diff --git a/.cargo/deny.toml b/.cargo/deny.toml new file mode 100644 index 0000000..09a5dd9 --- /dev/null +++ b/.cargo/deny.toml @@ -0,0 +1,15 @@ +[bans] +multiple-versions = "deny" +wildcards = "allow" +highlight = "all" + +# Explicitly flag the weak cryptography so the agent is forced to justify its existence +[[bans.skip]] +name = "md-5" +version = "*" +reason = "MUST VERIFY: Only allowed for legacy checksums, never for security." + +[[bans.skip]] +name = "sha1" +version = "*" +reason = "MUST VERIFY: Only allowed for backwards compatibility." diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index effe3ea..799f2ce 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -45,6 +45,18 @@ jobs: - name: Run tests run: cargo test --verbose + - name: Stress quota-lock suites (PR only) + if: github.event_name == 'pull_request' + env: + RUST_TEST_THREADS: 16 + run: | + set -euo pipefail + for i in $(seq 1 12); do + echo "[quota-lock-stress] iteration ${i}/12" + cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 + cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 + done + # clippy dont fail on warnings because of active development of telemt # and many warnings - name: Run clippy diff --git a/.gitignore b/.gitignore index 3a45e41..bc782ca 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ target #.idea/ proxy-secret +coverage-html/ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index e6c5f2e..c17cc76 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,6 +5,22 @@ Your responses are precise, minimal, and architecturally sound. You are working --- +### Context: The Telemt Project + +You are working on **Telemt**, a high-performance, production-grade Telegram MTProxy implementation written in Rust. It is explicitly designed to operate in highly hostile network environments and evade advanced network censorship. + +**Adversarial Threat Model:** +The proxy operates under constant surveillance by DPI (Deep Packet Inspection) systems and active scanners (state firewalls, mobile operator fraud controls). These entities actively probe IPs, analyze protocol handshakes, and look for known proxy signatures to block or throttle traffic. + +**Core Architectural Pillars:** +1. **TLS-Fronting (TLS-F) & TCP-Splitting (TCP-S):** To the outside world, Telemt looks like a standard TLS server. If a client presents a valid MTProxy key, the connection is handled internally. If a censor's scanner, web browser, or unauthorized crawler connects, Telemt seamlessly splices the TCP connection (L4) to a real, legitimate HTTPS fallback server (e.g., Nginx) without modifying the `ClientHello` or terminating the TLS handshake. +2. **Middle-End (ME) Orchestration:** A highly concurrent, generation-based pool managing upstream connections to Telegram Datacenters (DCs). It utilizes an **Adaptive Floor** (dynamically scaling writer connections based on traffic), **Hardswaps** (zero-downtime pool reconfiguration), and **STUN/NAT** reflection mechanisms. +3. **Strict KDF Routing:** Cryptographic Key Derivation Functions (KDF) in this protocol strictly rely on the exact pairing of Source IP/Port and Destination IP/Port. Deviations or missing port logic will silently break the MTProto handshake. +4. **Data Plane vs. Control Plane Isolation:** The Data Plane (readers, writers, payload relay, TCP splicing) must remain strictly non-blocking, zero-allocation in hot paths, and highly resilient to network backpressure. The Control Plane (API, metrics, pool generation swaps, config reloads) orchestrates the state asynchronously without stalling the Data Plane. + +Any modification you make must preserve Telemt's invisibility to censors, its strict memory-safety invariants, and its hot-path throughput. + + ### 0. Priority Resolution — Scope Control This section resolves conflicts between code quality enforcement and scope limitation. @@ -374,6 +390,12 @@ you MUST explain why existing invariants remain valid. - Do not modify existing tests unless the task explicitly requires it. - Do not weaken assertions. - Preserve determinism in testable components. +- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new. +- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code. +- Differential/model catches logic drift over time. +- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run. +- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly. +- Dead parameter is a code smell rule. ### 15. Security Constraints diff --git a/Cargo.lock b/Cargo.lock index 787e357..8159a22 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,7 +20,7 @@ checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -46,6 +46,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -81,9 +90,9 @@ checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "arc-swap" -version = "1.8.2" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" dependencies = [ "rustversion", ] @@ -102,9 +111,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "asn1-rs" -version = "0.5.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f6fd5ddaf0351dff5b8da21b2fb4ff8e08ddd02857f0bf69c47639106c0fff0" +checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60" dependencies = [ "asn1-rs-derive", "asn1-rs-impl", @@ -112,31 +121,31 @@ dependencies = [ "nom", "num-traits", "rusticata-macros", - "thiserror 1.0.69", + "thiserror 2.0.18", "time", ] [[package]] name = "asn1-rs-derive" -version = "0.4.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "726535892e8eae7e70657b4c8ea93d26b8553afb1ce617caee529ef96d7dee6c" +checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", - "synstructure 0.12.6", + "syn", + "synstructure", ] [[package]] name = "asn1-rs-impl" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2777730b2039ac0f95f093556e61b6d26cebed5393ca6f152717777cec3a42ed" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] @@ -147,7 +156,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -162,6 +171,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "base64" version = "0.22.1" @@ -212,7 +243,7 @@ dependencies = [ "cc", "cfg-if", "constant_time_eq", - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -273,21 +304,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" -[[package]] -name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - [[package]] name = "cfg_aliases" version = "0.2.1" @@ -302,7 +335,18 @@ checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", +] + +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.0", ] [[package]] @@ -312,7 +356,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead", - "chacha20", + "chacha20 0.9.1", "cipher", "poly1305", "zeroize", @@ -395,6 +439,25 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -407,6 +470,16 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -422,6 +495,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crc32c" version = "0.6.8" @@ -442,25 +524,24 @@ dependencies = [ [[package]] name = "criterion" -version = "0.5.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" dependencies = [ + "alloca", "anes", "cast", "ciborium", "clap", "criterion-plot", - "is-terminal", "itertools", "num-traits", - "once_cell", "oorandom", + "page_size", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "walkdir", @@ -468,9 +549,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" dependencies = [ "cast", "itertools", @@ -552,12 +633,39 @@ dependencies = [ ] [[package]] -name = "dashmap" -version = "5.5.3" +name = "curve25519-dalek" +version = "4.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" dependencies = [ "cfg-if", + "cpufeatures 0.2.17", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", "hashbrown 0.14.5", "lock_api", "once_cell", @@ -582,9 +690,9 @@ dependencies = [ [[package]] name = "der-parser" -version = "8.2.0" +version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbd676fbbab537128ef0278adb5576cf363cff6aa22a7b24effe97347cfab61e" +checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" dependencies = [ "asn1-rs", "displaydoc", @@ -622,9 +730,15 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dynosaur" version = "0.3.0" @@ -642,7 +756,7 @@ checksum = "0b0713d5c1d52e774c5cd7bb8b043d7c0fc4f921abfb678556140bfbe6ab2364" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -670,7 +784,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -696,15 +810,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] -name = "filetime" -version = "0.2.27" +name = "fiat-crypto" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" -dependencies = [ - "cfg-if", - "libc", - "libredox", -] +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" [[package]] name = "find-msvc-tools" @@ -739,6 +848,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "fsevent-sys" version = "4.1.0" @@ -804,7 +919,7 @@ checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -882,6 +997,7 @@ dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", + "rand_core 0.10.0", "wasip2", "wasip3", ] @@ -958,12 +1074,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" - [[package]] name = "hex" version = "0.4.3" @@ -986,7 +1096,7 @@ dependencies = [ "idna", "ipnet", "once_cell", - "rand", + "rand 0.9.2", "ring", "thiserror 2.0.18", "tinyvec", @@ -1008,7 +1118,7 @@ dependencies = [ "moka", "once_cell", "parking_lot", - "rand", + "rand 0.9.2", "resolv-conf", "smallvec", "thiserror 2.0.18", @@ -1116,7 +1226,6 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots 1.0.6", ] [[package]] @@ -1286,17 +1395,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "inotify" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" -dependencies = [ - "bitflags 1.3.2", - "inotify-sys", - "libc", -] - [[package]] name = "inotify" version = "0.11.1" @@ -1347,9 +1445,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "ipnetwork" -version = "0.20.0" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" dependencies = [ "serde", ] @@ -1364,22 +1462,11 @@ dependencies = [ "serde", ] -[[package]] -name = "is-terminal" -version = "0.4.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.61.2", -] - [[package]] name = "itertools" -version = "0.10.5" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -1390,6 +1477,38 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.91" @@ -1438,18 +1557,6 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" -[[package]] -name = "libredox" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" -dependencies = [ - "bitflags 2.11.0", - "libc", - "plain", - "redox_syscall 0.7.3", -] - [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -1538,18 +1645,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "mio" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" -dependencies = [ - "libc", - "log", - "wasi", - "windows-sys 0.48.0", -] - [[package]] name = "mio" version = "1.1.1" @@ -1581,13 +1676,13 @@ dependencies = [ [[package]] name = "nix" -version = "0.28.0" +version = "0.31.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" dependencies = [ "bitflags 2.11.0", "cfg-if", - "cfg_aliases 0.1.1", + "cfg_aliases", "libc", "memoffset", ] @@ -1602,25 +1697,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "notify" -version = "6.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" -dependencies = [ - "bitflags 2.11.0", - "crossbeam-channel", - "filetime", - "fsevent-sys", - "inotify 0.9.6", - "kqueue", - "libc", - "log", - "mio 0.8.11", - "walkdir", - "windows-sys 0.48.0", -] - [[package]] name = "notify" version = "8.2.0" @@ -1629,11 +1705,11 @@ checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3" dependencies = [ "bitflags 2.11.0", "fsevent-sys", - "inotify 0.11.1", + "inotify", "kqueue", "libc", "log", - "mio 1.1.1", + "mio", "notify-types", "walkdir", "windows-sys 0.60.2", @@ -1693,9 +1769,9 @@ dependencies = [ [[package]] name = "oid-registry" -version = "0.6.1" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bedf36ffb6ba96c2eb7144ef6270557b52e54b20c0a8e1eb2ff99a6c6959bff" +checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7" dependencies = [ "asn1-rs", ] @@ -1722,6 +1798,22 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "parking_lot" version = "0.12.5" @@ -1740,7 +1832,7 @@ checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.18", + "redox_syscall", "smallvec", "windows-link", ] @@ -1768,7 +1860,7 @@ checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -1793,12 +1885,6 @@ dependencies = [ "spki", ] -[[package]] -name = "plain" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" - [[package]] name = "plotters" version = "0.3.7" @@ -1833,7 +1919,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.17", "opaque-debug", "universal-hash", ] @@ -1845,7 +1931,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "opaque-debug", "universal-hash", ] @@ -1887,7 +1973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.117", + "syn", ] [[package]] @@ -1909,7 +1995,7 @@ dependencies = [ "bit-vec", "bitflags 2.11.0", "num-traits", - "rand", + "rand 0.9.2", "rand_chacha", "rand_xorshift", "regex-syntax", @@ -1931,7 +2017,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ "bytes", - "cfg_aliases 0.2.1", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", @@ -1950,10 +2036,11 @@ version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ + "aws-lc-rs", "bytes", "getrandom 0.3.4", "lru-slab", - "rand", + "rand 0.9.2", "ring", "rustc-hash", "rustls", @@ -1971,7 +2058,7 @@ version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ - "cfg_aliases 0.2.1", + "cfg_aliases", "libc", "once_cell", "socket2 0.6.3", @@ -2010,6 +2097,17 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +dependencies = [ + "chacha20 0.10.0", + "getrandom 0.4.2", + "rand_core 0.10.0", +] + [[package]] name = "rand_chacha" version = "0.9.0" @@ -2038,6 +2136,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" + [[package]] name = "rand_xorshift" version = "0.4.0" @@ -2076,15 +2180,6 @@ dependencies = [ "bitflags 2.11.0", ] -[[package]] -name = "redox_syscall" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" -dependencies = [ - "bitflags 2.11.0", -] - [[package]] name = "regex" version = "1.12.3" @@ -2116,9 +2211,9 @@ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "reqwest" -version = "0.12.28" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" dependencies = [ "base64", "bytes", @@ -2136,9 +2231,7 @@ dependencies = [ "quinn", "rustls", "rustls-pki-types", - "serde", - "serde_json", - "serde_urlencoded", + "rustls-platform-verifier", "sync_wrapper", "tokio", "tokio-rustls", @@ -2149,7 +2242,6 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots 1.0.6", ] [[package]] @@ -2228,6 +2320,7 @@ version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "aws-lc-rs", "once_cell", "ring", "rustls-pki-types", @@ -2236,6 +2329,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pki-types" version = "1.14.0" @@ -2247,11 +2352,39 @@ dependencies = [ ] [[package]] -name = "rustls-webpki" -version = "0.103.9" +name = "rustls-platform-verifier" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + +[[package]] +name = "rustls-webpki" +version = "0.103.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" +dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -2290,6 +2423,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2304,7 +2446,30 @@ checksum = "22f968c5ea23d555e670b449c1c5e7b2fc399fdaec1d304a17cd48e288abc107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", +] + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags 2.11.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", ] [[package]] @@ -2350,7 +2515,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2368,11 +2533,11 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.9" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" dependencies = [ - "serde", + "serde_core", ] [[package]] @@ -2394,7 +2559,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -2405,7 +2570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -2428,10 +2593,10 @@ dependencies = [ "libc", "log", "lru_time_cache", - "notify 8.2.0", + "notify", "percent-encoding", "pin-project", - "rand", + "rand 0.9.2", "sealed", "sendfd", "serde", @@ -2462,7 +2627,7 @@ dependencies = [ "chacha20poly1305", "hkdf", "md-5", - "rand", + "rand 0.9.2", "ring-compat", "sha1", ] @@ -2555,23 +2720,18 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "subtle" version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.117" @@ -2592,18 +2752,6 @@ dependencies = [ "futures-core", ] -[[package]] -name = "synstructure" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "unicode-xid", -] - [[package]] name = "synstructure" version = "0.13.2" @@ -2612,7 +2760,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2623,7 +2771,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.28" +version = "3.3.29" dependencies = [ "aes", "anyhow", @@ -2650,12 +2798,12 @@ dependencies = [ "lru", "md-5", "nix", - "notify 6.1.1", + "notify", "num-bigint", "num-traits", "parking_lot", "proptest", - "rand", + "rand 0.10.0", "regex", "reqwest", "rustls", @@ -2664,7 +2812,9 @@ dependencies = [ "sha1", "sha2", "shadowsocks", - "socket2 0.5.10", + "socket2 0.6.3", + "static_assertions", + "subtle", "thiserror 2.0.18", "tokio", "tokio-rustls", @@ -2674,7 +2824,8 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webpki-roots 0.26.11", + "webpki-roots", + "x25519-dalek", "x509-parser", "zeroize", ] @@ -2718,7 +2869,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2729,7 +2880,7 @@ checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2815,7 +2966,7 @@ checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ "bytes", "libc", - "mio 1.1.1", + "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -2833,7 +2984,7 @@ checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2904,44 +3055,42 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.23" +version = "1.0.7+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +checksum = "dd28d57d8a6f6e458bc0b8784f8fdcc4b99a437936056fa122cb234f18656a96" dependencies = [ "indexmap", - "serde", + "serde_core", "serde_spanned", "toml_datetime", - "toml_write", + "toml_parser", + "toml_writer", "winnow", ] [[package]] -name = "toml_write" -version = "0.1.2" +name = "toml_datetime" +version = "1.0.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +checksum = "9b320e741db58cac564e26c607d3cc1fdc4a88fd36c879568c07856ed83ff3e9" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_parser" +version = "1.0.10+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7df25b4befd31c4816df190124375d5a20c6b6921e2cad937316de3fccd63420" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.0.7+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17aaa1c6e3dc22b1da4b6bba97d066e354c7945cac2f7852d4e4e7ca7a6b56d" [[package]] name = "tower" @@ -3007,7 +3156,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3057,7 +3206,7 @@ checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3245,7 +3394,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "wasm-bindgen-shared", ] @@ -3313,12 +3462,12 @@ dependencies = [ ] [[package]] -name = "webpki-roots" -version = "0.26.11" +name = "webpki-root-certs" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" dependencies = [ - "webpki-roots 1.0.6", + "rustls-pki-types", ] [[package]] @@ -3336,6 +3485,22 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.11" @@ -3345,6 +3510,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" @@ -3366,7 +3537,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3377,7 +3548,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3404,6 +3575,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -3440,6 +3620,21 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -3488,6 +3683,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3506,6 +3707,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3524,6 +3731,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3554,6 +3767,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3572,6 +3791,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3590,6 +3815,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -3608,6 +3839,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -3628,12 +3865,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.15" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" -dependencies = [ - "memchr", -] +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" [[package]] name = "winreg" @@ -3675,7 +3909,7 @@ dependencies = [ "heck", "indexmap", "prettyplease", - "syn 2.0.117", + "syn", "wasm-metadata", "wit-bindgen-core", "wit-component", @@ -3691,7 +3925,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -3740,10 +3974,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" [[package]] -name = "x509-parser" -version = "0.15.1" +name = "x25519-dalek" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7069fba5b66b9193bd2c5d3d4ff12b839118f6bcbef5328efafafb5395cf63da" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + +[[package]] +name = "x509-parser" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202" dependencies = [ "asn1-rs", "data-encoding", @@ -3752,7 +3998,7 @@ dependencies = [ "nom", "oid-registry", "rusticata-macros", - "thiserror 1.0.69", + "thiserror 2.0.18", "time", ] @@ -3775,8 +4021,8 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", - "synstructure 0.13.2", + "syn", + "synstructure", ] [[package]] @@ -3796,7 +4042,7 @@ checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3816,8 +4062,8 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", - "synstructure 0.13.2", + "syn", + "synstructure", ] [[package]] @@ -3837,7 +4083,7 @@ checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3870,7 +4116,7 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 97855f3..53082db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.3.28" +version = "3.3.29" edition = "2024" [dependencies] @@ -22,17 +22,19 @@ hmac = "0.12" crc32fast = "1.4" crc32c = "0.6" zeroize = { version = "1.8", features = ["derive"] } +subtle = "2.6" +static_assertions = "1.1" # Network -socket2 = { version = "0.5", features = ["all"] } -nix = { version = "0.28", default-features = false, features = ["net"] } +socket2 = { version = "0.6", features = ["all"] } +nix = { version = "0.31", default-features = false, features = ["net", "fs"] } shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] } # Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -toml = "0.8" -x509-parser = "0.15" +toml = "1.0" +x509-parser = "0.18" # Utils bytes = "1.9" @@ -40,10 +42,10 @@ thiserror = "2.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } parking_lot = "0.12" -dashmap = "5.5" +dashmap = "6.1" arc-swap = "1.7" lru = "0.16" -rand = "0.9" +rand = "0.10" chrono = { version = "0.4", features = ["serde"] } hex = "0.4" base64 = "0.22" @@ -52,23 +54,24 @@ regex = "1.11" crossbeam-queue = "0.3" num-bigint = "0.4" num-traits = "0.2" +x25519-dalek = "2" anyhow = "1.0" # HTTP -reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } -notify = { version = "6", features = ["macos_fsevent"] } -ipnetwork = "0.20" +reqwest = { version = "0.13", features = ["rustls"], default-features = false } +notify = "8.2" +ipnetwork = { version = "0.21", features = ["serde"] } hyper = { version = "1", features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } http-body-util = "0.1" httpdate = "1.0" tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] } rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] } -webpki-roots = "0.26" +webpki-roots = "1.0" [dev-dependencies] tokio-test = "0.4" -criterion = "0.5" +criterion = "0.8" proptest = "1.4" futures = "0.3" diff --git a/benches/crypto_bench.rs b/benches/crypto_bench.rs index 0089abe..940791c 100644 --- a/benches/crypto_bench.rs +++ b/benches/crypto_bench.rs @@ -1,5 +1,5 @@ // Cryptobench -use criterion::{black_box, criterion_group, Criterion}; +use criterion::{Criterion, black_box, criterion_group}; fn bench_aes_ctr(c: &mut Criterion) { c.bench_function("aes_ctr_encrypt_64kb", |b| { @@ -9,4 +9,4 @@ fn bench_aes_ctr(c: &mut Criterion) { black_box(enc.encrypt(&data)) }) }); -} \ No newline at end of file +} diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index 90da08a..3eee3a7 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -260,6 +260,129 @@ This document lists all configuration keys accepted by `config.toml`. | tls_full_cert_ttl_secs | `u64` | `90` | — | TTL for sending full cert payload per (domain, client IP) tuple. | | alpn_enforce | `bool` | `true` | — | Enforces ALPN echo behavior based on client preference. | | mask_proxy_protocol | `u8` | `0` | — | PROXY protocol mode for mask backend (`0` disabled, `1` v1, `2` v2). | +| mask_shape_hardening | `bool` | `true` | — | Enables client->mask shape-channel hardening by applying controlled tail padding to bucket boundaries on mask relay shutdown. | +| mask_shape_hardening_aggressive_mode | `bool` | `false` | Requires `mask_shape_hardening = true`. | Opt-in aggressive shaping profile: allows shaping on backend-silent non-EOF paths and switches above-cap blur to strictly positive random tail. | +| mask_shape_bucket_floor_bytes | `usize` | `512` | Must be `> 0`; should be `<= mask_shape_bucket_cap_bytes`. | Minimum bucket size used by shape-channel hardening. | +| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | +| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | +| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. | +| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. | +| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. | +| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. | + +### Shape-channel hardening notes (`[censorship]`) + +These parameters are designed to reduce one specific fingerprint source during masking: the exact number of bytes sent from proxy to `mask_host` for invalid or probing traffic. + +Without hardening, a censor can often correlate probe input length with backend-observed length very precisely (for example: `5 + body_sent` on early TLS reject paths). That creates a length-based classifier signal. + +When `mask_shape_hardening = true`, Telemt pads the **client->mask** stream tail to a bucket boundary at relay shutdown: + +- Total bytes sent to mask are first measured. +- A bucket is selected using powers of two starting from `mask_shape_bucket_floor_bytes`. +- Padding is added only if total bytes are below `mask_shape_bucket_cap_bytes`. +- If bytes already exceed cap, no extra padding is added. + +This means multiple nearby probe sizes collapse into the same backend-observed size class, making active classification harder. + +What each parameter changes in practice: + +- `mask_shape_hardening` + Enables or disables this entire length-shaping stage on the fallback path. + When `false`, backend-observed length stays close to the real forwarded probe length. + When `true`, clean relay shutdown can append random padding bytes to move the total into a bucket. + +- `mask_shape_bucket_floor_bytes` + Sets the first bucket boundary used for small probes. + Example: with floor `512`, a malformed probe that would otherwise forward `37` bytes can be expanded to `512` bytes on clean EOF. + Larger floor values hide very small probes better, but increase egress cost. + +- `mask_shape_bucket_cap_bytes` + Sets the largest bucket Telemt will pad up to with bucket logic. + Example: with cap `4096`, a forwarded total of `1800` bytes may be padded to `2048` or `4096` depending on the bucket ladder, but a total already above `4096` will not be bucket-padded further. + Larger cap values increase the range over which size classes are collapsed, but also increase worst-case overhead. + +- Clean EOF matters in conservative mode + In the default profile, shape padding is intentionally conservative: it is applied on clean relay shutdown, not on every timeout/drip path. + This avoids introducing new timeout-tail artifacts that some backends or tests interpret as a separate fingerprint. + +Practical trade-offs: + +- Better anti-fingerprinting on size/shape channel. +- Slightly higher egress overhead for small probes due to padding. +- Behavior is intentionally conservative and enabled by default. + +Recommended starting profile: + +- `mask_shape_hardening = true` (default) +- `mask_shape_bucket_floor_bytes = 512` +- `mask_shape_bucket_cap_bytes = 4096` + +### Aggressive mode notes (`[censorship]`) + +`mask_shape_hardening_aggressive_mode` is an opt-in profile for higher anti-classifier pressure. + +- Default is `false` to preserve conservative timeout/no-tail behavior. +- Requires `mask_shape_hardening = true`. +- When enabled, backend-silent non-EOF masking paths may be shaped. +- When enabled together with above-cap blur, the random extra tail uses `[1, max]` instead of `[0, max]`. + +What changes when aggressive mode is enabled: + +- Backend-silent timeout paths can be shaped + In default mode, a client that keeps the socket half-open and times out will usually not receive shape padding on that path. + In aggressive mode, Telemt may still shape that backend-silent session if no backend bytes were returned. + This is specifically aimed at active probes that try to avoid EOF in order to preserve an exact backend-observed length. + +- Above-cap blur always adds at least one byte + In default mode, above-cap blur may choose `0`, so some oversized probes still land on their exact base forwarded length. + In aggressive mode, that exact-base sample is removed by construction. + +- Tradeoff + Aggressive mode improves resistance to active length classifiers, but it is more opinionated and less conservative. + If your deployment prioritizes strict compatibility with timeout/no-tail semantics, leave it disabled. + If your threat model includes repeated active probing by a censor, this mode is the stronger profile. + +Use this mode only when your threat model prioritizes classifier resistance over strict compatibility with conservative masking semantics. + +### Above-cap blur notes (`[censorship]`) + +`mask_shape_above_cap_blur` adds a second-stage blur for very large probes that are already above `mask_shape_bucket_cap_bytes`. + +- A random tail in `[0, mask_shape_above_cap_blur_max_bytes]` is appended in default mode. +- In aggressive mode, the random tail becomes strictly positive: `[1, mask_shape_above_cap_blur_max_bytes]`. +- This reduces exact-size leakage above cap at bounded overhead. +- Keep `mask_shape_above_cap_blur_max_bytes` conservative to avoid unnecessary egress growth. + +Operational meaning: + +- Without above-cap blur + A probe that forwards `5005` bytes will still look like `5005` bytes to the backend if it is already above cap. + +- With above-cap blur enabled + That same probe may look like any value in a bounded window above its base length. + Example with `mask_shape_above_cap_blur_max_bytes = 64`: + backend-observed size becomes `5005..5069` in default mode, or `5006..5069` in aggressive mode. + +- Choosing `mask_shape_above_cap_blur_max_bytes` + Small values reduce cost but preserve more separability between far-apart oversized classes. + Larger values blur oversized classes more aggressively, but add more egress overhead and more output variance. + +### Timing normalization envelope notes (`[censorship]`) + +`mask_timing_normalization_enabled` smooths timing differences between masking outcomes by applying a target duration envelope. + +- A random target is selected in `[mask_timing_normalization_floor_ms, mask_timing_normalization_ceiling_ms]`. +- Fast paths are delayed up to the selected target. +- Slow paths are not forced to finish by the ceiling (the envelope is best-effort shaping, not truncation). + +Recommended starting profile for timing shaping: + +- `mask_timing_normalization_enabled = true` +- `mask_timing_normalization_floor_ms = 180` +- `mask_timing_normalization_ceiling_ms = 320` + +If your backend or network is very bandwidth-constrained, reduce cap first. If probes are still too distinguishable in your environment, increase floor gradually. ## [access] diff --git a/docs/model/FakeTLS.png b/docs/model/FakeTLS.png new file mode 100644 index 0000000..5f6782e Binary files /dev/null and b/docs/model/FakeTLS.png differ diff --git a/docs/model/architecture.png b/docs/model/architecture.png new file mode 100644 index 0000000..71d4a17 Binary files /dev/null and b/docs/model/architecture.png differ diff --git a/src/api/http_utils.rs b/src/api/http_utils.rs index e04bd04..9dfe526 100644 --- a/src/api/http_utils.rs +++ b/src/api/http_utils.rs @@ -24,10 +24,7 @@ pub(super) fn success_response( .unwrap() } -pub(super) fn error_response( - request_id: u64, - failure: ApiFailure, -) -> hyper::Response> { +pub(super) fn error_response(request_id: u64, failure: ApiFailure) -> hyper::Response> { let payload = ErrorResponse { ok: false, error: ErrorBody { diff --git a/src/api/mod.rs b/src/api/mod.rs index 0e2edd4..c1e3557 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; @@ -19,8 +21,8 @@ use crate::ip_tracker::UserIpTracker; use crate::proxy::route_mode::RouteRuntimeController; use crate::startup::StartupTracker; use crate::stats::Stats; -use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; +use crate::transport::middle_proxy::MePool; mod config_store; mod events; @@ -36,8 +38,8 @@ mod runtime_zero; mod users; use config_store::{current_revision, parse_if_match}; -use http_utils::{error_response, read_json, read_optional_json, success_response}; use events::ApiEventStore; +use http_utils::{error_response, read_json, read_optional_json, success_response}; use model::{ ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData, }; @@ -55,11 +57,11 @@ use runtime_stats::{ MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data, build_upstreams_data, build_zero_all_data, }; +use runtime_watch::spawn_runtime_watchers; use runtime_zero::{ build_limits_effective_data, build_runtime_gates_data, build_security_posture_data, build_system_info_data, }; -use runtime_watch::spawn_runtime_watchers; use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; pub(super) struct ApiRuntimeState { @@ -208,15 +210,15 @@ async fn handle( )); } - if !api_cfg.whitelist.is_empty() - && !api_cfg - .whitelist - .iter() - .any(|net| net.contains(peer.ip())) + if !api_cfg.whitelist.is_empty() && !api_cfg.whitelist.iter().any(|net| net.contains(peer.ip())) { return Ok(error_response( request_id, - ApiFailure::new(StatusCode::FORBIDDEN, "forbidden", "Source IP is not allowed"), + ApiFailure::new( + StatusCode::FORBIDDEN, + "forbidden", + "Source IP is not allowed", + ), )); } @@ -347,7 +349,8 @@ async fn handle( } ("GET", "/v1/runtime/connections/summary") => { let revision = current_revision(&shared.config_path).await?; - let data = build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await; + let data = + build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await; Ok(success_response(StatusCode::OK, data, revision)) } ("GET", "/v1/runtime/events/recent") => { @@ -389,13 +392,16 @@ async fn handle( let (data, revision) = match result { Ok(ok) => ok, Err(error) => { - shared.runtime_events.record("api.user.create.failed", error.code); + shared + .runtime_events + .record("api.user.create.failed", error.code); return Err(error); } }; - shared - .runtime_events - .record("api.user.create.ok", format!("username={}", data.user.username)); + shared.runtime_events.record( + "api.user.create.ok", + format!("username={}", data.user.username), + ); Ok(success_response(StatusCode::CREATED, data, revision)) } _ => { @@ -414,7 +420,8 @@ async fn handle( detected_ip_v6, ) .await; - if let Some(user_info) = users.into_iter().find(|entry| entry.username == user) + if let Some(user_info) = + users.into_iter().find(|entry| entry.username == user) { return Ok(success_response(StatusCode::OK, user_info, revision)); } @@ -435,7 +442,8 @@ async fn handle( )); } let expected_revision = parse_if_match(req.headers()); - let body = read_json::(req.into_body(), body_limit).await?; + let body = + read_json::(req.into_body(), body_limit).await?; let result = patch_user(user, body, expected_revision, &shared).await; let (data, revision) = match result { Ok(ok) => ok, @@ -475,10 +483,9 @@ async fn handle( return Err(error); } }; - shared.runtime_events.record( - "api.user.delete.ok", - format!("username={}", deleted_user), - ); + shared + .runtime_events + .record("api.user.delete.ok", format!("username={}", deleted_user)); return Ok(success_response(StatusCode::OK, deleted_user, revision)); } if method == Method::POST diff --git a/src/api/model.rs b/src/api/model.rs index 6578d35..8ae0c0b 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -1,10 +1,12 @@ use std::net::IpAddr; +use std::sync::OnceLock; use chrono::{DateTime, Utc}; use hyper::StatusCode; -use rand::Rng; use serde::{Deserialize, Serialize}; +use crate::crypto::SecureRandom; + const MAX_USERNAME_LEN: usize = 64; #[derive(Debug)] @@ -172,6 +174,24 @@ pub(super) struct ZeroMiddleProxyData { pub(super) route_drop_queue_full_total: u64, pub(super) route_drop_queue_full_base_total: u64, pub(super) route_drop_queue_full_high_total: u64, + pub(super) d2c_batches_total: u64, + pub(super) d2c_batch_frames_total: u64, + pub(super) d2c_batch_bytes_total: u64, + pub(super) d2c_flush_reason_queue_drain_total: u64, + pub(super) d2c_flush_reason_batch_frames_total: u64, + pub(super) d2c_flush_reason_batch_bytes_total: u64, + pub(super) d2c_flush_reason_max_delay_total: u64, + pub(super) d2c_flush_reason_ack_immediate_total: u64, + pub(super) d2c_flush_reason_close_total: u64, + pub(super) d2c_data_frames_total: u64, + pub(super) d2c_ack_frames_total: u64, + pub(super) d2c_payload_bytes_total: u64, + pub(super) d2c_write_mode_coalesced_total: u64, + pub(super) d2c_write_mode_split_total: u64, + pub(super) d2c_quota_reject_pre_write_total: u64, + pub(super) d2c_quota_reject_post_write_total: u64, + pub(super) d2c_frame_buf_shrink_total: u64, + pub(super) d2c_frame_buf_shrink_bytes_total: u64, pub(super) socks_kdf_strict_reject_total: u64, pub(super) socks_kdf_compat_fallback_total: u64, pub(super) endpoint_quarantine_total: u64, @@ -196,8 +216,6 @@ pub(super) struct ZeroPoolData { pub(super) pool_swap_total: u64, pub(super) pool_drain_active: u64, pub(super) pool_force_close_total: u64, - pub(super) pool_drain_soft_evict_total: u64, - pub(super) pool_drain_soft_evict_writer_total: u64, pub(super) pool_stale_pick_total: u64, pub(super) writer_removed_total: u64, pub(super) writer_removed_unexpected_total: u64, @@ -206,16 +224,6 @@ pub(super) struct ZeroPoolData { pub(super) refill_failed_total: u64, pub(super) writer_restored_same_endpoint_total: u64, pub(super) writer_restored_fallback_total: u64, - pub(super) teardown_attempt_total_normal: u64, - pub(super) teardown_attempt_total_hard_detach: u64, - pub(super) teardown_success_total_normal: u64, - pub(super) teardown_success_total_hard_detach: u64, - pub(super) teardown_timeout_total: u64, - pub(super) teardown_escalation_total: u64, - pub(super) teardown_noop_total: u64, - pub(super) teardown_cleanup_side_effect_failures_total: u64, - pub(super) teardown_duration_count_total: u64, - pub(super) teardown_duration_sum_seconds_total: f64, } #[derive(Serialize, Clone)] @@ -248,7 +256,6 @@ pub(super) struct MeWritersSummary { pub(super) available_pct: f64, pub(super) required_writers: usize, pub(super) alive_writers: usize, - pub(super) coverage_ratio: f64, pub(super) coverage_pct: f64, pub(super) fresh_alive_writers: usize, pub(super) fresh_coverage_pct: f64, @@ -297,7 +304,6 @@ pub(super) struct DcStatus { pub(super) floor_max: usize, pub(super) floor_capped: bool, pub(super) alive_writers: usize, - pub(super) coverage_ratio: f64, pub(super) coverage_pct: f64, pub(super) fresh_alive_writers: usize, pub(super) fresh_coverage_pct: f64, @@ -375,12 +381,6 @@ pub(super) struct MinimalMeRuntimeData { pub(super) me_reconnect_backoff_cap_ms: u64, pub(super) me_reconnect_fast_retry_count: u32, pub(super) me_pool_drain_ttl_secs: u64, - pub(super) me_instadrain: bool, - pub(super) me_pool_drain_soft_evict_enabled: bool, - pub(super) me_pool_drain_soft_evict_grace_secs: u64, - pub(super) me_pool_drain_soft_evict_per_writer: u8, - pub(super) me_pool_drain_soft_evict_budget_per_core: u16, - pub(super) me_pool_drain_soft_evict_cooldown_ms: u64, pub(super) me_pool_force_close_secs: u64, pub(super) me_pool_min_fresh_ratio: f32, pub(super) me_bind_stale_mode: &'static str, @@ -502,7 +502,9 @@ pub(super) fn is_valid_username(user: &str) -> bool { } pub(super) fn random_user_secret() -> String { + static API_SECRET_RNG: OnceLock = OnceLock::new(); + let rng = API_SECRET_RNG.get_or_init(SecureRandom::new); let mut bytes = [0u8; 16]; - rand::rng().fill(&mut bytes); + rng.fill(&mut bytes); hex::encode(bytes) } diff --git a/src/api/runtime_init.rs b/src/api/runtime_init.rs index 4bd8943..b7601f5 100644 --- a/src/api/runtime_init.rs +++ b/src/api/runtime_init.rs @@ -167,11 +167,7 @@ async fn current_me_pool_stage_progress(shared: &ApiShared) -> Option { let pool = shared.me_pool.read().await.clone()?; let status = pool.api_status_snapshot().await; let configured_dc_groups = status.configured_dc_groups; - let covered_dc_groups = status - .dcs - .iter() - .filter(|dc| dc.alive_writers > 0) - .count(); + let covered_dc_groups = status.dcs.iter().filter(|dc| dc.alive_writers > 0).count(); let dc_coverage = ratio_01(covered_dc_groups, configured_dc_groups); let writer_coverage = ratio_01(status.alive_writers, status.required_writers); diff --git a/src/api/runtime_min.rs b/src/api/runtime_min.rs index 047fd9c..986f138 100644 --- a/src/api/runtime_min.rs +++ b/src/api/runtime_min.rs @@ -4,9 +4,6 @@ use std::time::{SystemTime, UNIX_EPOCH}; use serde::Serialize; use crate::config::ProxyConfig; -use crate::stats::{ - MeWriterCleanupSideEffectStep, MeWriterTeardownMode, MeWriterTeardownReason, Stats, -}; use super::ApiShared; @@ -101,50 +98,6 @@ pub(super) struct RuntimeMeQualityCountersData { pub(super) reconnect_success_total: u64, } -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownAttemptData { - pub(super) reason: &'static str, - pub(super) mode: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownSuccessData { - pub(super) mode: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownSideEffectData { - pub(super) step: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownDurationBucketData { - pub(super) le_seconds: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownDurationData { - pub(super) mode: &'static str, - pub(super) count: u64, - pub(super) sum_seconds: f64, - pub(super) buckets: Vec, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownData { - pub(super) attempts: Vec, - pub(super) success: Vec, - pub(super) timeout_total: u64, - pub(super) escalation_total: u64, - pub(super) noop_total: u64, - pub(super) cleanup_side_effect_failures: Vec, - pub(super) duration: Vec, -} - #[derive(Serialize)] pub(super) struct RuntimeMeQualityRouteDropData { pub(super) no_conn_total: u64, @@ -179,14 +132,12 @@ pub(super) struct RuntimeMeQualityDcRttData { pub(super) rtt_ema_ms: Option, pub(super) alive_writers: usize, pub(super) required_writers: usize, - pub(super) coverage_ratio: f64, pub(super) coverage_pct: f64, } #[derive(Serialize)] pub(super) struct RuntimeMeQualityPayload { pub(super) counters: RuntimeMeQualityCountersData, - pub(super) teardown: RuntimeMeQualityTeardownData, pub(super) route_drops: RuntimeMeQualityRouteDropData, pub(super) family_states: Vec, pub(super) drain_gate: RuntimeMeQualityDrainGateData, @@ -457,7 +408,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime reconnect_attempt_total: shared.stats.get_me_reconnect_attempts(), reconnect_success_total: shared.stats.get_me_reconnect_success(), }, - teardown: build_runtime_me_teardown_data(shared), route_drops: RuntimeMeQualityRouteDropData { no_conn_total: shared.stats.get_me_route_drop_no_conn(), channel_closed_total: shared.stats.get_me_route_drop_channel_closed(), @@ -480,7 +430,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime rtt_ema_ms: dc.rtt_ms, alive_writers: dc.alive_writers, required_writers: dc.required_writers, - coverage_ratio: dc.coverage_ratio, coverage_pct: dc.coverage_pct, }) .collect(), @@ -488,81 +437,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime } } -fn build_runtime_me_teardown_data(shared: &ApiShared) -> RuntimeMeQualityTeardownData { - let attempts = MeWriterTeardownReason::ALL - .iter() - .copied() - .flat_map(|reason| { - MeWriterTeardownMode::ALL - .iter() - .copied() - .map(move |mode| RuntimeMeQualityTeardownAttemptData { - reason: reason.as_str(), - mode: mode.as_str(), - total: shared.stats.get_me_writer_teardown_attempt_total(reason, mode), - }) - }) - .collect(); - - let success = MeWriterTeardownMode::ALL - .iter() - .copied() - .map(|mode| RuntimeMeQualityTeardownSuccessData { - mode: mode.as_str(), - total: shared.stats.get_me_writer_teardown_success_total(mode), - }) - .collect(); - - let cleanup_side_effect_failures = MeWriterCleanupSideEffectStep::ALL - .iter() - .copied() - .map(|step| RuntimeMeQualityTeardownSideEffectData { - step: step.as_str(), - total: shared - .stats - .get_me_writer_cleanup_side_effect_failures_total(step), - }) - .collect(); - - let duration = MeWriterTeardownMode::ALL - .iter() - .copied() - .map(|mode| { - let count = shared.stats.get_me_writer_teardown_duration_count(mode); - let mut buckets: Vec = Stats::me_writer_teardown_duration_bucket_labels() - .iter() - .enumerate() - .map(|(bucket_idx, label)| RuntimeMeQualityTeardownDurationBucketData { - le_seconds: label, - total: shared - .stats - .get_me_writer_teardown_duration_bucket_total(mode, bucket_idx), - }) - .collect(); - buckets.push(RuntimeMeQualityTeardownDurationBucketData { - le_seconds: "+Inf", - total: count, - }); - RuntimeMeQualityTeardownDurationData { - mode: mode.as_str(), - count, - sum_seconds: shared.stats.get_me_writer_teardown_duration_sum_seconds(mode), - buckets, - } - }) - .collect(); - - RuntimeMeQualityTeardownData { - attempts, - success, - timeout_total: shared.stats.get_me_writer_teardown_timeout_total(), - escalation_total: shared.stats.get_me_writer_teardown_escalation_total(), - noop_total: shared.stats.get_me_writer_teardown_noop_total(), - cleanup_side_effect_failures, - duration, - } -} - pub(super) async fn build_runtime_upstream_quality_data( shared: &ApiShared, ) -> RuntimeUpstreamQualityData { diff --git a/src/api/runtime_stats.rs b/src/api/runtime_stats.rs index 999e2cf..b66d1a5 100644 --- a/src/api/runtime_stats.rs +++ b/src/api/runtime_stats.rs @@ -1,9 +1,9 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use crate::config::ApiConfig; -use crate::stats::{MeWriterTeardownMode, Stats}; -use crate::transport::upstream::IpPreference; +use crate::stats::Stats; use crate::transport::UpstreamRouteKind; +use crate::transport::upstream::IpPreference; use super::ApiShared; use super::model::{ @@ -68,6 +68,25 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer route_drop_queue_full_total: stats.get_me_route_drop_queue_full(), route_drop_queue_full_base_total: stats.get_me_route_drop_queue_full_base(), route_drop_queue_full_high_total: stats.get_me_route_drop_queue_full_high(), + d2c_batches_total: stats.get_me_d2c_batches_total(), + d2c_batch_frames_total: stats.get_me_d2c_batch_frames_total(), + d2c_batch_bytes_total: stats.get_me_d2c_batch_bytes_total(), + d2c_flush_reason_queue_drain_total: stats.get_me_d2c_flush_reason_queue_drain_total(), + d2c_flush_reason_batch_frames_total: stats.get_me_d2c_flush_reason_batch_frames_total(), + d2c_flush_reason_batch_bytes_total: stats.get_me_d2c_flush_reason_batch_bytes_total(), + d2c_flush_reason_max_delay_total: stats.get_me_d2c_flush_reason_max_delay_total(), + d2c_flush_reason_ack_immediate_total: stats + .get_me_d2c_flush_reason_ack_immediate_total(), + d2c_flush_reason_close_total: stats.get_me_d2c_flush_reason_close_total(), + d2c_data_frames_total: stats.get_me_d2c_data_frames_total(), + d2c_ack_frames_total: stats.get_me_d2c_ack_frames_total(), + d2c_payload_bytes_total: stats.get_me_d2c_payload_bytes_total(), + d2c_write_mode_coalesced_total: stats.get_me_d2c_write_mode_coalesced_total(), + d2c_write_mode_split_total: stats.get_me_d2c_write_mode_split_total(), + d2c_quota_reject_pre_write_total: stats.get_me_d2c_quota_reject_pre_write_total(), + d2c_quota_reject_post_write_total: stats.get_me_d2c_quota_reject_post_write_total(), + d2c_frame_buf_shrink_total: stats.get_me_d2c_frame_buf_shrink_total(), + d2c_frame_buf_shrink_bytes_total: stats.get_me_d2c_frame_buf_shrink_bytes_total(), socks_kdf_strict_reject_total: stats.get_me_socks_kdf_strict_reject(), socks_kdf_compat_fallback_total: stats.get_me_socks_kdf_compat_fallback(), endpoint_quarantine_total: stats.get_me_endpoint_quarantine_total(), @@ -96,8 +115,6 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer pool_swap_total: stats.get_pool_swap_total(), pool_drain_active: stats.get_pool_drain_active(), pool_force_close_total: stats.get_pool_force_close_total(), - pool_drain_soft_evict_total: stats.get_pool_drain_soft_evict_total(), - pool_drain_soft_evict_writer_total: stats.get_pool_drain_soft_evict_writer_total(), pool_stale_pick_total: stats.get_pool_stale_pick_total(), writer_removed_total: stats.get_me_writer_removed_total(), writer_removed_unexpected_total: stats.get_me_writer_removed_unexpected_total(), @@ -106,29 +123,6 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer refill_failed_total: stats.get_me_refill_failed_total(), writer_restored_same_endpoint_total: stats.get_me_writer_restored_same_endpoint_total(), writer_restored_fallback_total: stats.get_me_writer_restored_fallback_total(), - teardown_attempt_total_normal: stats - .get_me_writer_teardown_attempt_total_by_mode(MeWriterTeardownMode::Normal), - teardown_attempt_total_hard_detach: stats - .get_me_writer_teardown_attempt_total_by_mode(MeWriterTeardownMode::HardDetach), - teardown_success_total_normal: stats - .get_me_writer_teardown_success_total(MeWriterTeardownMode::Normal), - teardown_success_total_hard_detach: stats - .get_me_writer_teardown_success_total(MeWriterTeardownMode::HardDetach), - teardown_timeout_total: stats.get_me_writer_teardown_timeout_total(), - teardown_escalation_total: stats.get_me_writer_teardown_escalation_total(), - teardown_noop_total: stats.get_me_writer_teardown_noop_total(), - teardown_cleanup_side_effect_failures_total: stats - .get_me_writer_cleanup_side_effect_failures_total_all(), - teardown_duration_count_total: stats - .get_me_writer_teardown_duration_count(MeWriterTeardownMode::Normal) - .saturating_add( - stats.get_me_writer_teardown_duration_count(MeWriterTeardownMode::HardDetach), - ), - teardown_duration_sum_seconds_total: stats - .get_me_writer_teardown_duration_sum_seconds(MeWriterTeardownMode::Normal) - + stats.get_me_writer_teardown_duration_sum_seconds( - MeWriterTeardownMode::HardDetach, - ), }, desync: ZeroDesyncData { secure_padding_invalid_total: stats.get_secure_padding_invalid(), @@ -340,7 +334,6 @@ async fn get_minimal_payload_cached( available_pct: status.available_pct, required_writers: status.required_writers, alive_writers: status.alive_writers, - coverage_ratio: status.coverage_ratio, coverage_pct: status.coverage_pct, fresh_alive_writers: status.fresh_alive_writers, fresh_coverage_pct: status.fresh_coverage_pct, @@ -398,7 +391,6 @@ async fn get_minimal_payload_cached( floor_max: entry.floor_max, floor_capped: entry.floor_capped, alive_writers: entry.alive_writers, - coverage_ratio: entry.coverage_ratio, coverage_pct: entry.coverage_pct, fresh_alive_writers: entry.fresh_alive_writers, fresh_coverage_pct: entry.fresh_coverage_pct, @@ -452,12 +444,6 @@ async fn get_minimal_payload_cached( me_reconnect_backoff_cap_ms: runtime.me_reconnect_backoff_cap_ms, me_reconnect_fast_retry_count: runtime.me_reconnect_fast_retry_count, me_pool_drain_ttl_secs: runtime.me_pool_drain_ttl_secs, - me_instadrain: runtime.me_instadrain, - me_pool_drain_soft_evict_enabled: runtime.me_pool_drain_soft_evict_enabled, - me_pool_drain_soft_evict_grace_secs: runtime.me_pool_drain_soft_evict_grace_secs, - me_pool_drain_soft_evict_per_writer: runtime.me_pool_drain_soft_evict_per_writer, - me_pool_drain_soft_evict_budget_per_core: runtime.me_pool_drain_soft_evict_budget_per_core, - me_pool_drain_soft_evict_cooldown_ms: runtime.me_pool_drain_soft_evict_cooldown_ms, me_pool_force_close_secs: runtime.me_pool_force_close_secs, me_pool_min_fresh_ratio: runtime.me_pool_min_fresh_ratio, me_bind_stale_mode: runtime.me_bind_stale_mode, @@ -526,7 +512,6 @@ fn disabled_me_writers(now_epoch_secs: u64, reason: &'static str) -> MeWritersDa available_pct: 0.0, required_writers: 0, alive_writers: 0, - coverage_ratio: 0.0, coverage_pct: 0.0, fresh_alive_writers: 0, fresh_coverage_pct: 0.0, diff --git a/src/api/runtime_zero.rs b/src/api/runtime_zero.rs index ba89302..a6eb163 100644 --- a/src/api/runtime_zero.rs +++ b/src/api/runtime_zero.rs @@ -128,7 +128,8 @@ pub(super) fn build_system_info_data( .runtime_state .last_config_reload_epoch_secs .load(Ordering::Relaxed); - let last_config_reload_epoch_secs = (last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs); + let last_config_reload_epoch_secs = + (last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs); let git_commit = option_env!("TELEMT_GIT_COMMIT") .or(option_env!("VERGEN_GIT_SHA")) @@ -153,7 +154,10 @@ pub(super) fn build_system_info_data( uptime_seconds: shared.stats.uptime_secs(), config_path: shared.config_path.display().to_string(), config_hash: revision.to_string(), - config_reload_count: shared.runtime_state.config_reload_count.load(Ordering::Relaxed), + config_reload_count: shared + .runtime_state + .config_reload_count + .load(Ordering::Relaxed), last_config_reload_epoch_secs, } } @@ -233,9 +237,7 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD adaptive_floor_writers_per_core_total: cfg .general .me_adaptive_floor_writers_per_core_total, - adaptive_floor_cpu_cores_override: cfg - .general - .me_adaptive_floor_cpu_cores_override, + adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override, adaptive_floor_max_extra_writers_single_per_core: cfg .general .me_adaptive_floor_max_extra_writers_single_per_core, diff --git a/src/api/users.rs b/src/api/users.rs index f339806..2ee8b98 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -46,7 +46,9 @@ pub(super) async fn create_user( None => random_user_secret(), }; - if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + if let Some(ad_tag) = body.user_ad_tag.as_ref() + && !is_valid_ad_tag(ad_tag) + { return Err(ApiFailure::bad_request( "user_ad_tag must be exactly 32 hex characters", )); @@ -65,12 +67,18 @@ pub(super) async fn create_user( )); } - cfg.access.users.insert(body.username.clone(), secret.clone()); + cfg.access + .users + .insert(body.username.clone(), secret.clone()); if let Some(ad_tag) = body.user_ad_tag { - cfg.access.user_ad_tags.insert(body.username.clone(), ad_tag); + cfg.access + .user_ad_tags + .insert(body.username.clone(), ad_tag); } if let Some(limit) = body.max_tcp_conns { - cfg.access.user_max_tcp_conns.insert(body.username.clone(), limit); + cfg.access + .user_max_tcp_conns + .insert(body.username.clone(), limit); } if let Some(expiration) = expiration { cfg.access @@ -78,7 +86,9 @@ pub(super) async fn create_user( .insert(body.username.clone(), expiration); } if let Some(quota) = body.data_quota_bytes { - cfg.access.user_data_quota.insert(body.username.clone(), quota); + cfg.access + .user_data_quota + .insert(body.username.clone(), quota); } let updated_limit = body.max_unique_ips; @@ -108,11 +118,15 @@ pub(super) async fn create_user( touched_sections.push(AccessSection::UserMaxUniqueIps); } - let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; + let revision = + save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; drop(_guard); if let Some(limit) = updated_limit { - shared.ip_tracker.set_user_limit(&body.username, limit).await; + shared + .ip_tracker + .set_user_limit(&body.username, limit) + .await; } let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); @@ -140,12 +154,7 @@ pub(super) async fn create_user( recent_unique_ips: 0, recent_unique_ips_list: Vec::new(), total_octets: 0, - links: build_user_links( - &cfg, - &secret, - detected_ip_v4, - detected_ip_v6, - ), + links: build_user_links(&cfg, &secret, detected_ip_v4, detected_ip_v6), }); Ok((CreateUserResponse { user, secret }, revision)) @@ -157,12 +166,16 @@ pub(super) async fn patch_user( expected_revision: Option, shared: &ApiShared, ) -> Result<(UserInfo, String), ApiFailure> { - if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) { + if let Some(secret) = body.secret.as_ref() + && !is_valid_user_secret(secret) + { return Err(ApiFailure::bad_request( "secret must be exactly 32 hex characters", )); } - if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + if let Some(ad_tag) = body.user_ad_tag.as_ref() + && !is_valid_ad_tag(ad_tag) + { return Err(ApiFailure::bad_request( "user_ad_tag must be exactly 32 hex characters", )); @@ -187,10 +200,14 @@ pub(super) async fn patch_user( cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); } if let Some(limit) = body.max_tcp_conns { - cfg.access.user_max_tcp_conns.insert(user.to_string(), limit); + cfg.access + .user_max_tcp_conns + .insert(user.to_string(), limit); } if let Some(expiration) = expiration { - cfg.access.user_expirations.insert(user.to_string(), expiration); + cfg.access + .user_expirations + .insert(user.to_string(), expiration); } if let Some(quota) = body.data_quota_bytes { cfg.access.user_data_quota.insert(user.to_string(), quota); @@ -198,7 +215,9 @@ pub(super) async fn patch_user( let mut updated_limit = None; if let Some(limit) = body.max_unique_ips { - cfg.access.user_max_unique_ips.insert(user.to_string(), limit); + cfg.access + .user_max_unique_ips + .insert(user.to_string(), limit); updated_limit = Some(limit); } @@ -263,7 +282,8 @@ pub(super) async fn rotate_secret( AccessSection::UserDataQuota, AccessSection::UserMaxUniqueIps, ]; - let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; + let revision = + save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; drop(_guard); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); @@ -330,7 +350,8 @@ pub(super) async fn delete_user( AccessSection::UserDataQuota, AccessSection::UserMaxUniqueIps, ]; - let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; + let revision = + save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; drop(_guard); shared.ip_tracker.remove_user_limit(user).await; shared.ip_tracker.clear_user_ips(user).await; @@ -365,12 +386,7 @@ pub(super) async fn users_from_config( .users .get(&username) .map(|secret| { - build_user_links( - cfg, - secret, - startup_detected_ip_v4, - startup_detected_ip_v6, - ) + build_user_links(cfg, secret, startup_detected_ip_v4, startup_detected_ip_v6) }) .unwrap_or(UserLinks { classic: Vec::new(), @@ -392,10 +408,8 @@ pub(super) async fn users_from_config( .get(&username) .copied() .filter(|limit| *limit > 0) - .or( - (cfg.access.user_max_unique_ips_global_each > 0) - .then_some(cfg.access.user_max_unique_ips_global_each), - ), + .or((cfg.access.user_max_unique_ips_global_each > 0) + .then_some(cfg.access.user_max_unique_ips_global_each)), current_connections: stats.get_user_curr_connects(&username), active_unique_ips: active_ip_list.len(), active_unique_ips_list: active_ip_list, @@ -481,11 +495,11 @@ fn resolve_link_hosts( push_unique_host(&mut hosts, host); continue; } - if let Some(ip) = listener.announce_ip { - if !ip.is_unspecified() { - push_unique_host(&mut hosts, &ip.to_string()); - continue; - } + if let Some(ip) = listener.announce_ip + && !ip.is_unspecified() + { + push_unique_host(&mut hosts, &ip.to_string()); + continue; } if listener.ip.is_unspecified() { let detected_ip = if listener.ip.is_ipv4() { diff --git a/src/cli.rs b/src/cli.rs index 87dcfb5..6dc0e2a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,9 +1,9 @@ //! CLI commands: --init (fire-and-forget setup) +use rand::RngExt; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; -use rand::Rng; /// Options for the init command pub struct InitOptions { @@ -35,10 +35,10 @@ pub fn parse_init_args(args: &[String]) -> Option { if !args.iter().any(|a| a == "--init") { return None; } - + let mut opts = InitOptions::default(); let mut i = 0; - + while i < args.len() { match args[i].as_str() { "--port" => { @@ -78,7 +78,7 @@ pub fn parse_init_args(args: &[String]) -> Option { } i += 1; } - + Some(opts) } @@ -86,7 +86,7 @@ pub fn parse_init_args(args: &[String]) -> Option { pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[telemt] Fire-and-forget setup"); eprintln!(); - + // 1. Generate or validate secret let secret = match opts.secret { Some(s) => { @@ -98,28 +98,28 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { } None => generate_secret(), }; - + eprintln!("[+] Secret: {}", secret); eprintln!("[+] User: {}", opts.username); eprintln!("[+] Port: {}", opts.port); eprintln!("[+] Domain: {}", opts.domain); - + // 2. Create config directory fs::create_dir_all(&opts.config_dir)?; let config_path = opts.config_dir.join("config.toml"); - + // 3. Write config let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain); fs::write(&config_path, &config_content)?; eprintln!("[+] Config written to {}", config_path.display()); - + // 4. Write systemd unit - let exe_path = std::env::current_exe() - .unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); - + let exe_path = + std::env::current_exe().unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); + let unit_path = Path::new("/etc/systemd/system/telemt.service"); let unit_content = generate_systemd_unit(&exe_path, &config_path); - + match fs::write(unit_path, &unit_content) { Ok(()) => { eprintln!("[+] Systemd unit written to {}", unit_path.display()); @@ -128,31 +128,31 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[!] Cannot write systemd unit (run as root?): {}", e); eprintln!("[!] Manual unit file content:"); eprintln!("{}", unit_content); - + // Still print links and config print_links(&opts.username, &secret, opts.port, &opts.domain); return Ok(()); } } - + // 5. Reload systemd run_cmd("systemctl", &["daemon-reload"]); - + // 6. Enable service run_cmd("systemctl", &["enable", "telemt.service"]); eprintln!("[+] Service enabled"); - + // 7. Start service (unless --no-start) if !opts.no_start { run_cmd("systemctl", &["start", "telemt.service"]); eprintln!("[+] Service started"); - + // Brief delay then check status std::thread::sleep(std::time::Duration::from_secs(1)); let status = Command::new("systemctl") .args(["is-active", "telemt.service"]) .output(); - + match status { Ok(out) if out.status.success() => { eprintln!("[+] Service is running"); @@ -166,12 +166,12 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[+] Service not started (--no-start)"); eprintln!("[+] Start manually: systemctl start telemt.service"); } - + eprintln!(); - + // 8. Print links print_links(&opts.username, &secret, opts.port, &opts.domain); - + Ok(()) } @@ -183,7 +183,7 @@ fn generate_secret() -> String { fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String { format!( -r#"# Telemt MTProxy — auto-generated config + r#"# Telemt MTProxy — auto-generated config # Re-run `telemt --init` to regenerate show_link = ["{username}"] @@ -246,7 +246,7 @@ tls_full_cert_ttl_secs = 90 [access] replay_check_len = 65536 -replay_window_secs = 1800 +replay_window_secs = 120 ignore_time_skew = false [access.users] @@ -266,7 +266,7 @@ weight = 10 fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String { format!( -r#"[Unit] + r#"[Unit] Description=Telemt MTProxy Documentation=https://github.com/telemt/telemt After=network-online.target @@ -309,11 +309,13 @@ fn run_cmd(cmd: &str, args: &[&str]) { fn print_links(username: &str, secret: &str, port: u16, domain: &str) { let domain_hex = hex::encode(domain); - + println!("=== Proxy Links ==="); println!("[{}]", username); - println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}", - port, secret, domain_hex); + println!( + " EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}", + port, secret, domain_hex + ); println!(); println!("Replace YOUR_SERVER_IP with your server's public IP."); println!("The proxy will auto-detect and display the correct link on startup."); diff --git a/src/config/defaults.rs b/src/config/defaults.rs index be540b0..66ffeda 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; use ipnetwork::IpNetwork; use serde::Deserialize; +use std::collections::HashMap; // Helper defaults kept private to the config module. const DEFAULT_NETWORK_IPV6: Option = Some(false); @@ -29,6 +29,8 @@ const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_FRAMES: usize = 32; const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_BYTES: usize = 128 * 1024; const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_DELAY_US: u64 = 500; const DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE: bool = true; +const DEFAULT_ME_QUOTA_SOFT_OVERSHOOT_BYTES: u64 = 64 * 1024; +const DEFAULT_ME_D2C_FRAME_BUF_SHRINK_THRESHOLD_BYTES: usize = 256 * 1024; const DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES: usize = 64 * 1024; const DEFAULT_DIRECT_RELAY_COPY_BUF_S2C_BYTES: usize = 256 * 1024; const DEFAULT_ME_WRITER_PICK_SAMPLE_SIZE: u8 = 3; @@ -86,13 +88,31 @@ pub(crate) fn default_replay_check_len() -> usize { } pub(crate) fn default_replay_window_secs() -> u64 { - 1800 + // Keep replay cache TTL tight by default to reduce replay surface. + // Deployments with higher RTT or longer reconnect jitter can override this in config. + 120 } pub(crate) fn default_handshake_timeout() -> u64 { 30 } +pub(crate) fn default_relay_idle_policy_v2_enabled() -> bool { + true +} + +pub(crate) fn default_relay_client_idle_soft_secs() -> u64 { + 120 +} + +pub(crate) fn default_relay_client_idle_hard_secs() -> u64 { + 360 +} + +pub(crate) fn default_relay_idle_grace_after_downstream_activity_secs() -> u64 { + 30 +} + pub(crate) fn default_connect_timeout() -> u64 { 10 } @@ -125,10 +145,7 @@ pub(crate) fn default_weight() -> u16 { } pub(crate) fn default_metrics_whitelist() -> Vec { - vec![ - "127.0.0.1/32".parse().unwrap(), - "::1/128".parse().unwrap(), - ] + vec!["127.0.0.1/32".parse().unwrap(), "::1/128".parse().unwrap()] } pub(crate) fn default_api_listen() -> String { @@ -151,10 +168,18 @@ pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 { 1000 } -pub(crate) fn default_api_runtime_edge_enabled() -> bool { false } -pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { 1000 } -pub(crate) fn default_api_runtime_edge_top_n() -> usize { 10 } -pub(crate) fn default_api_runtime_edge_events_capacity() -> usize { 256 } +pub(crate) fn default_api_runtime_edge_enabled() -> bool { + false +} +pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { + 1000 +} +pub(crate) fn default_api_runtime_edge_top_n() -> usize { + 10 +} +pub(crate) fn default_api_runtime_edge_events_capacity() -> usize { + 256 +} pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { 500 @@ -364,6 +389,14 @@ pub(crate) fn default_me_d2c_ack_flush_immediate() -> bool { DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE } +pub(crate) fn default_me_quota_soft_overshoot_bytes() -> u64 { + DEFAULT_ME_QUOTA_SOFT_OVERSHOOT_BYTES +} + +pub(crate) fn default_me_d2c_frame_buf_shrink_threshold_bytes() -> usize { + DEFAULT_ME_D2C_FRAME_BUF_SHRINK_THRESHOLD_BYTES +} + pub(crate) fn default_direct_relay_copy_buf_c2s_bytes() -> usize { DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES } @@ -485,17 +518,53 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 { } pub(crate) fn default_server_hello_delay_min_ms() -> u64 { - 0 + 8 } pub(crate) fn default_server_hello_delay_max_ms() -> u64 { - 0 + 24 } pub(crate) fn default_alpn_enforce() -> bool { true } +pub(crate) fn default_mask_shape_hardening() -> bool { + true +} + +pub(crate) fn default_mask_shape_hardening_aggressive_mode() -> bool { + false +} + +pub(crate) fn default_mask_shape_bucket_floor_bytes() -> usize { + 512 +} + +pub(crate) fn default_mask_shape_bucket_cap_bytes() -> usize { + 4096 +} + +pub(crate) fn default_mask_shape_above_cap_blur() -> bool { + false +} + +pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize { + 512 +} + +pub(crate) fn default_mask_timing_normalization_enabled() -> bool { + false +} + +pub(crate) fn default_mask_timing_normalization_floor_ms() -> u64 { + 0 +} + +pub(crate) fn default_mask_timing_normalization_ceiling_ms() -> u64 { + 0 +} + pub(crate) fn default_stun_servers() -> Vec { vec![ "stun.l.google.com:5349".to_string(), diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 4cf7676..e580b7f 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -31,38 +31,30 @@ use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher}; use tokio::sync::{mpsc, watch}; use tracing::{error, info, warn}; -use crate::config::{ - LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, - MeWriterPickMode, -}; use super::load::{LoadedConfig, ProxyConfig}; +use crate::config::{ + LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode, +}; -const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2; const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); -const HOT_RELOAD_STABLE_RECHECK: Duration = Duration::from_millis(75); // ── Hot fields ──────────────────────────────────────────────────────────────── /// Fields that are safe to swap without restarting listeners. #[derive(Debug, Clone, PartialEq)] pub struct HotFields { - pub log_level: LogLevel, - pub ad_tag: Option, - pub dns_overrides: Vec, - pub desync_all_full: bool, - pub update_every_secs: u64, - pub me_reinit_every_secs: u64, - pub me_reinit_singleflight: bool, + pub log_level: LogLevel, + pub ad_tag: Option, + pub dns_overrides: Vec, + pub desync_all_full: bool, + pub update_every_secs: u64, + pub me_reinit_every_secs: u64, + pub me_reinit_singleflight: bool, pub me_reinit_coalesce_window_ms: u64, - pub hardswap: bool, - pub me_pool_drain_ttl_secs: u64, + pub hardswap: bool, + pub me_pool_drain_ttl_secs: u64, pub me_instadrain: bool, pub me_pool_drain_threshold: u64, - pub me_pool_drain_soft_evict_enabled: bool, - pub me_pool_drain_soft_evict_grace_secs: u64, - pub me_pool_drain_soft_evict_per_writer: u8, - pub me_pool_drain_soft_evict_budget_per_core: u16, - pub me_pool_drain_soft_evict_cooldown_ms: u64, pub me_pool_min_fresh_ratio: f32, pub me_reinit_drain_timeout_secs: u64, pub me_hardswap_warmup_delay_min_ms: u64, @@ -114,18 +106,20 @@ pub struct HotFields { pub me_d2c_flush_batch_max_bytes: usize, pub me_d2c_flush_batch_max_delay_us: u64, pub me_d2c_ack_flush_immediate: bool, + pub me_quota_soft_overshoot_bytes: u64, + pub me_d2c_frame_buf_shrink_threshold_bytes: usize, pub direct_relay_copy_buf_c2s_bytes: usize, pub direct_relay_copy_buf_s2c_bytes: usize, pub me_health_interval_ms_unhealthy: u64, pub me_health_interval_ms_healthy: u64, pub me_admission_poll_ms: u64, pub me_warn_rate_limit_ms: u64, - pub users: std::collections::HashMap, - pub user_ad_tags: std::collections::HashMap, - pub user_max_tcp_conns: std::collections::HashMap, - pub user_expirations: std::collections::HashMap>, - pub user_data_quota: std::collections::HashMap, - pub user_max_unique_ips: std::collections::HashMap, + pub users: std::collections::HashMap, + pub user_ad_tags: std::collections::HashMap, + pub user_max_tcp_conns: std::collections::HashMap, + pub user_expirations: std::collections::HashMap>, + pub user_data_quota: std::collections::HashMap, + pub user_max_unique_ips: std::collections::HashMap, pub user_max_unique_ips_global_each: usize, pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode, pub user_max_unique_ips_window_secs: u64, @@ -134,27 +128,18 @@ pub struct HotFields { impl HotFields { pub fn from_config(cfg: &ProxyConfig) -> Self { Self { - log_level: cfg.general.log_level.clone(), - ad_tag: cfg.general.ad_tag.clone(), - dns_overrides: cfg.network.dns_overrides.clone(), - desync_all_full: cfg.general.desync_all_full, - update_every_secs: cfg.general.effective_update_every_secs(), - me_reinit_every_secs: cfg.general.me_reinit_every_secs, - me_reinit_singleflight: cfg.general.me_reinit_singleflight, + log_level: cfg.general.log_level.clone(), + ad_tag: cfg.general.ad_tag.clone(), + dns_overrides: cfg.network.dns_overrides.clone(), + desync_all_full: cfg.general.desync_all_full, + update_every_secs: cfg.general.effective_update_every_secs(), + me_reinit_every_secs: cfg.general.me_reinit_every_secs, + me_reinit_singleflight: cfg.general.me_reinit_singleflight, me_reinit_coalesce_window_ms: cfg.general.me_reinit_coalesce_window_ms, - hardswap: cfg.general.hardswap, - me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs, + hardswap: cfg.general.hardswap, + me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs, me_instadrain: cfg.general.me_instadrain, me_pool_drain_threshold: cfg.general.me_pool_drain_threshold, - me_pool_drain_soft_evict_enabled: cfg.general.me_pool_drain_soft_evict_enabled, - me_pool_drain_soft_evict_grace_secs: cfg.general.me_pool_drain_soft_evict_grace_secs, - me_pool_drain_soft_evict_per_writer: cfg.general.me_pool_drain_soft_evict_per_writer, - me_pool_drain_soft_evict_budget_per_core: cfg - .general - .me_pool_drain_soft_evict_budget_per_core, - me_pool_drain_soft_evict_cooldown_ms: cfg - .general - .me_pool_drain_soft_evict_cooldown_ms, me_pool_min_fresh_ratio: cfg.general.me_pool_min_fresh_ratio, me_reinit_drain_timeout_secs: cfg.general.me_reinit_drain_timeout_secs, me_hardswap_warmup_delay_min_ms: cfg.general.me_hardswap_warmup_delay_min_ms, @@ -205,15 +190,11 @@ impl HotFields { me_adaptive_floor_min_writers_multi_endpoint: cfg .general .me_adaptive_floor_min_writers_multi_endpoint, - me_adaptive_floor_recover_grace_secs: cfg - .general - .me_adaptive_floor_recover_grace_secs, + me_adaptive_floor_recover_grace_secs: cfg.general.me_adaptive_floor_recover_grace_secs, me_adaptive_floor_writers_per_core_total: cfg .general .me_adaptive_floor_writers_per_core_total, - me_adaptive_floor_cpu_cores_override: cfg - .general - .me_adaptive_floor_cpu_cores_override, + me_adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override, me_adaptive_floor_max_extra_writers_single_per_core: cfg .general .me_adaptive_floor_max_extra_writers_single_per_core, @@ -232,26 +213,34 @@ impl HotFields { me_adaptive_floor_max_warm_writers_global: cfg .general .me_adaptive_floor_max_warm_writers_global, - me_route_backpressure_base_timeout_ms: cfg.general.me_route_backpressure_base_timeout_ms, - me_route_backpressure_high_timeout_ms: cfg.general.me_route_backpressure_high_timeout_ms, - me_route_backpressure_high_watermark_pct: cfg.general.me_route_backpressure_high_watermark_pct, + me_route_backpressure_base_timeout_ms: cfg + .general + .me_route_backpressure_base_timeout_ms, + me_route_backpressure_high_timeout_ms: cfg + .general + .me_route_backpressure_high_timeout_ms, + me_route_backpressure_high_watermark_pct: cfg + .general + .me_route_backpressure_high_watermark_pct, me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms, me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames, me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes, me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us, me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate, + me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes, + me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes, direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes, direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes, me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy, me_health_interval_ms_healthy: cfg.general.me_health_interval_ms_healthy, me_admission_poll_ms: cfg.general.me_admission_poll_ms, me_warn_rate_limit_ms: cfg.general.me_warn_rate_limit_ms, - users: cfg.access.users.clone(), - user_ad_tags: cfg.access.user_ad_tags.clone(), - user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), - user_expirations: cfg.access.user_expirations.clone(), - user_data_quota: cfg.access.user_data_quota.clone(), - user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), + users: cfg.access.users.clone(), + user_ad_tags: cfg.access.user_ad_tags.clone(), + user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), + user_expirations: cfg.access.user_expirations.clone(), + user_data_quota: cfg.access.user_data_quota.clone(), + user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each, user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode, user_max_unique_ips_window_secs: cfg.access.user_max_unique_ips_window_secs, @@ -346,16 +335,12 @@ impl WatchManifest { #[derive(Debug, Default)] struct ReloadState { applied_snapshot_hash: Option, - candidate_snapshot_hash: Option, - candidate_hits: u8, } impl ReloadState { fn new(applied_snapshot_hash: Option) -> Self { Self { applied_snapshot_hash, - candidate_snapshot_hash: None, - candidate_hits: 0, } } @@ -363,32 +348,8 @@ impl ReloadState { self.applied_snapshot_hash == Some(hash) } - fn observe_candidate(&mut self, hash: u64) -> u8 { - if self.candidate_snapshot_hash == Some(hash) { - self.candidate_hits = self.candidate_hits.saturating_add(1); - } else { - self.candidate_snapshot_hash = Some(hash); - self.candidate_hits = 1; - } - self.candidate_hits - } - - fn reset_candidate(&mut self) { - self.candidate_snapshot_hash = None; - self.candidate_hits = 0; - } - fn mark_applied(&mut self, hash: u64) { self.applied_snapshot_hash = Some(hash); - self.reset_candidate(); - } - - fn pending_candidate(&self) -> Option<(u64, u8)> { - let hash = self.candidate_snapshot_hash?; - if self.candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS { - return Some((hash, self.candidate_hits)); - } - None } } @@ -481,15 +442,6 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.general.me_pool_drain_ttl_secs = new.general.me_pool_drain_ttl_secs; cfg.general.me_instadrain = new.general.me_instadrain; cfg.general.me_pool_drain_threshold = new.general.me_pool_drain_threshold; - cfg.general.me_pool_drain_soft_evict_enabled = new.general.me_pool_drain_soft_evict_enabled; - cfg.general.me_pool_drain_soft_evict_grace_secs = - new.general.me_pool_drain_soft_evict_grace_secs; - cfg.general.me_pool_drain_soft_evict_per_writer = - new.general.me_pool_drain_soft_evict_per_writer; - cfg.general.me_pool_drain_soft_evict_budget_per_core = - new.general.me_pool_drain_soft_evict_budget_per_core; - cfg.general.me_pool_drain_soft_evict_cooldown_ms = - new.general.me_pool_drain_soft_evict_cooldown_ms; cfg.general.me_pool_min_fresh_ratio = new.general.me_pool_min_fresh_ratio; cfg.general.me_reinit_drain_timeout_secs = new.general.me_reinit_drain_timeout_secs; cfg.general.me_hardswap_warmup_delay_min_ms = new.general.me_hardswap_warmup_delay_min_ms; @@ -536,10 +488,14 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { new.general.me_adaptive_floor_writers_per_core_total; cfg.general.me_adaptive_floor_cpu_cores_override = new.general.me_adaptive_floor_cpu_cores_override; - cfg.general.me_adaptive_floor_max_extra_writers_single_per_core = - new.general.me_adaptive_floor_max_extra_writers_single_per_core; - cfg.general.me_adaptive_floor_max_extra_writers_multi_per_core = - new.general.me_adaptive_floor_max_extra_writers_multi_per_core; + cfg.general + .me_adaptive_floor_max_extra_writers_single_per_core = new + .general + .me_adaptive_floor_max_extra_writers_single_per_core; + cfg.general + .me_adaptive_floor_max_extra_writers_multi_per_core = new + .general + .me_adaptive_floor_max_extra_writers_multi_per_core; cfg.general.me_adaptive_floor_max_active_writers_per_core = new.general.me_adaptive_floor_max_active_writers_per_core; cfg.general.me_adaptive_floor_max_warm_writers_per_core = @@ -559,6 +515,9 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.general.me_d2c_flush_batch_max_bytes = new.general.me_d2c_flush_batch_max_bytes; cfg.general.me_d2c_flush_batch_max_delay_us = new.general.me_d2c_flush_batch_max_delay_us; cfg.general.me_d2c_ack_flush_immediate = new.general.me_d2c_ack_flush_immediate; + cfg.general.me_quota_soft_overshoot_bytes = new.general.me_quota_soft_overshoot_bytes; + cfg.general.me_d2c_frame_buf_shrink_threshold_bytes = + new.general.me_d2c_frame_buf_shrink_threshold_bytes; cfg.general.direct_relay_copy_buf_c2s_bytes = new.general.direct_relay_copy_buf_c2s_bytes; cfg.general.direct_relay_copy_buf_s2c_bytes = new.general.direct_relay_copy_buf_s2c_bytes; cfg.general.me_health_interval_ms_unhealthy = new.general.me_health_interval_ms_unhealthy; @@ -598,8 +557,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.server.api.minimal_runtime_cache_ttl_ms != new.server.api.minimal_runtime_cache_ttl_ms || old.server.api.runtime_edge_enabled != new.server.api.runtime_edge_enabled - || old.server.api.runtime_edge_cache_ttl_ms - != new.server.api.runtime_edge_cache_ttl_ms + || old.server.api.runtime_edge_cache_ttl_ms != new.server.api.runtime_edge_cache_ttl_ms || old.server.api.runtime_edge_top_n != new.server.api.runtime_edge_top_n || old.server.api.runtime_edge_events_capacity != new.server.api.runtime_edge_events_capacity @@ -615,8 +573,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.server.listen_tcp != new.server.listen_tcp || old.server.listen_unix_sock != new.server.listen_unix_sock || old.server.listen_unix_sock_perm != new.server.listen_unix_sock_perm - || old.server.max_connections != new.server.max_connections - || old.server.accept_permit_timeout_ms != new.server.accept_permit_timeout_ms { warned = true; warn!("config reload: server listener settings changed; restart required"); @@ -637,6 +593,19 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.tls_full_cert_ttl_secs != new.censorship.tls_full_cert_ttl_secs || old.censorship.alpn_enforce != new.censorship.alpn_enforce || old.censorship.mask_proxy_protocol != new.censorship.mask_proxy_protocol + || old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening + || old.censorship.mask_shape_bucket_floor_bytes + != new.censorship.mask_shape_bucket_floor_bytes + || old.censorship.mask_shape_bucket_cap_bytes != new.censorship.mask_shape_bucket_cap_bytes + || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur + || old.censorship.mask_shape_above_cap_blur_max_bytes + != new.censorship.mask_shape_above_cap_blur_max_bytes + || old.censorship.mask_timing_normalization_enabled + != new.censorship.mask_timing_normalization_enabled + || old.censorship.mask_timing_normalization_floor_ms + != new.censorship.mask_timing_normalization_floor_ms + || old.censorship.mask_timing_normalization_ceiling_ms + != new.censorship.mask_timing_normalization_ceiling_ms { warned = true; warn!("config reload: censorship settings changed; restart required"); @@ -677,9 +646,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b } if old.general.me_route_no_writer_mode != new.general.me_route_no_writer_mode || old.general.me_route_no_writer_wait_ms != new.general.me_route_no_writer_wait_ms - || old.general.me_route_hybrid_max_wait_ms != new.general.me_route_hybrid_max_wait_ms - || old.general.me_route_blocking_send_timeout_ms - != new.general.me_route_blocking_send_timeout_ms || old.general.me_route_inline_recovery_attempts != new.general.me_route_inline_recovery_attempts || old.general.me_route_inline_recovery_wait_ms @@ -688,10 +654,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b warned = true; warn!("config reload: general.me_route_no_writer_* changed; restart required"); } - if old.general.me_c2me_send_timeout_ms != new.general.me_c2me_send_timeout_ms { - warned = true; - warn!("config reload: general.me_c2me_send_timeout_ms changed; restart required"); - } if old.general.unknown_dc_log_path != new.general.unknown_dc_log_path || old.general.unknown_dc_file_log_enabled != new.general.unknown_dc_file_log_enabled { @@ -886,25 +848,6 @@ fn log_changes( old_hot.me_pool_drain_threshold, new_hot.me_pool_drain_threshold, ); } - if old_hot.me_pool_drain_soft_evict_enabled != new_hot.me_pool_drain_soft_evict_enabled - || old_hot.me_pool_drain_soft_evict_grace_secs - != new_hot.me_pool_drain_soft_evict_grace_secs - || old_hot.me_pool_drain_soft_evict_per_writer - != new_hot.me_pool_drain_soft_evict_per_writer - || old_hot.me_pool_drain_soft_evict_budget_per_core - != new_hot.me_pool_drain_soft_evict_budget_per_core - || old_hot.me_pool_drain_soft_evict_cooldown_ms - != new_hot.me_pool_drain_soft_evict_cooldown_ms - { - info!( - "config reload: me_pool_drain_soft_evict: enabled={} grace={}s per_writer={} budget_per_core={} cooldown={}ms", - new_hot.me_pool_drain_soft_evict_enabled, - new_hot.me_pool_drain_soft_evict_grace_secs, - new_hot.me_pool_drain_soft_evict_per_writer, - new_hot.me_pool_drain_soft_evict_budget_per_core, - new_hot.me_pool_drain_soft_evict_cooldown_ms - ); - } if (old_hot.me_pool_min_fresh_ratio - new_hot.me_pool_min_fresh_ratio).abs() > f32::EPSILON { info!( @@ -938,8 +881,7 @@ fn log_changes( { info!( "config reload: me_bind_stale: mode={:?} ttl={}s", - new_hot.me_bind_stale_mode, - new_hot.me_bind_stale_ttl_secs + new_hot.me_bind_stale_mode, new_hot.me_bind_stale_ttl_secs ); } if old_hot.me_secret_atomic_snapshot != new_hot.me_secret_atomic_snapshot @@ -1019,8 +961,7 @@ fn log_changes( if old_hot.me_socks_kdf_policy != new_hot.me_socks_kdf_policy { info!( "config reload: me_socks_kdf_policy: {:?} → {:?}", - old_hot.me_socks_kdf_policy, - new_hot.me_socks_kdf_policy, + old_hot.me_socks_kdf_policy, new_hot.me_socks_kdf_policy, ); } @@ -1074,8 +1015,7 @@ fn log_changes( || old_hot.me_route_backpressure_high_watermark_pct != new_hot.me_route_backpressure_high_watermark_pct || old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms - || old_hot.me_health_interval_ms_unhealthy - != new_hot.me_health_interval_ms_unhealthy + || old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy || old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy || old_hot.me_admission_poll_ms != new_hot.me_admission_poll_ms || old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms @@ -1097,34 +1037,47 @@ fn log_changes( || old_hot.me_d2c_flush_batch_max_bytes != new_hot.me_d2c_flush_batch_max_bytes || old_hot.me_d2c_flush_batch_max_delay_us != new_hot.me_d2c_flush_batch_max_delay_us || old_hot.me_d2c_ack_flush_immediate != new_hot.me_d2c_ack_flush_immediate + || old_hot.me_quota_soft_overshoot_bytes != new_hot.me_quota_soft_overshoot_bytes + || old_hot.me_d2c_frame_buf_shrink_threshold_bytes + != new_hot.me_d2c_frame_buf_shrink_threshold_bytes || old_hot.direct_relay_copy_buf_c2s_bytes != new_hot.direct_relay_copy_buf_c2s_bytes || old_hot.direct_relay_copy_buf_s2c_bytes != new_hot.direct_relay_copy_buf_s2c_bytes { info!( - "config reload: relay_tuning: me_d2c_frames={} me_d2c_bytes={} me_d2c_delay_us={} me_ack_flush_immediate={} direct_buf_c2s={} direct_buf_s2c={}", + "config reload: relay_tuning: me_d2c_frames={} me_d2c_bytes={} me_d2c_delay_us={} me_ack_flush_immediate={} me_quota_soft_overshoot_bytes={} me_d2c_frame_buf_shrink_threshold_bytes={} direct_buf_c2s={} direct_buf_s2c={}", new_hot.me_d2c_flush_batch_max_frames, new_hot.me_d2c_flush_batch_max_bytes, new_hot.me_d2c_flush_batch_max_delay_us, new_hot.me_d2c_ack_flush_immediate, + new_hot.me_quota_soft_overshoot_bytes, + new_hot.me_d2c_frame_buf_shrink_threshold_bytes, new_hot.direct_relay_copy_buf_c2s_bytes, new_hot.direct_relay_copy_buf_s2c_bytes, ); } if old_hot.users != new_hot.users { - let mut added: Vec<&String> = new_hot.users.keys() + let mut added: Vec<&String> = new_hot + .users + .keys() .filter(|u| !old_hot.users.contains_key(*u)) .collect(); added.sort(); - let mut removed: Vec<&String> = old_hot.users.keys() + let mut removed: Vec<&String> = old_hot + .users + .keys() .filter(|u| !new_hot.users.contains_key(*u)) .collect(); removed.sort(); - let mut changed: Vec<&String> = new_hot.users.keys() + let mut changed: Vec<&String> = new_hot + .users + .keys() .filter(|u| { - old_hot.users.get(*u) + old_hot + .users + .get(*u) .map(|s| s != &new_hot.users[*u]) .unwrap_or(false) }) @@ -1134,10 +1087,18 @@ fn log_changes( if !added.is_empty() { info!( "config reload: users added: [{}]", - added.iter().map(|s| s.as_str()).collect::>().join(", ") + added + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6); - let port = new_cfg.general.links.public_port.unwrap_or(new_cfg.server.port); + let port = new_cfg + .general + .links + .public_port + .unwrap_or(new_cfg.server.port); for user in &added { if let Some(secret) = new_hot.users.get(*user) { print_user_links(user, secret, &host, port, new_cfg); @@ -1147,13 +1108,21 @@ fn log_changes( if !removed.is_empty() { info!( "config reload: users removed: [{}]", - removed.iter().map(|s| s.as_str()).collect::>().join(", ") + removed + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); } if !changed.is_empty() { info!( "config reload: users secret changed: [{}]", - changed.iter().map(|s| s.as_str()).collect::>().join(", ") + changed + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); } } @@ -1184,8 +1153,7 @@ fn log_changes( } if old_hot.user_max_unique_ips_global_each != new_hot.user_max_unique_ips_global_each || old_hot.user_max_unique_ips_mode != new_hot.user_max_unique_ips_mode - || old_hot.user_max_unique_ips_window_secs - != new_hot.user_max_unique_ips_window_secs + || old_hot.user_max_unique_ips_window_secs != new_hot.user_max_unique_ips_window_secs { info!( "config reload: user_max_unique_ips policy global_each={} mode={:?} window={}s", @@ -1208,7 +1176,6 @@ fn reload_config( let loaded = match ProxyConfig::load_with_metadata(config_path) { Ok(loaded) => loaded, Err(e) => { - reload_state.reset_candidate(); error!("config reload: failed to parse {:?}: {}", config_path, e); return None; } @@ -1221,8 +1188,10 @@ fn reload_config( let next_manifest = WatchManifest::from_source_files(&source_files); if let Err(e) = new_cfg.validate() { - reload_state.reset_candidate(); - error!("config reload: validation failed: {}; keeping old config", e); + error!( + "config reload: validation failed: {}; keeping old config", + e + ); return Some(next_manifest); } @@ -1230,17 +1199,6 @@ fn reload_config( return Some(next_manifest); } - let candidate_hits = reload_state.observe_candidate(rendered_hash); - if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS { - info!( - snapshot_hash = rendered_hash, - candidate_hits, - required_hits = HOT_RELOAD_STABLE_SNAPSHOTS, - "config reload: candidate snapshot observed but not stable yet" - ); - return Some(next_manifest); - } - let old_cfg = config_tx.borrow().clone(); let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg); let old_hot = HotFields::from_config(&old_cfg); @@ -1260,7 +1218,6 @@ fn reload_config( if old_hot.dns_overrides != applied_hot.dns_overrides && let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides) { - reload_state.reset_candidate(); error!( "config reload: invalid network.dns_overrides: {}; keeping old config", e @@ -1281,73 +1238,6 @@ fn reload_config( Some(next_manifest) } -async fn reload_with_internal_stable_rechecks( - config_path: &PathBuf, - config_tx: &watch::Sender>, - log_tx: &watch::Sender, - detected_ip_v4: Option, - detected_ip_v6: Option, - reload_state: &mut ReloadState, -) -> Option { - let mut next_manifest = reload_config( - config_path, - config_tx, - log_tx, - detected_ip_v4, - detected_ip_v6, - reload_state, - ); - let mut rechecks_left = HOT_RELOAD_STABLE_SNAPSHOTS.saturating_sub(1); - - while rechecks_left > 0 { - let Some((snapshot_hash, candidate_hits)) = reload_state.pending_candidate() else { - break; - }; - - info!( - snapshot_hash, - candidate_hits, - required_hits = HOT_RELOAD_STABLE_SNAPSHOTS, - rechecks_left, - recheck_delay_ms = HOT_RELOAD_STABLE_RECHECK.as_millis(), - "config reload: scheduling internal stable recheck" - ); - tokio::time::sleep(HOT_RELOAD_STABLE_RECHECK).await; - - let recheck_manifest = reload_config( - config_path, - config_tx, - log_tx, - detected_ip_v4, - detected_ip_v6, - reload_state, - ); - if recheck_manifest.is_some() { - next_manifest = recheck_manifest; - } - - if reload_state.is_applied(snapshot_hash) { - info!( - snapshot_hash, - "config reload: applied after internal stable recheck" - ); - break; - } - - if reload_state.pending_candidate().is_none() { - info!( - snapshot_hash, - "config reload: internal stable recheck aborted" - ); - break; - } - - rechecks_left = rechecks_left.saturating_sub(1); - } - - next_manifest -} - // ── Public API ──────────────────────────────────────────────────────────────── /// Spawn the hot-reload watcher task. @@ -1366,7 +1256,7 @@ pub fn spawn_config_watcher( ) -> (watch::Receiver>, watch::Receiver) { let initial_level = initial.general.log_level.clone(); let (config_tx, config_rx) = watch::channel(initial); - let (log_tx, log_rx) = watch::channel(initial_level); + let (log_tx, log_rx) = watch::channel(initial_level); let config_path = normalize_watch_path(&config_path); let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok(); @@ -1383,25 +1273,29 @@ pub fn spawn_config_watcher( let tx_inotify = notify_tx.clone(); let manifest_for_inotify = manifest_state.clone(); - let mut inotify_watcher = match recommended_watcher(move |res: notify::Result| { - let Ok(event) = res else { return }; - if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { - return; - } - let is_our_file = manifest_for_inotify - .read() - .map(|manifest| manifest.matches_event_paths(&event.paths)) - .unwrap_or(false); - if is_our_file { - let _ = tx_inotify.try_send(()); - } - }) { - Ok(watcher) => Some(watcher), - Err(e) => { - warn!("config watcher: inotify unavailable: {}", e); - None - } - }; + let mut inotify_watcher = + match recommended_watcher(move |res: notify::Result| { + let Ok(event) = res else { return }; + if !matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ) { + return; + } + let is_our_file = manifest_for_inotify + .read() + .map(|manifest| manifest.matches_event_paths(&event.paths)) + .unwrap_or(false); + if is_our_file { + let _ = tx_inotify.try_send(()); + } + }) { + Ok(watcher) => Some(watcher), + Err(e) => { + warn!("config watcher: inotify unavailable: {}", e); + None + } + }; apply_watch_manifest( inotify_watcher.as_mut(), Option::<&mut notify::poll::PollWatcher>::None, @@ -1417,7 +1311,10 @@ pub fn spawn_config_watcher( let mut poll_watcher = match notify::poll::PollWatcher::new( move |res: notify::Result| { let Ok(event) = res else { return }; - if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { + if !matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ) { return; } let is_our_file = manifest_for_poll @@ -1465,22 +1362,36 @@ pub fn spawn_config_watcher( } } #[cfg(not(unix))] - if notify_rx.recv().await.is_none() { break; } + if notify_rx.recv().await.is_none() { + break; + } // Debounce: drain extra events that arrive within a short quiet window. tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; while notify_rx.try_recv().is_ok() {} - if let Some(next_manifest) = reload_with_internal_stable_rechecks( + let mut next_manifest = reload_config( &config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6, &mut reload_state, - ) - .await - { + ); + if next_manifest.is_none() { + tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; + while notify_rx.try_recv().is_ok() {} + next_manifest = reload_config( + &config_path, + &config_tx, + &log_tx, + detected_ip_v4, + detected_ip_v6, + &mut reload_state, + ); + } + + if let Some(next_manifest) = next_manifest { apply_watch_manifest( inotify_watcher.as_mut(), poll_watcher.as_mut(), @@ -1555,7 +1466,10 @@ mod tests { new.server.port = old.server.port.saturating_add(1); let applied = overlay_hot_fields(&old, &new); - assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied)); + assert_eq!( + HotFields::from_config(&old), + HotFields::from_config(&applied) + ); assert_eq!(applied.server.port, old.server.port); } @@ -1574,7 +1488,10 @@ mod tests { applied.general.me_bind_stale_mode, new.general.me_bind_stale_mode ); - assert_ne!(HotFields::from_config(&old), HotFields::from_config(&applied)); + assert_ne!( + HotFields::from_config(&old), + HotFields::from_config(&applied) + ); } #[test] @@ -1588,7 +1505,10 @@ mod tests { applied.general.me_keepalive_interval_secs, old.general.me_keepalive_interval_secs ); - assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied)); + assert_eq!( + HotFields::from_config(&old), + HotFields::from_config(&applied) + ); } #[test] @@ -1600,69 +1520,35 @@ mod tests { let applied = overlay_hot_fields(&old, &new); assert_eq!(applied.general.hardswap, new.general.hardswap); - assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy); + assert_eq!( + applied.general.use_middle_proxy, + old.general.use_middle_proxy + ); assert!(!config_equal(&applied, &new)); } #[test] - fn reload_requires_stable_snapshot_before_hot_apply() { + fn reload_applies_hot_change_on_first_observed_snapshot() { let initial_tag = "11111111111111111111111111111111"; let final_tag = "22222222222222222222222222222222"; let path = temp_config_path("telemt_hot_reload_stable"); write_reload_config(&path, Some(initial_tag), None); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); - let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; - let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); - let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); - let mut reload_state = ReloadState::new(Some(initial_hash)); - - write_reload_config(&path, None, None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - - write_reload_config(&path, Some(final_tag), None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); - - let _ = std::fs::remove_file(path); - } - - #[tokio::test] - async fn reload_cycle_applies_after_single_external_event() { - let initial_tag = "10101010101010101010101010101010"; - let final_tag = "20202020202020202020202020202020"; - let path = temp_config_path("telemt_hot_reload_single_event"); - - write_reload_config(&path, Some(initial_tag), None); - let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); - let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let initial_hash = ProxyConfig::load_with_metadata(&path) + .unwrap() + .rendered_hash; let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); write_reload_config(&path, Some(final_tag), None); - reload_with_internal_stable_rechecks( - &path, - &config_tx, - &log_tx, - None, - None, - &mut reload_state, - ) - .await - .unwrap(); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(final_tag) + ); - assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); let _ = std::fs::remove_file(path); } @@ -1674,14 +1560,15 @@ mod tests { write_reload_config(&path, Some(initial_tag), None); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); - let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let initial_hash = ProxyConfig::load_with_metadata(&path) + .unwrap() + .rendered_hash; let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1)); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); let applied = config_tx.borrow().clone(); assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag)); @@ -1689,4 +1576,36 @@ mod tests { let _ = std::fs::remove_file(path); } + + #[test] + fn reload_recovers_after_parse_error_on_next_attempt() { + let initial_tag = "cccccccccccccccccccccccccccccccc"; + let final_tag = "dddddddddddddddddddddddddddddddd"; + let path = temp_config_path("telemt_hot_reload_parse_recovery"); + + write_reload_config(&path, Some(initial_tag), None); + let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); + let initial_hash = ProxyConfig::load_with_metadata(&path) + .unwrap() + .rendered_hash; + let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); + let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); + let mut reload_state = ReloadState::new(Some(initial_hash)); + + std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap(); + assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none()); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(initial_tag) + ); + + write_reload_config(&path, Some(final_tag), None); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(final_tag) + ); + + let _ = std::fs::remove_file(path); + } } diff --git a/src/config/load.rs b/src/config/load.rs index c797637..bf6d036 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -5,7 +5,7 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; -use rand::Rng; +use rand::RngExt; use serde::{Deserialize, Serialize}; use shadowsocks::config::ServerConfig as ShadowsocksServerConfig; use tracing::warn; @@ -360,6 +360,131 @@ impl ProxyConfig { )); } + if config.timeouts.client_handshake == 0 { + return Err(ProxyError::Config( + "timeouts.client_handshake must be > 0".to_string(), + )); + } + + let handshake_timeout_ms = config + .timeouts + .client_handshake + .checked_mul(1000) + .ok_or_else(|| { + ProxyError::Config( + "timeouts.client_handshake is too large to validate milliseconds budget" + .to_string(), + ) + })?; + + if config.censorship.server_hello_delay_max_ms >= handshake_timeout_ms { + return Err(ProxyError::Config( + "censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000" + .to_string(), + )); + } + + if config.censorship.mask_shape_bucket_floor_bytes == 0 { + return Err(ProxyError::Config( + "censorship.mask_shape_bucket_floor_bytes must be > 0".to_string(), + )); + } + + if config.censorship.mask_shape_bucket_cap_bytes + < config.censorship.mask_shape_bucket_floor_bytes + { + return Err(ProxyError::Config( + "censorship.mask_shape_bucket_cap_bytes must be >= censorship.mask_shape_bucket_floor_bytes" + .to_string(), + )); + } + + if config.censorship.mask_shape_above_cap_blur && !config.censorship.mask_shape_hardening { + return Err(ProxyError::Config( + "censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true" + .to_string(), + )); + } + + if config.censorship.mask_shape_hardening_aggressive_mode + && !config.censorship.mask_shape_hardening + { + return Err(ProxyError::Config( + "censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true" + .to_string(), + )); + } + + if config.censorship.mask_shape_above_cap_blur + && config.censorship.mask_shape_above_cap_blur_max_bytes == 0 + { + return Err(ProxyError::Config( + "censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled" + .to_string(), + )); + } + + if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 { + return Err(ProxyError::Config( + "censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576".to_string(), + )); + } + + if config.censorship.mask_timing_normalization_ceiling_ms + < config.censorship.mask_timing_normalization_floor_ms + { + return Err(ProxyError::Config( + "censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms" + .to_string(), + )); + } + + if config.censorship.mask_timing_normalization_enabled + && config.censorship.mask_timing_normalization_floor_ms == 0 + { + return Err(ProxyError::Config( + "censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true" + .to_string(), + )); + } + + if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 { + return Err(ProxyError::Config( + "censorship.mask_timing_normalization_ceiling_ms must be <= 60000".to_string(), + )); + } + + if config.timeouts.relay_client_idle_soft_secs == 0 { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_soft_secs must be > 0".to_string(), + )); + } + + if config.timeouts.relay_client_idle_hard_secs == 0 { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_hard_secs must be > 0".to_string(), + )); + } + + if config.timeouts.relay_client_idle_hard_secs < config.timeouts.relay_client_idle_soft_secs + { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs" + .to_string(), + )); + } + + if config + .timeouts + .relay_idle_grace_after_downstream_activity_secs + > config.timeouts.relay_client_idle_hard_secs + { + return Err(ProxyError::Config( + "timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs" + .to_string(), + )); + } + if config.general.me_writer_cmd_channel_capacity == 0 { return Err(ProxyError::Config( "general.me_writer_cmd_channel_capacity must be > 0".to_string(), @@ -408,6 +533,19 @@ impl ProxyConfig { )); } + if config.general.me_quota_soft_overshoot_bytes > 16 * 1024 * 1024 { + return Err(ProxyError::Config( + "general.me_quota_soft_overshoot_bytes must be within [0, 16777216]".to_string(), + )); + } + + if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) { + return Err(ProxyError::Config( + "general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]" + .to_string(), + )); + } + if !(4096..=1024 * 1024).contains(&config.general.direct_relay_copy_buf_c2s_bytes) { return Err(ProxyError::Config( "general.direct_relay_copy_buf_c2s_bytes must be within [4096, 1048576]" @@ -648,7 +786,8 @@ impl ProxyConfig { } if config.general.me_route_backpressure_base_timeout_ms > 5000 { return Err(ProxyError::Config( - "general.me_route_backpressure_base_timeout_ms must be within [1, 5000]".to_string(), + "general.me_route_backpressure_base_timeout_ms must be within [1, 5000]" + .to_string(), )); } @@ -661,7 +800,8 @@ impl ProxyConfig { } if config.general.me_route_backpressure_high_timeout_ms > 5000 { return Err(ProxyError::Config( - "general.me_route_backpressure_high_timeout_ms must be within [1, 5000]".to_string(), + "general.me_route_backpressure_high_timeout_ms must be within [1, 5000]" + .to_string(), )); } @@ -860,7 +1000,7 @@ impl ProxyConfig { if !config.censorship.tls_emulation && config.censorship.fake_cert_len == default_fake_cert_len() { - config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); + config.censorship.fake_cert_len = rand::rng().random_range(1024..4096); } // Resolve listen_tcp: explicit value wins, otherwise auto-detect. @@ -982,6 +1122,18 @@ impl ProxyConfig { } } +#[cfg(test)] +#[path = "tests/load_idle_policy_tests.rs"] +mod load_idle_policy_tests; + +#[cfg(test)] +#[path = "tests/load_security_tests.rs"] +mod load_security_tests; + +#[cfg(test)] +#[path = "tests/load_mask_shape_security_tests.rs"] +mod load_mask_shape_security_tests; + #[cfg(test)] mod tests { use super::*; @@ -1697,7 +1849,9 @@ mod tests { let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml"); std::fs::write(&path, toml).unwrap(); let err = ProxyConfig::load(&path).unwrap_err().to_string(); - assert!(err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]")); + assert!( + err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]") + ); let _ = std::fs::remove_file(path); } @@ -1718,7 +1872,9 @@ mod tests { let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml"); std::fs::write(&path, toml).unwrap(); let err = ProxyConfig::load(&path).unwrap_err().to_string(); - assert!(err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]")); + assert!( + err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]") + ); let _ = std::fs::remove_file(path); } diff --git a/src/config/mod.rs b/src/config/mod.rs index c7187ad..dcb3bec 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,9 +1,9 @@ //! Configuration. pub(crate) mod defaults; -mod types; -mod load; pub mod hot_reload; +mod load; +mod types; pub use load::ProxyConfig; pub use types::*; diff --git a/src/config/tests/load_idle_policy_tests.rs b/src/config/tests/load_idle_policy_tests.rs new file mode 100644 index 0000000..c6a4e86 --- /dev/null +++ b/src/config/tests/load_idle_policy_tests.rs @@ -0,0 +1,80 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-idle-policy-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_relay_hard_idle_smaller_than_soft_idle_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +relay_client_idle_soft_secs = 120 +relay_client_idle_hard_secs = 60 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with hard= timeouts.relay_client_idle_soft_secs" + ), + "error must explain the violated hard>=soft invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_relay_grace_larger_than_hard_idle_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +relay_client_idle_soft_secs = 60 +relay_client_idle_hard_secs = 120 +relay_idle_grace_after_downstream_activity_secs = 121 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with grace>hard must fail"); + let msg = err.to_string(); + assert!( + msg.contains("timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs"), + "error must explain the violated grace<=hard invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_zero_handshake_timeout_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +client_handshake = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with zero handshake timeout must fail"); + let msg = err.to_string(); + assert!( + msg.contains("timeouts.client_handshake must be > 0"), + "error must explain that handshake timeout must be positive, got: {msg}" + ); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs new file mode 100644 index 0000000..8986a49 --- /dev/null +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -0,0 +1,238 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-load-mask-shape-security-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_zero_mask_shape_bucket_floor_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_bucket_floor_bytes = 0 +mask_shape_bucket_cap_bytes = 4096 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("zero mask_shape_bucket_floor_bytes must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_shape_bucket_floor_bytes must be > 0"), + "error must explain floor>0 invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_shape_bucket_cap_less_than_floor() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_bucket_floor_bytes = 1024 +mask_shape_bucket_cap_bytes = 512 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("mask_shape_bucket_cap_bytes < floor must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains( + "censorship.mask_shape_bucket_cap_bytes must be >= censorship.mask_shape_bucket_floor_bytes" + ), + "error must explain cap>=floor invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_mask_shape_bucket_cap_equal_to_floor() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = true +mask_shape_bucket_floor_bytes = 1024 +mask_shape_bucket_cap_bytes = 1024 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("equal cap and floor must be accepted"); + assert!(cfg.censorship.mask_shape_hardening); + assert_eq!(cfg.censorship.mask_shape_bucket_floor_bytes, 1024); + assert_eq!(cfg.censorship.mask_shape_bucket_cap_bytes, 1024); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_above_cap_blur_when_shape_hardening_disabled() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = false +mask_shape_above_cap_blur = true +mask_shape_above_cap_blur_max_bytes = 64 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("above-cap blur must require shape hardening enabled"); + let msg = err.to_string(); + assert!( + msg.contains( + "censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true" + ), + "error must explain blur prerequisite, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_above_cap_blur_with_zero_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = true +mask_shape_above_cap_blur = true +mask_shape_above_cap_blur_max_bytes = 0 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("above-cap blur max bytes must be > 0 when enabled"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"), + "error must explain blur max bytes invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_timing_normalization_floor_zero_when_enabled() { + let path = write_temp_config( + r#" +[censorship] +mask_timing_normalization_enabled = true +mask_timing_normalization_floor_ms = 0 +mask_timing_normalization_ceiling_ms = 200 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("timing normalization floor must be > 0 when enabled"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"), + "error must explain timing floor invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_timing_normalization_ceiling_below_floor() { + let path = write_temp_config( + r#" +[censorship] +mask_timing_normalization_enabled = true +mask_timing_normalization_floor_ms = 220 +mask_timing_normalization_ceiling_ms = 200 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("timing normalization ceiling must be >= floor"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"), + "error must explain timing ceiling/floor invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_valid_timing_normalization_and_above_cap_blur_config() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = true +mask_shape_above_cap_blur = true +mask_shape_above_cap_blur_max_bytes = 128 +mask_timing_normalization_enabled = true +mask_timing_normalization_floor_ms = 150 +mask_timing_normalization_ceiling_ms = 240 +"#, + ); + + let cfg = ProxyConfig::load(&path) + .expect("valid blur and timing normalization settings must be accepted"); + assert!(cfg.censorship.mask_shape_hardening); + assert!(cfg.censorship.mask_shape_above_cap_blur); + assert_eq!(cfg.censorship.mask_shape_above_cap_blur_max_bytes, 128); + assert!(cfg.censorship.mask_timing_normalization_enabled); + assert_eq!(cfg.censorship.mask_timing_normalization_floor_ms, 150); + assert_eq!(cfg.censorship.mask_timing_normalization_ceiling_ms, 240); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_aggressive_shape_mode_when_shape_hardening_disabled() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = false +mask_shape_hardening_aggressive_mode = true +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("aggressive shape hardening mode must require shape hardening enabled"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true"), + "error must explain aggressive-mode prerequisite, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_aggressive_shape_mode_when_shape_hardening_enabled() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = true +mask_shape_hardening_aggressive_mode = true +mask_shape_above_cap_blur = true +mask_shape_above_cap_blur_max_bytes = 8 +"#, + ); + + let cfg = ProxyConfig::load(&path) + .expect("aggressive shape hardening mode should be accepted when prerequisites are met"); + assert!(cfg.censorship.mask_shape_hardening); + assert!(cfg.censorship.mask_shape_hardening_aggressive_mode); + assert!(cfg.censorship.mask_shape_above_cap_blur); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_security_tests.rs b/src/config/tests/load_security_tests.rs new file mode 100644 index 0000000..654a9c0 --- /dev/null +++ b/src/config/tests/load_security_tests.rs @@ -0,0 +1,88 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-load-security-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_server_hello_delay_equal_to_handshake_timeout_budget() { + let path = write_temp_config( + r#" +[timeouts] +client_handshake = 1 + +[censorship] +server_hello_delay_max_ms = 1000 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("delay equal to handshake timeout must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains( + "censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000" + ), + "error must explain delay0 enable bounded wait for compatibility. #[serde(default = "default_me_reader_route_data_wait_ms")] pub me_reader_route_data_wait_ms: u64, @@ -489,6 +489,14 @@ pub struct GeneralConfig { #[serde(default = "default_me_d2c_ack_flush_immediate")] pub me_d2c_ack_flush_immediate: bool, + /// Additional bytes above strict per-user quota allowed in hot-path soft mode. + #[serde(default = "default_me_quota_soft_overshoot_bytes")] + pub me_quota_soft_overshoot_bytes: u64, + + /// Shrink threshold for reusable ME->Client frame assembly buffer. + #[serde(default = "default_me_d2c_frame_buf_shrink_threshold_bytes")] + pub me_d2c_frame_buf_shrink_threshold_bytes: usize, + /// Copy buffer size for client->DC direction in direct relay. #[serde(default = "default_direct_relay_copy_buf_c2s_bytes")] pub direct_relay_copy_buf_c2s_bytes: usize, @@ -945,6 +953,8 @@ impl Default for GeneralConfig { me_d2c_flush_batch_max_bytes: default_me_d2c_flush_batch_max_bytes(), me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(), me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(), + me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(), + me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(), direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(), direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(), me_warmup_stagger_enabled: default_true(), @@ -1047,8 +1057,7 @@ impl Default for GeneralConfig { me_pool_drain_soft_evict_per_writer: default_me_pool_drain_soft_evict_per_writer(), me_pool_drain_soft_evict_budget_per_core: default_me_pool_drain_soft_evict_budget_per_core(), - me_pool_drain_soft_evict_cooldown_ms: - default_me_pool_drain_soft_evict_cooldown_ms(), + me_pool_drain_soft_evict_cooldown_ms: default_me_pool_drain_soft_evict_cooldown_ms(), me_bind_stale_mode: MeBindStaleMode::default(), me_bind_stale_ttl_secs: default_me_bind_stale_ttl_secs(), me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(), @@ -1228,6 +1237,13 @@ pub struct ServerConfig { #[serde(default = "default_proxy_protocol_header_timeout_ms")] pub proxy_protocol_header_timeout_ms: u64, + /// Trusted source CIDRs allowed to send incoming PROXY protocol headers. + /// + /// When non-empty, connections from addresses outside this allowlist are + /// rejected before `src_addr` is applied. + #[serde(default)] + pub proxy_protocol_trusted_cidrs: Vec, + /// Port for the Prometheus-compatible metrics endpoint. /// Enables metrics when set; binds on all interfaces (dual-stack) by default. #[serde(default)] @@ -1270,6 +1286,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), + proxy_protocol_trusted_cidrs: Vec::new(), metrics_port: None, metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), @@ -1286,6 +1303,24 @@ pub struct TimeoutsConfig { #[serde(default = "default_handshake_timeout")] pub client_handshake: u64, + /// Enables soft/hard relay client idle policy for middle-relay sessions. + #[serde(default = "default_relay_idle_policy_v2_enabled")] + pub relay_idle_policy_v2_enabled: bool, + + /// Soft idle threshold for middle-relay client uplink activity in seconds. + /// Hitting this threshold marks the session as idle-candidate, but does not close it. + #[serde(default = "default_relay_client_idle_soft_secs")] + pub relay_client_idle_soft_secs: u64, + + /// Hard idle threshold for middle-relay client uplink activity in seconds. + /// Hitting this threshold closes the session. + #[serde(default = "default_relay_client_idle_hard_secs")] + pub relay_client_idle_hard_secs: u64, + + /// Additional grace in seconds added to hard idle window after recent downstream activity. + #[serde(default = "default_relay_idle_grace_after_downstream_activity_secs")] + pub relay_idle_grace_after_downstream_activity_secs: u64, + #[serde(default = "default_connect_timeout")] pub tg_connect: u64, @@ -1308,6 +1343,11 @@ impl Default for TimeoutsConfig { fn default() -> Self { Self { client_handshake: default_handshake_timeout(), + relay_idle_policy_v2_enabled: default_relay_idle_policy_v2_enabled(), + relay_client_idle_soft_secs: default_relay_client_idle_soft_secs(), + relay_client_idle_hard_secs: default_relay_client_idle_hard_secs(), + relay_idle_grace_after_downstream_activity_secs: + default_relay_idle_grace_after_downstream_activity_secs(), tg_connect: default_connect_timeout(), client_keepalive: default_keepalive(), client_ack: default_ack_timeout(), @@ -1381,6 +1421,46 @@ pub struct AntiCensorshipConfig { /// Allows the backend to see the real client IP. #[serde(default)] pub mask_proxy_protocol: u8, + + /// Enable shape-channel hardening on mask backend path by padding + /// client->mask stream tail to configured buckets on stream end. + #[serde(default = "default_mask_shape_hardening")] + pub mask_shape_hardening: bool, + + /// Opt-in aggressive shape hardening mode. + /// When enabled, masking may shape some backend-silent timeout paths and + /// enforces strictly positive above-cap blur when blur is enabled. + #[serde(default = "default_mask_shape_hardening_aggressive_mode")] + pub mask_shape_hardening_aggressive_mode: bool, + + /// Minimum bucket size for mask shape hardening padding. + #[serde(default = "default_mask_shape_bucket_floor_bytes")] + pub mask_shape_bucket_floor_bytes: usize, + + /// Maximum bucket size for mask shape hardening padding. + #[serde(default = "default_mask_shape_bucket_cap_bytes")] + pub mask_shape_bucket_cap_bytes: usize, + + /// Add bounded random tail bytes even when total bytes already exceed + /// mask_shape_bucket_cap_bytes. + #[serde(default = "default_mask_shape_above_cap_blur")] + pub mask_shape_above_cap_blur: bool, + + /// Maximum random bytes appended above cap when above-cap blur is enabled. + #[serde(default = "default_mask_shape_above_cap_blur_max_bytes")] + pub mask_shape_above_cap_blur_max_bytes: usize, + + /// Enable outcome-time normalization envelope for masking fallback. + #[serde(default = "default_mask_timing_normalization_enabled")] + pub mask_timing_normalization_enabled: bool, + + /// Lower bound (ms) for masking outcome timing envelope. + #[serde(default = "default_mask_timing_normalization_floor_ms")] + pub mask_timing_normalization_floor_ms: u64, + + /// Upper bound (ms) for masking outcome timing envelope. + #[serde(default = "default_mask_timing_normalization_ceiling_ms")] + pub mask_timing_normalization_ceiling_ms: u64, } impl Default for AntiCensorshipConfig { @@ -1402,6 +1482,15 @@ impl Default for AntiCensorshipConfig { tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(), alpn_enforce: default_alpn_enforce(), mask_proxy_protocol: 0, + mask_shape_hardening: default_mask_shape_hardening(), + mask_shape_hardening_aggressive_mode: default_mask_shape_hardening_aggressive_mode(), + mask_shape_bucket_floor_bytes: default_mask_shape_bucket_floor_bytes(), + mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), + mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), + mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(), + mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(), + mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(), + mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(), } } } diff --git a/src/crypto/aes.rs b/src/crypto/aes.rs index deda730..0726298 100644 --- a/src/crypto/aes.rs +++ b/src/crypto/aes.rs @@ -13,10 +13,13 @@ #![allow(dead_code)] -use aes::Aes256; -use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; -use zeroize::Zeroize; use crate::error::{ProxyError, Result}; +use aes::Aes256; +use ctr::{ + Ctr128BE, + cipher::{KeyIvInit, StreamCipher}, +}; +use zeroize::Zeroize; type Aes256Ctr = Ctr128BE; @@ -42,33 +45,39 @@ impl AesCtr { cipher: Aes256Ctr::new(key.into(), (&iv_bytes).into()), } } - + /// Create from key and IV slices pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result { if key.len() != 32 { - return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 32, + got: key.len(), + }); } if iv.len() != 16 { - return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 16, + got: iv.len(), + }); } - + let key: [u8; 32] = key.try_into().unwrap(); let iv = u128::from_be_bytes(iv.try_into().unwrap()); Ok(Self::new(&key, iv)) } - + /// Encrypt/decrypt data in-place (CTR mode is symmetric) pub fn apply(&mut self, data: &mut [u8]) { self.cipher.apply_keystream(data); } - + /// Encrypt data, returning new buffer pub fn encrypt(&mut self, data: &[u8]) -> Vec { let mut output = data.to_vec(); self.apply(&mut output); output } - + /// Decrypt data (for CTR, identical to encrypt) pub fn decrypt(&mut self, data: &[u8]) -> Vec { self.encrypt(data) @@ -99,27 +108,33 @@ impl Drop for AesCbc { impl AesCbc { /// AES block size const BLOCK_SIZE: usize = 16; - + /// Create new AES-CBC cipher with key and IV pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self { Self { key, iv } } - + /// Create from slices pub fn from_slices(key: &[u8], iv: &[u8]) -> Result { if key.len() != 32 { - return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 32, + got: key.len(), + }); } if iv.len() != 16 { - return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 16, + got: iv.len(), + }); } - + Ok(Self { key: key.try_into().unwrap(), iv: iv.try_into().unwrap(), }) } - + /// Encrypt a single block using raw AES (no chaining) fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { use aes::cipher::BlockEncrypt; @@ -127,7 +142,7 @@ impl AesCbc { key_schedule.encrypt_block((&mut output).into()); output } - + /// Decrypt a single block using raw AES (no chaining) fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { use aes::cipher::BlockDecrypt; @@ -135,7 +150,7 @@ impl AesCbc { key_schedule.decrypt_block((&mut output).into()); output } - + /// XOR two 16-byte blocks fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] { let mut result = [0u8; 16]; @@ -144,27 +159,28 @@ impl AesCbc { } result } - + /// Encrypt data using CBC mode with proper chaining /// /// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV pub fn encrypt(&self, data: &[u8]) -> Result> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(Vec::new()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut result = Vec::with_capacity(data.len()); let mut prev_ciphertext = self.iv; - + for chunk in data.chunks(Self::BLOCK_SIZE) { let plaintext: [u8; 16] = chunk.try_into().unwrap(); let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); @@ -172,30 +188,31 @@ impl AesCbc { prev_ciphertext = ciphertext; result.extend_from_slice(&ciphertext); } - + Ok(result) } - + /// Decrypt data using CBC mode with proper chaining /// /// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV pub fn decrypt(&self, data: &[u8]) -> Result> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(Vec::new()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut result = Vec::with_capacity(data.len()); let mut prev_ciphertext = self.iv; - + for chunk in data.chunks(Self::BLOCK_SIZE) { let ciphertext: [u8; 16] = chunk.try_into().unwrap(); let decrypted = self.decrypt_block(&ciphertext, &key_schedule); @@ -203,75 +220,77 @@ impl AesCbc { prev_ciphertext = ciphertext; result.extend_from_slice(&plaintext); } - + Ok(result) } - + /// Encrypt data in-place pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut prev_ciphertext = self.iv; - + for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - + for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - + let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.encrypt_block(block_array, &key_schedule); - + prev_ciphertext = *block_array; } - + Ok(()) } - + /// Decrypt data in-place pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut prev_ciphertext = self.iv; - + for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - + let current_ciphertext: [u8; 16] = block.try_into().unwrap(); - + let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.decrypt_block(block_array, &key_schedule); - + for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - + prev_ciphertext = current_ciphertext; } - + Ok(()) } } @@ -318,227 +337,227 @@ impl Decryptor for PassthroughEncryptor { #[cfg(test)] mod tests { use super::*; - + // ============= AES-CTR Tests ============= - + #[test] fn test_aes_ctr_roundtrip() { let key = [0u8; 32]; let iv = 12345u128; - + let original = b"Hello, MTProto!"; - + let mut enc = AesCtr::new(&key, iv); let encrypted = enc.encrypt(original); - + let mut dec = AesCtr::new(&key, iv); let decrypted = dec.decrypt(&encrypted); - + assert_eq!(original.as_slice(), decrypted.as_slice()); } - + #[test] fn test_aes_ctr_in_place() { let key = [0x42u8; 32]; let iv = 999u128; - + let original = b"Test data for in-place encryption"; let mut data = original.to_vec(); - + let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); - + assert_ne!(&data[..], original); - + let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); - + assert_eq!(&data[..], original); } - + // ============= AES-CBC Tests ============= - + #[test] fn test_aes_cbc_roundtrip() { let key = [0u8; 32]; let iv = [0u8; 16]; - + let original = [0u8; 32]; - + let cipher = AesCbc::new(key, iv); let encrypted = cipher.encrypt(&original).unwrap(); let decrypted = cipher.decrypt(&encrypted).unwrap(); - + assert_eq!(original.as_slice(), decrypted.as_slice()); } - + #[test] fn test_aes_cbc_chaining_works() { let key = [0x42u8; 32]; let iv = [0x00u8; 16]; - + let plaintext = [0xAAu8; 32]; - + let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - + let block1 = &ciphertext[0..16]; let block2 = &ciphertext[16..32]; - + assert_ne!( block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext" ); } - + #[test] fn test_aes_cbc_known_vector() { let key = [0u8; 32]; let iv = [0u8; 16]; let plaintext = [0u8; 16]; - + let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - + let decrypted = cipher.decrypt(&ciphertext).unwrap(); assert_eq!(plaintext.as_slice(), decrypted.as_slice()); - + assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); } - + #[test] fn test_aes_cbc_multi_block() { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - + let plaintext: Vec = (0..80).collect(); - + let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); let decrypted = cipher.decrypt(&ciphertext).unwrap(); - + assert_eq!(plaintext, decrypted); } - + #[test] fn test_aes_cbc_in_place() { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - + let original = [0x56u8; 48]; let mut buffer = original; - + let cipher = AesCbc::new(key, iv); - + cipher.encrypt_in_place(&mut buffer).unwrap(); assert_ne!(&buffer[..], &original[..]); - + cipher.decrypt_in_place(&mut buffer).unwrap(); assert_eq!(&buffer[..], &original[..]); } - + #[test] fn test_aes_cbc_empty_data() { let cipher = AesCbc::new([0u8; 32], [0u8; 16]); - + let encrypted = cipher.encrypt(&[]).unwrap(); assert!(encrypted.is_empty()); - + let decrypted = cipher.decrypt(&[]).unwrap(); assert!(decrypted.is_empty()); } - + #[test] fn test_aes_cbc_unaligned_error() { let cipher = AesCbc::new([0u8; 32], [0u8; 16]); - + let result = cipher.encrypt(&[0u8; 15]); assert!(result.is_err()); - + let result = cipher.encrypt(&[0u8; 17]); assert!(result.is_err()); } - + #[test] fn test_aes_cbc_avalanche_effect() { let key = [0xAB; 32]; let iv = [0xCD; 16]; - + let plaintext1 = [0u8; 32]; let mut plaintext2 = [0u8; 32]; plaintext2[0] = 0x01; - + let cipher = AesCbc::new(key, iv); - + let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); - + assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); } - + #[test] fn test_aes_cbc_iv_matters() { let key = [0x55; 32]; let plaintext = [0x77u8; 16]; - + let cipher1 = AesCbc::new(key, [0u8; 16]); let cipher2 = AesCbc::new(key, [1u8; 16]); - + let ciphertext1 = cipher1.encrypt(&plaintext).unwrap(); let ciphertext2 = cipher2.encrypt(&plaintext).unwrap(); - + assert_ne!(ciphertext1, ciphertext2); } - + #[test] fn test_aes_cbc_deterministic() { let key = [0x99; 32]; let iv = [0x88; 16]; let plaintext = [0x77u8; 32]; - + let cipher = AesCbc::new(key, iv); - + let ciphertext1 = cipher.encrypt(&plaintext).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext).unwrap(); - + assert_eq!(ciphertext1, ciphertext2); } - + // ============= Zeroize Tests ============= - + #[test] fn test_aes_cbc_zeroize_on_drop() { let key = [0xAA; 32]; let iv = [0xBB; 16]; - + let cipher = AesCbc::new(key, iv); // Verify key/iv are set assert_eq!(cipher.key, [0xAA; 32]); assert_eq!(cipher.iv, [0xBB; 16]); - + drop(cipher); // After drop, key/iv are zeroized (can't observe directly, // but the Drop impl runs without panic) } - + // ============= Error Handling Tests ============= - + #[test] fn test_invalid_key_length() { let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]); assert!(result.is_err()); - + let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]); assert!(result.is_err()); } - + #[test] fn test_invalid_iv_length() { let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]); assert!(result.is_err()); - + let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]); assert!(result.is_err()); } -} \ No newline at end of file +} diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs index fa3e441..9e1fa16 100644 --- a/src/crypto/hash.rs +++ b/src/crypto/hash.rs @@ -12,10 +12,10 @@ //! usages are intentional and protocol-mandated. use hmac::{Hmac, Mac}; -use sha2::Sha256; use md5::Md5; use sha1::Sha1; use sha2::Digest; +use sha2::Sha256; type HmacSha256 = Hmac; @@ -28,8 +28,7 @@ pub fn sha256(data: &[u8]) -> [u8; 32] { /// SHA-256 HMAC pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] { - let mut mac = HmacSha256::new_from_slice(key) - .expect("HMAC accepts any key length"); + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length"); mac.update(data); mac.finalize().into_bytes().into() } @@ -124,27 +123,18 @@ pub fn derive_middleproxy_keys( srv_ipv6: Option<&[u8; 16]>, ) -> ([u8; 32], [u8; 16]) { let s = build_middleproxy_prekey( - nonce_srv, - nonce_clt, - clt_ts, - srv_ip, - clt_port, - purpose, - clt_ip, - srv_port, - secret, - clt_ipv6, - srv_ipv6, + nonce_srv, nonce_clt, clt_ts, srv_ip, clt_port, purpose, clt_ip, srv_port, secret, + clt_ipv6, srv_ipv6, ); let md5_1 = md5(&s[1..]); let sha1_sum = sha1(&s); let md5_2 = md5(&s[2..]); - + let mut key = [0u8; 32]; key[..12].copy_from_slice(&md5_1[..12]); key[12..].copy_from_slice(&sha1_sum); - + (key, md5_2) } @@ -164,17 +154,8 @@ mod tests { let secret = vec![0x55u8; 128]; let prekey = build_middleproxy_prekey( - &nonce_srv, - &nonce_clt, - &clt_ts, - srv_ip, - &clt_port, - b"CLIENT", - clt_ip, - &srv_port, - &secret, - None, - None, + &nonce_srv, &nonce_clt, &clt_ts, srv_ip, &clt_port, b"CLIENT", clt_ip, &srv_port, + &secret, None, None, ); let digest = sha256(&prekey); assert_eq!( diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 9108f34..cf2dcd2 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -4,7 +4,7 @@ pub mod aes; pub mod hash; pub mod random; -pub use aes::{AesCtr, AesCbc}; +pub use aes::{AesCbc, AesCtr}; pub use hash::{ build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, sha256, sha256_hmac, }; diff --git a/src/crypto/random.rs b/src/crypto/random.rs index a88efc6..760f120 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -3,11 +3,11 @@ #![allow(deprecated)] #![allow(dead_code)] -use rand::{Rng, RngCore, SeedableRng}; -use rand::rngs::StdRng; -use parking_lot::Mutex; -use zeroize::Zeroize; use crate::crypto::AesCtr; +use parking_lot::Mutex; +use rand::rngs::StdRng; +use rand::{Rng, RngExt, SeedableRng}; +use zeroize::Zeroize; /// Cryptographically secure PRNG with AES-CTR pub struct SecureRandom { @@ -34,16 +34,16 @@ impl SecureRandom { pub fn new() -> Self { let mut seed_source = rand::rng(); let mut rng = StdRng::from_rng(&mut seed_source); - + let mut key = [0u8; 32]; rng.fill_bytes(&mut key); let iv: u128 = rng.random(); - + let cipher = AesCtr::new(&key, iv); - + // Zeroize local key copy — cipher already consumed it key.zeroize(); - + Self { inner: Mutex::new(SecureRandomInner { rng, @@ -53,7 +53,7 @@ impl SecureRandom { }), } } - + /// Fill a caller-provided buffer with random bytes. pub fn fill(&self, out: &mut [u8]) { let mut inner = self.inner.lock(); @@ -94,25 +94,25 @@ impl SecureRandom { self.fill(&mut out); out } - + /// Generate random number in range [0, max) pub fn range(&self, max: usize) -> usize { if max == 0 { return 0; } let mut inner = self.inner.lock(); - inner.rng.gen_range(0..max) + inner.rng.random_range(0..max) } - + /// Generate random bits pub fn bits(&self, k: usize) -> u64 { if k == 0 { return 0; } - + let bytes_needed = k.div_ceil(8); let bytes = self.bytes(bytes_needed.min(8)); - + let mut result = 0u64; for (i, &b) in bytes.iter().enumerate() { if i >= 8 { @@ -120,14 +120,14 @@ impl SecureRandom { } result |= (b as u64) << (i * 8); } - + if k < 64 { result &= (1u64 << k) - 1; } - + result } - + /// Choose random element from slice pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> { if slice.is_empty() { @@ -136,22 +136,22 @@ impl SecureRandom { Some(&slice[self.range(slice.len())]) } } - + /// Shuffle slice in place pub fn shuffle(&self, slice: &mut [T]) { let mut inner = self.inner.lock(); for i in (1..slice.len()).rev() { - let j = inner.rng.gen_range(0..=i); + let j = inner.rng.random_range(0..=i); slice.swap(i, j); } } - + /// Generate random u32 pub fn u32(&self) -> u32 { let mut inner = self.inner.lock(); inner.rng.random() } - + /// Generate random u64 pub fn u64(&self) -> u64 { let mut inner = self.inner.lock(); @@ -169,7 +169,7 @@ impl Default for SecureRandom { mod tests { use super::*; use std::collections::HashSet; - + #[test] fn test_bytes_uniqueness() { let rng = SecureRandom::new(); @@ -177,7 +177,7 @@ mod tests { let b = rng.bytes(32); assert_ne!(a, b); } - + #[test] fn test_bytes_length() { let rng = SecureRandom::new(); @@ -186,63 +186,63 @@ mod tests { assert_eq!(rng.bytes(100).len(), 100); assert_eq!(rng.bytes(1000).len(), 1000); } - + #[test] fn test_range() { let rng = SecureRandom::new(); - + for _ in 0..1000 { let n = rng.range(10); assert!(n < 10); } - + assert_eq!(rng.range(1), 0); assert_eq!(rng.range(0), 0); } - + #[test] fn test_bits() { let rng = SecureRandom::new(); - + for _ in 0..100 { assert!(rng.bits(1) <= 1); } - + for _ in 0..100 { assert!(rng.bits(8) <= 255); } } - + #[test] fn test_choose() { let rng = SecureRandom::new(); let items = vec![1, 2, 3, 4, 5]; - + let mut seen = HashSet::new(); for _ in 0..1000 { if let Some(&item) = rng.choose(&items) { seen.insert(item); } } - + assert_eq!(seen.len(), 5); - + let empty: Vec = vec![]; assert!(rng.choose(&empty).is_none()); } - + #[test] fn test_shuffle() { let rng = SecureRandom::new(); let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; - + let mut shuffled = original.clone(); rng.shuffle(&mut shuffled); - + let mut sorted = shuffled.clone(); sorted.sort(); assert_eq!(sorted, original); - + assert_ne!(shuffled, original); } } diff --git a/src/error.rs b/src/error.rs index e4d66b9..d9aeb22 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,28 +12,15 @@ use thiserror::Error; #[derive(Debug)] pub enum StreamError { /// Partial read: got fewer bytes than expected - PartialRead { - expected: usize, - got: usize, - }, + PartialRead { expected: usize, got: usize }, /// Partial write: wrote fewer bytes than expected - PartialWrite { - expected: usize, - written: usize, - }, + PartialWrite { expected: usize, written: usize }, /// Stream is in poisoned state and cannot be used - Poisoned { - reason: String, - }, + Poisoned { reason: String }, /// Buffer overflow: attempted to buffer more than allowed - BufferOverflow { - limit: usize, - attempted: usize, - }, + BufferOverflow { limit: usize, attempted: usize }, /// Invalid frame format - InvalidFrame { - details: String, - }, + InvalidFrame { details: String }, /// Unexpected end of stream UnexpectedEof, /// Underlying I/O error @@ -47,13 +34,21 @@ impl fmt::Display for StreamError { write!(f, "partial read: expected {} bytes, got {}", expected, got) } Self::PartialWrite { expected, written } => { - write!(f, "partial write: expected {} bytes, wrote {}", expected, written) + write!( + f, + "partial write: expected {} bytes, wrote {}", + expected, written + ) } Self::Poisoned { reason } => { write!(f, "stream poisoned: {}", reason) } Self::BufferOverflow { limit, attempted } => { - write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted) + write!( + f, + "buffer overflow: limit {}, attempted {}", + limit, attempted + ) } Self::InvalidFrame { details } => { write!(f, "invalid frame: {}", details) @@ -90,9 +85,7 @@ impl From for std::io::Error { StreamError::UnexpectedEof => { std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err) } - StreamError::Poisoned { .. } => { - std::io::Error::other(err) - } + StreamError::Poisoned { .. } => std::io::Error::other(err), StreamError::BufferOverflow { .. } => { std::io::Error::new(std::io::ErrorKind::OutOfMemory, err) } @@ -112,7 +105,7 @@ impl From for std::io::Error { pub trait Recoverable { /// Check if error is recoverable (can retry operation) fn is_recoverable(&self) -> bool; - + /// Check if connection can continue after this error fn can_continue(&self) -> bool; } @@ -123,19 +116,22 @@ impl Recoverable for StreamError { Self::PartialRead { .. } | Self::PartialWrite { .. } => true, Self::Io(e) => matches!( e.kind(), - std::io::ErrorKind::WouldBlock - | std::io::ErrorKind::Interrupted - | std::io::ErrorKind::TimedOut + std::io::ErrorKind::WouldBlock + | std::io::ErrorKind::Interrupted + | std::io::ErrorKind::TimedOut ), - Self::Poisoned { .. } + Self::Poisoned { .. } | Self::BufferOverflow { .. } | Self::InvalidFrame { .. } | Self::UnexpectedEof => false, } } - + fn can_continue(&self) -> bool { - !matches!(self, Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. }) + !matches!( + self, + Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. } + ) } } @@ -143,19 +139,19 @@ impl Recoverable for std::io::Error { fn is_recoverable(&self) -> bool { matches!( self.kind(), - std::io::ErrorKind::WouldBlock - | std::io::ErrorKind::Interrupted - | std::io::ErrorKind::TimedOut + std::io::ErrorKind::WouldBlock + | std::io::ErrorKind::Interrupted + | std::io::ErrorKind::TimedOut ) } - + fn can_continue(&self) -> bool { !matches!( self.kind(), std::io::ErrorKind::BrokenPipe - | std::io::ErrorKind::ConnectionReset - | std::io::ErrorKind::ConnectionAborted - | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::NotConnected ) } } @@ -165,96 +161,88 @@ impl Recoverable for std::io::Error { #[derive(Error, Debug)] pub enum ProxyError { // ============= Crypto Errors ============= - #[error("Crypto error: {0}")] Crypto(String), - + #[error("Invalid key length: expected {expected}, got {got}")] InvalidKeyLength { expected: usize, got: usize }, - + // ============= Stream Errors ============= - #[error("Stream error: {0}")] Stream(#[from] StreamError), - + // ============= Protocol Errors ============= - #[error("Invalid handshake: {0}")] InvalidHandshake(String), - + #[error("Invalid protocol tag: {0:02x?}")] InvalidProtoTag([u8; 4]), - + #[error("Invalid TLS record: type={record_type}, version={version:02x?}")] InvalidTlsRecord { record_type: u8, version: [u8; 2] }, - + #[error("Replay attack detected from {addr}")] ReplayAttack { addr: SocketAddr }, - + #[error("Time skew detected: client={client_time}, server={server_time}")] TimeSkew { client_time: u32, server_time: u32 }, - + #[error("Invalid message length: {len} (min={min}, max={max})")] InvalidMessageLength { len: usize, min: usize, max: usize }, - + #[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")] ChecksumMismatch { expected: u32, got: u32 }, - + #[error("Sequence number mismatch: expected={expected}, got={got}")] SeqNoMismatch { expected: i32, got: i32 }, - + #[error("TLS handshake failed: {reason}")] TlsHandshakeFailed { reason: String }, - + #[error("Telegram handshake timeout")] TgHandshakeTimeout, - + // ============= Network Errors ============= - #[error("Connection timeout to {addr}")] ConnectionTimeout { addr: String }, - + #[error("Connection refused by {addr}")] ConnectionRefused { addr: String }, - + #[error("IO error: {0}")] Io(#[from] std::io::Error), - + // ============= Proxy Protocol Errors ============= - #[error("Invalid proxy protocol header")] InvalidProxyProtocol, - + #[error("Proxy error: {0}")] Proxy(String), - + // ============= Config Errors ============= - #[error("Config error: {0}")] Config(String), - + #[error("Invalid secret for user {user}: {reason}")] InvalidSecret { user: String, reason: String }, - + // ============= User Errors ============= - #[error("User {user} expired")] UserExpired { user: String }, - + #[error("User {user} exceeded connection limit")] ConnectionLimitExceeded { user: String }, - + #[error("User {user} exceeded data quota")] DataQuotaExceeded { user: String }, - + #[error("Unknown user")] UnknownUser, - + #[error("Rate limited")] RateLimited, - + // ============= General Errors ============= - #[error("Internal error: {0}")] Internal(String), } @@ -269,7 +257,7 @@ impl Recoverable for ProxyError { _ => false, } } - + fn can_continue(&self) -> bool { match self { Self::Stream(e) => e.can_continue(), @@ -301,17 +289,19 @@ impl HandshakeResult { pub fn is_success(&self) -> bool { matches!(self, HandshakeResult::Success(_)) } - + /// Check if bad client pub fn is_bad_client(&self) -> bool { matches!(self, HandshakeResult::BadClient { .. }) } - + /// Map the success value pub fn map U>(self, f: F) -> HandshakeResult { match self { HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), - HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer }, + HandshakeResult::BadClient { reader, writer } => { + HandshakeResult::BadClient { reader, writer } + } HandshakeResult::Error(e) => HandshakeResult::Error(e), } } @@ -338,76 +328,104 @@ impl From for HandshakeResult { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_stream_error_display() { - let err = StreamError::PartialRead { expected: 100, got: 50 }; + let err = StreamError::PartialRead { + expected: 100, + got: 50, + }; assert!(err.to_string().contains("100")); assert!(err.to_string().contains("50")); - - let err = StreamError::Poisoned { reason: "test".into() }; + + let err = StreamError::Poisoned { + reason: "test".into(), + }; assert!(err.to_string().contains("test")); } - + #[test] fn test_stream_error_recoverable() { - assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable()); - assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable()); + assert!( + StreamError::PartialRead { + expected: 10, + got: 5 + } + .is_recoverable() + ); + assert!( + StreamError::PartialWrite { + expected: 10, + written: 5 + } + .is_recoverable() + ); assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable()); assert!(!StreamError::UnexpectedEof.is_recoverable()); } - + #[test] fn test_stream_error_can_continue() { assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue()); assert!(!StreamError::UnexpectedEof.can_continue()); - assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue()); + assert!( + StreamError::PartialRead { + expected: 10, + got: 5 + } + .can_continue() + ); } - + #[test] fn test_stream_error_to_io_error() { let stream_err = StreamError::UnexpectedEof; let io_err: std::io::Error = stream_err.into(); assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof); } - + #[test] fn test_handshake_result() { let success: HandshakeResult = HandshakeResult::Success(42); assert!(success.is_success()); assert!(!success.is_bad_client()); - - let bad: HandshakeResult = HandshakeResult::BadClient { reader: (), writer: () }; + + let bad: HandshakeResult = HandshakeResult::BadClient { + reader: (), + writer: (), + }; assert!(!bad.is_success()); assert!(bad.is_bad_client()); } - + #[test] fn test_handshake_result_map() { let success: HandshakeResult = HandshakeResult::Success(42); let mapped = success.map(|x| x * 2); - + match mapped { HandshakeResult::Success(v) => assert_eq!(v, 84), _ => panic!("Expected success"), } } - + #[test] fn test_proxy_error_recoverable() { let err = ProxyError::RateLimited; assert!(err.is_recoverable()); - + let err = ProxyError::InvalidHandshake("bad".into()); assert!(!err.is_recoverable()); } - + #[test] fn test_error_display() { - let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() }; + let err = ProxyError::ConnectionTimeout { + addr: "1.2.3.4:443".into(), + }; assert!(err.to_string().contains("1.2.3.4:443")); - + let err = ProxyError::InvalidProxyProtocol; assert!(err.to_string().contains("proxy protocol")); } -} \ No newline at end of file +} diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index fce20b6..76ea424 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -5,10 +5,11 @@ use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; +use std::sync::Mutex; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; -use tokio::sync::RwLock; +use tokio::sync::{Mutex as AsyncMutex, RwLock}; use crate::config::UserMaxUniqueIpsMode; @@ -21,6 +22,8 @@ pub struct UserIpTracker { limit_mode: Arc>, limit_window: Arc>, last_compact_epoch_secs: Arc, + cleanup_queue: Arc>>, + cleanup_drain_lock: Arc>, } impl UserIpTracker { @@ -33,6 +36,79 @@ impl UserIpTracker { limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), + cleanup_queue: Arc::new(Mutex::new(Vec::new())), + cleanup_drain_lock: Arc::new(AsyncMutex::new(())), + } + } + + pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { + match self.cleanup_queue.lock() { + Ok(mut queue) => queue.push((user, ip)), + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + queue.push((user.clone(), ip)); + self.cleanup_queue.clear_poison(); + tracing::warn!( + "UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})", + user, + ip + ); + } + } + } + + #[cfg(test)] + pub(crate) fn cleanup_queue_len_for_tests(&self) -> usize { + self.cleanup_queue + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .len() + } + + #[cfg(test)] + pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc>> { + Arc::clone(&self.cleanup_queue) + } + + pub(crate) async fn drain_cleanup_queue(&self) { + // Serialize queue draining and active-IP mutation so check-and-add cannot + // observe stale active entries that are already queued for removal. + let _drain_guard = self.cleanup_drain_lock.lock().await; + let to_remove = { + match self.cleanup_queue.lock() { + Ok(mut queue) => { + if queue.is_empty() { + return; + } + std::mem::take(&mut *queue) + } + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + if queue.is_empty() { + self.cleanup_queue.clear_poison(); + return; + } + let drained = std::mem::take(&mut *queue); + self.cleanup_queue.clear_poison(); + drained + } + } + }; + + let mut active_ips = self.active_ips.write().await; + for (user, ip) in to_remove { + if let Some(user_ips) = active_ips.get_mut(&user) { + if let Some(count) = user_ips.get_mut(&ip) { + if *count > 1 { + *count -= 1; + } else { + user_ips.remove(&ip); + } + } + if user_ips.is_empty() { + active_ips.remove(&user); + } + } } } @@ -65,7 +141,8 @@ impl UserIpTracker { let mut active_ips = self.active_ips.write().await; let mut recent_ips = self.recent_ips.write().await; - let mut users = Vec::::with_capacity(active_ips.len().saturating_add(recent_ips.len())); + let mut users = + Vec::::with_capacity(active_ips.len().saturating_add(recent_ips.len())); users.extend(active_ips.keys().cloned()); for user in recent_ips.keys() { if !active_ips.contains_key(user) { @@ -74,8 +151,14 @@ impl UserIpTracker { } for user in users { - let active_empty = active_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); - let recent_empty = recent_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); + let active_empty = active_ips + .get(&user) + .map(|ips| ips.is_empty()) + .unwrap_or(true); + let recent_empty = recent_ips + .get(&user) + .map(|ips| ips.is_empty()) + .unwrap_or(true); if active_empty && recent_empty { active_ips.remove(&user); recent_ips.remove(&user); @@ -118,6 +201,7 @@ impl UserIpTracker { } pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> { + self.drain_cleanup_queue().await; self.maybe_compact_empty_users().await; let default_max_ips = *self.default_max_ips.read().await; let limit = { @@ -194,6 +278,7 @@ impl UserIpTracker { } pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap { + self.drain_cleanup_queue().await; let window = *self.limit_window.read().await; let now = Instant::now(); let recent_ips = self.recent_ips.read().await; @@ -214,6 +299,7 @@ impl UserIpTracker { } pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; let mut out = HashMap::with_capacity(users.len()); for user in users { @@ -228,6 +314,7 @@ impl UserIpTracker { } pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; let window = *self.limit_window.read().await; let now = Instant::now(); let recent_ips = self.recent_ips.read().await; @@ -250,11 +337,13 @@ impl UserIpTracker { } pub async fn get_active_ip_count(&self, username: &str) -> usize { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips.get(username).map(|ips| ips.len()).unwrap_or(0) } pub async fn get_active_ips(&self, username: &str) -> Vec { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips .get(username) @@ -263,6 +352,7 @@ impl UserIpTracker { } pub async fn get_stats(&self) -> Vec<(String, usize, usize)> { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; let max_ips = self.max_ips.read().await; let default_max_ips = *self.default_max_ips.read().await; @@ -301,6 +391,7 @@ impl UserIpTracker { } pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips .get(username) diff --git a/src/maestro/connectivity.rs b/src/maestro/connectivity.rs index c843223..0cb561d 100644 --- a/src/maestro/connectivity.rs +++ b/src/maestro/connectivity.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::sync::Arc; use std::time::Instant; @@ -11,10 +13,10 @@ use crate::startup::{ COMPONENT_DC_CONNECTIVITY_PING, COMPONENT_ME_CONNECTIVITY_PING, COMPONENT_RUNTIME_READY, StartupTracker, }; +use crate::transport::UpstreamManager; use crate::transport::middle_proxy::{ MePingFamily, MePingSample, MePool, format_me_route, format_sample_line, run_me_ping, }; -use crate::transport::UpstreamManager; pub(crate) async fn run_startup_connectivity( config: &Arc, @@ -47,11 +49,15 @@ pub(crate) async fn run_startup_connectivity( let v4_ok = me_results.iter().any(|r| { matches!(r.family, MePingFamily::V4) - && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + && r.samples + .iter() + .any(|s| s.error.is_none() && s.handshake_ms.is_some()) }); let v6_ok = me_results.iter().any(|r| { matches!(r.family, MePingFamily::V6) - && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + && r.samples + .iter() + .any(|s| s.error.is_none() && s.handshake_ms.is_some()) }); info!("================= Telegram ME Connectivity ================="); @@ -131,8 +137,14 @@ pub(crate) async fn run_startup_connectivity( .await; for upstream_result in &ping_results { - let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some()); - let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some()); + let v6_works = upstream_result + .v6_results + .iter() + .any(|r| r.rtt_ms.is_some()); + let v4_works = upstream_result + .v4_results + .iter() + .any(|r| r.rtt_ms.is_some()); if upstream_result.both_available { if prefer_ipv6 { diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index f43e308..35f796f 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -1,5 +1,7 @@ -use std::time::Duration; +#![allow(clippy::items_after_test_module)] + use std::path::PathBuf; +use std::time::Duration; use tokio::sync::watch; use tracing::{debug, error, info, warn}; @@ -10,6 +12,19 @@ use crate::transport::middle_proxy::{ ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, }; +pub(crate) fn resolve_runtime_config_path( + config_path_cli: &str, + startup_cwd: &std::path::Path, +) -> PathBuf { + let raw = PathBuf::from(config_path_cli); + let absolute = if raw.is_absolute() { + raw + } else { + startup_cwd.join(raw) + }; + absolute.canonicalize().unwrap_or(absolute) +} + pub(crate) fn parse_cli() -> (String, Option, bool, Option) { let mut config_path = "config.toml".to_string(); let mut data_path: Option = None; @@ -40,7 +55,9 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { } } s if s.starts_with("--data-path=") => { - data_path = Some(PathBuf::from(s.trim_start_matches("--data-path=").to_string())); + data_path = Some(PathBuf::from( + s.trim_start_matches("--data-path=").to_string(), + )); } "--silent" | "-s" => { silent = true; @@ -58,7 +75,9 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { eprintln!("Usage: telemt [config.toml] [OPTIONS]"); eprintln!(); eprintln!("Options:"); - eprintln!(" --data-path Set data directory (absolute path; overrides config value)"); + eprintln!( + " --data-path Set data directory (absolute path; overrides config value)" + ); eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --log-level debug|verbose|normal|silent"); eprintln!(" --help, -h Show this help"); @@ -96,9 +115,52 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { (config_path, data_path, silent, log_level) } +#[cfg(test)] +mod tests { + use super::resolve_runtime_config_path; + + #[test] + fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + let target = startup_cwd.join("config.toml"); + std::fs::write(&target, " ").unwrap(); + + let resolved = resolve_runtime_config_path("config.toml", &startup_cwd); + assert_eq!(resolved, target.canonicalize().unwrap()); + + let _ = std::fs::remove_file(&target); + let _ = std::fs::remove_dir(&startup_cwd); + } + + #[test] + fn resolve_runtime_config_path_keeps_absolute_for_missing_file() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + + let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd); + assert_eq!(resolved, startup_cwd.join("missing.toml")); + + let _ = std::fs::remove_dir(&startup_cwd); + } +} + pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); - for user_name in config.general.links.show.resolve_users(&config.access.users) { + for user_name in config + .general + .links + .show + .resolve_users(&config.access.users) + { if let Some(secret) = config.access.users.get(user_name) { info!(target: "telemt::links", "User: {}", user_name); if config.general.modes.classic { @@ -239,7 +301,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot( return Some(cfg); } - warn!(snapshot = label, url, "Startup proxy-config is empty; trying disk cache"); + warn!( + snapshot = label, + url, "Startup proxy-config is empty; trying disk cache" + ); if let Some(path) = cache_path { match load_proxy_config_cache(path).await { Ok(cached) if !cached.map.is_empty() => { @@ -254,8 +319,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot( Ok(_) => { warn!( snapshot = label, - path, - "Startup proxy-config cache is empty; ignoring cache file" + path, "Startup proxy-config cache is empty; ignoring cache file" ); } Err(cache_err) => { @@ -299,8 +363,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot( Ok(_) => { warn!( snapshot = label, - path, - "Startup proxy-config cache is empty; ignoring cache file" + path, "Startup proxy-config cache is empty; ignoring cache file" ); } Err(cache_err) => { diff --git a/src/maestro/listeners.rs b/src/maestro/listeners.rs index fe041d9..effaff8 100644 --- a/src/maestro/listeners.rs +++ b/src/maestro/listeners.rs @@ -12,17 +12,15 @@ use tracing::{debug, error, info, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; -use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController}; use crate::proxy::ClientHandler; +use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController}; use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker}; use crate::stats::beobachten::BeobachtenStore; use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; use crate::tls_front::TlsFrontCache; use crate::transport::middle_proxy::MePool; -use crate::transport::{ - ListenOptions, UpstreamManager, create_listener, find_listener_processes, -}; +use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes}; use super::helpers::{is_expected_handshake_eof, print_proxy_links}; @@ -81,8 +79,9 @@ pub(crate) async fn bind_listeners( Ok(socket) => { let listener = TcpListener::from_std(socket.into())?; info!("Listening on {}", addr); - let listener_proxy_protocol = - listener_conf.proxy_protocol.unwrap_or(config.server.proxy_protocol); + let listener_proxy_protocol = listener_conf + .proxy_protocol + .unwrap_or(config.server.proxy_protocol); let public_host = if let Some(ref announce) = listener_conf.announce { announce.clone() @@ -100,8 +99,14 @@ pub(crate) async fn bind_listeners( listener_conf.ip.to_string() }; - if config.general.links.public_host.is_none() && !config.general.links.show.is_empty() { - let link_port = config.general.links.public_port.unwrap_or(config.server.port); + if config.general.links.public_host.is_none() + && !config.general.links.show.is_empty() + { + let link_port = config + .general + .links + .public_port + .unwrap_or(config.server.port); print_proxy_links(&public_host, link_port, config); } @@ -145,12 +150,14 @@ pub(crate) async fn bind_listeners( let (host, port) = if let Some(ref h) = config.general.links.public_host { ( h.clone(), - config.general.links.public_port.unwrap_or(config.server.port), + config + .general + .links + .public_port + .unwrap_or(config.server.port), ) } else { - let ip = detected_ip_v4 - .or(detected_ip_v6) - .map(|ip| ip.to_string()); + let ip = detected_ip_v4.or(detected_ip_v6).map(|ip| ip.to_string()); if ip.is_none() { warn!( "show_link is configured but public IP could not be detected. Set public_host in config." @@ -158,7 +165,11 @@ pub(crate) async fn bind_listeners( } ( ip.unwrap_or_else(|| "UNKNOWN".to_string()), - config.general.links.public_port.unwrap_or(config.server.port), + config + .general + .links + .public_port + .unwrap_or(config.server.port), ) }; @@ -178,13 +189,19 @@ pub(crate) async fn bind_listeners( use std::os::unix::fs::PermissionsExt; let perms = std::fs::Permissions::from_mode(mode); if let Err(e) = std::fs::set_permissions(unix_path, perms) { - error!("Failed to set unix socket permissions to {}: {}", perm_str, e); + error!( + "Failed to set unix socket permissions to {}: {}", + perm_str, e + ); } else { info!("Listening on unix:{} (mode {})", unix_path, perm_str); } } Err(e) => { - warn!("Invalid listen_unix_sock_perm '{}': {}. Ignoring.", perm_str, e); + warn!( + "Invalid listen_unix_sock_perm '{}': {}. Ignoring.", + perm_str, e + ); info!("Listening on unix:{}", unix_path); } } @@ -218,10 +235,8 @@ pub(crate) async fn bind_listeners( drop(stream); continue; } - let accept_permit_timeout_ms = config_rx_unix - .borrow() - .server - .accept_permit_timeout_ms; + let accept_permit_timeout_ms = + config_rx_unix.borrow().server.accept_permit_timeout_ms; let permit = if accept_permit_timeout_ms == 0 { match max_connections_unix.clone().acquire_owned().await { Ok(permit) => permit, @@ -361,10 +376,8 @@ pub(crate) fn spawn_tcp_accept_loops( drop(stream); continue; } - let accept_permit_timeout_ms = config_rx - .borrow() - .server - .accept_permit_timeout_ms; + let accept_permit_timeout_ms = + config_rx.borrow().server.accept_permit_timeout_ms; let permit = if accept_permit_timeout_ms == 0 { match max_connections_tcp.clone().acquire_owned().await { Ok(permit) => permit, diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index eb45cc4..022f8ae 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::sync::Arc; use std::time::Duration; @@ -12,8 +14,8 @@ use crate::startup::{ COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, StartupMeStatus, StartupTracker, }; use crate::stats::Stats; -use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; +use crate::transport::middle_proxy::MePool; use super::helpers::load_startup_proxy_config_snapshot; @@ -229,8 +231,12 @@ pub(crate) async fn initialize_me_pool( config.general.me_adaptive_floor_recover_grace_secs, config.general.me_adaptive_floor_writers_per_core_total, config.general.me_adaptive_floor_cpu_cores_override, - config.general.me_adaptive_floor_max_extra_writers_single_per_core, - config.general.me_adaptive_floor_max_extra_writers_multi_per_core, + config + .general + .me_adaptive_floor_max_extra_writers_single_per_core, + config + .general + .me_adaptive_floor_max_extra_writers_multi_per_core, config.general.me_adaptive_floor_max_active_writers_per_core, config.general.me_adaptive_floor_max_warm_writers_per_core, config.general.me_adaptive_floor_max_active_writers_global, @@ -268,8 +274,6 @@ pub(crate) async fn initialize_me_pool( config.general.me_warn_rate_limit_ms, config.general.me_route_no_writer_mode, config.general.me_route_no_writer_wait_ms, - config.general.me_route_hybrid_max_wait_ms, - config.general.me_route_blocking_send_timeout_ms, config.general.me_route_inline_recovery_attempts, config.general.me_route_inline_recovery_wait_ms, ); @@ -459,65 +463,71 @@ pub(crate) async fn initialize_me_pool( "Middle-End pool initialized successfully" ); - // ── Supervised background tasks ────────────────── - let pool_clone = pool.clone(); - let rng_clone = rng.clone(); - let min_conns = pool_size; - tokio::spawn(async move { - loop { - let p = pool_clone.clone(); - let r = rng_clone.clone(); - let res = tokio::spawn(async move { - crate::transport::middle_proxy::me_health_monitor( - p, r, min_conns, - ) - .await; - }) + // ── Supervised background tasks ────────────────── + let pool_clone = pool.clone(); + let rng_clone = rng.clone(); + let min_conns = pool_size; + tokio::spawn(async move { + loop { + let p = pool_clone.clone(); + let r = rng_clone.clone(); + let res = tokio::spawn(async move { + crate::transport::middle_proxy::me_health_monitor( + p, r, min_conns, + ) .await; - match res { - Ok(()) => warn!("me_health_monitor exited unexpectedly, restarting"), - Err(e) => { - error!(error = %e, "me_health_monitor panicked, restarting in 1s"); - tokio::time::sleep(Duration::from_secs(1)).await; - } + }) + .await; + match res { + Ok(()) => warn!( + "me_health_monitor exited unexpectedly, restarting" + ), + Err(e) => { + error!(error = %e, "me_health_monitor panicked, restarting in 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; } } - }); - let pool_drain_enforcer = pool.clone(); - tokio::spawn(async move { - loop { - let p = pool_drain_enforcer.clone(); - let res = tokio::spawn(async move { + } + }); + let pool_drain_enforcer = pool.clone(); + tokio::spawn(async move { + loop { + let p = pool_drain_enforcer.clone(); + let res = tokio::spawn(async move { crate::transport::middle_proxy::me_drain_timeout_enforcer(p).await; }) .await; - match res { - Ok(()) => warn!("me_drain_timeout_enforcer exited unexpectedly, restarting"), - Err(e) => { - error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s"); - tokio::time::sleep(Duration::from_secs(1)).await; - } + match res { + Ok(()) => warn!( + "me_drain_timeout_enforcer exited unexpectedly, restarting" + ), + Err(e) => { + error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; } } - }); - let pool_watchdog = pool.clone(); - tokio::spawn(async move { - loop { - let p = pool_watchdog.clone(); - let res = tokio::spawn(async move { + } + }); + let pool_watchdog = pool.clone(); + tokio::spawn(async move { + loop { + let p = pool_watchdog.clone(); + let res = tokio::spawn(async move { crate::transport::middle_proxy::me_zombie_writer_watchdog(p).await; }) .await; - match res { - Ok(()) => warn!("me_zombie_writer_watchdog exited unexpectedly, restarting"), - Err(e) => { - error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s"); - tokio::time::sleep(Duration::from_secs(1)).await; - } + match res { + Ok(()) => warn!( + "me_zombie_writer_watchdog exited unexpectedly, restarting" + ), + Err(e) => { + error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; } } - }); - + } + }); + break Some(pool); } Err(e) => { diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index dce421c..7d3b168 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -11,9 +11,9 @@ // - admission: conditional-cast gate and route mode switching. // - listeners: TCP/Unix listener bind and accept-loop orchestration. // - shutdown: graceful shutdown sequence and uptime logging. -mod helpers; mod admission; mod connectivity; +mod helpers; mod listeners; mod me_startup; mod runtime_tasks; @@ -33,19 +33,19 @@ use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe}; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::startup::{ + COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, COMPONENT_ME_POOL_CONSTRUCT, + COMPONENT_ME_POOL_INIT_STAGE1, COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6, + COMPONENT_ME_SECRET_FETCH, COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus, + StartupTracker, +}; use crate::stats::beobachten::BeobachtenStore; use crate::stats::telemetry::TelemetryPolicy; use crate::stats::{ReplayChecker, Stats}; -use crate::startup::{ - COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, - COMPONENT_ME_POOL_CONSTRUCT, COMPONENT_ME_POOL_INIT_STAGE1, - COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, - COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus, StartupTracker, -}; use crate::stream::BufferPool; -use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; -use helpers::parse_cli; +use crate::transport::middle_proxy::MePool; +use helpers::{parse_cli, resolve_runtime_config_path}; /// Runs the full telemt runtime startup pipeline and blocks until shutdown. pub async fn run() -> std::result::Result<(), Box> { @@ -56,20 +56,34 @@ pub async fn run() -> std::result::Result<(), Box> { .as_secs(); let startup_tracker = Arc::new(StartupTracker::new(process_started_at_epoch_secs)); startup_tracker - .start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string())) + .start_component( + COMPONENT_CONFIG_LOAD, + Some("load and validate config".to_string()), + ) .await; - let (config_path, data_path, cli_silent, cli_log_level) = parse_cli(); + let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); + let startup_cwd = match std::env::current_dir() { + Ok(cwd) => cwd, + Err(e) => { + eprintln!("[telemt] Can't read current_dir: {}", e); + std::process::exit(1); + } + }; + let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd); let mut config = match ProxyConfig::load(&config_path) { Ok(c) => c, Err(e) => { - if std::path::Path::new(&config_path).exists() { + if config_path.exists() { eprintln!("[telemt] Error: {}", e); std::process::exit(1); } else { let default = ProxyConfig::default(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); - eprintln!("[telemt] Created default config at {}", config_path); + eprintln!( + "[telemt] Created default config at {}", + config_path.display() + ); default } } @@ -86,24 +100,38 @@ pub async fn run() -> std::result::Result<(), Box> { if let Some(ref data_path) = config.general.data_path { if !data_path.is_absolute() { - eprintln!("[telemt] data_path must be absolute: {}", data_path.display()); + eprintln!( + "[telemt] data_path must be absolute: {}", + data_path.display() + ); std::process::exit(1); } if data_path.exists() { if !data_path.is_dir() { - eprintln!("[telemt] data_path exists but is not a directory: {}", data_path.display()); + eprintln!( + "[telemt] data_path exists but is not a directory: {}", + data_path.display() + ); std::process::exit(1); } } else { if let Err(e) = std::fs::create_dir_all(data_path) { - eprintln!("[telemt] Can't create data_path {}: {}", data_path.display(), e); + eprintln!( + "[telemt] Can't create data_path {}: {}", + data_path.display(), + e + ); std::process::exit(1); } } if let Err(e) = std::env::set_current_dir(data_path) { - eprintln!("[telemt] Can't use data_path {}: {}", data_path.display(), e); + eprintln!( + "[telemt] Can't use data_path {}: {}", + data_path.display(), + e + ); std::process::exit(1); } } @@ -127,7 +155,10 @@ pub async fn run() -> std::result::Result<(), Box> { let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info")); startup_tracker - .start_component(COMPONENT_TRACING_INIT, Some("initialize tracing subscriber".to_string())) + .start_component( + COMPONENT_TRACING_INIT, + Some("initialize tracing subscriber".to_string()), + ) .await; // Configure color output based on config @@ -142,7 +173,10 @@ pub async fn run() -> std::result::Result<(), Box> { .with(fmt_layer) .init(); startup_tracker - .complete_component(COMPONENT_TRACING_INIT, Some("tracing initialized".to_string())) + .complete_component( + COMPONENT_TRACING_INIT, + Some("tracing initialized".to_string()), + ) .await; info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); @@ -208,7 +242,8 @@ pub async fn run() -> std::result::Result<(), Box> { config.access.user_max_unique_ips_window_secs, ) .await; - if config.access.user_max_unique_ips_global_each > 0 || !config.access.user_max_unique_ips.is_empty() + if config.access.user_max_unique_ips_global_each > 0 + || !config.access.user_max_unique_ips.is_empty() { info!( global_each_limit = config.access.user_max_unique_ips_global_each, @@ -235,7 +270,10 @@ pub async fn run() -> std::result::Result<(), Box> { let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode)); let api_me_pool = Arc::new(RwLock::new(None::>)); startup_tracker - .start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string())) + .start_component( + COMPONENT_API_BOOTSTRAP, + Some("spawn API listener task".to_string()), + ) .await; if config.server.api.enabled { @@ -258,7 +296,7 @@ pub async fn run() -> std::result::Result<(), Box> { let route_runtime_api = route_runtime.clone(); let config_rx_api = api_config_rx.clone(); let admission_rx_api = admission_rx.clone(); - let config_path_api = std::path::PathBuf::from(&config_path); + let config_path_api = config_path.clone(); let startup_tracker_api = startup_tracker.clone(); let detected_ips_rx_api = detected_ips_rx.clone(); tokio::spawn(async move { @@ -318,7 +356,10 @@ pub async fn run() -> std::result::Result<(), Box> { .await; startup_tracker - .start_component(COMPONENT_NETWORK_PROBE, Some("probe network capabilities".to_string())) + .start_component( + COMPONENT_NETWORK_PROBE, + Some("probe network capabilities".to_string()), + ) .await; let probe = run_probe( &config.network, @@ -331,11 +372,8 @@ pub async fn run() -> std::result::Result<(), Box> { probe.detected_ipv4.map(IpAddr::V4), probe.detected_ipv6.map(IpAddr::V6), )); - let decision = decide_network_capabilities( - &config.network, - &probe, - config.general.middle_proxy_nat_ip, - ); + let decision = + decide_network_capabilities(&config.network, &probe, config.general.middle_proxy_nat_ip); log_probe_result(&probe, &decision); startup_tracker .complete_component( @@ -438,24 +476,16 @@ pub async fn run() -> std::result::Result<(), Box> { // If ME failed to initialize, force direct-only mode. if me_pool.is_some() { - startup_tracker - .set_transport_mode("middle_proxy") - .await; - startup_tracker - .set_degraded(false) - .await; + startup_tracker.set_transport_mode("middle_proxy").await; + startup_tracker.set_degraded(false).await; info!("Transport: Middle-End Proxy - all DC-over-RPC"); } else { let _ = use_middle_proxy; use_middle_proxy = false; // Make runtime config reflect direct-only mode for handlers. config.general.use_middle_proxy = false; - startup_tracker - .set_transport_mode("direct") - .await; - startup_tracker - .set_degraded(true) - .await; + startup_tracker.set_transport_mode("direct").await; + startup_tracker.set_degraded(true).await; if me2dc_fallback { startup_tracker .set_me_status(StartupMeStatus::Failed, "fallback_to_direct") diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d9691a8..d553eb9 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -1,24 +1,27 @@ use std::net::IpAddr; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use tokio::sync::{mpsc, watch}; use tracing::{debug, warn}; -use tracing_subscriber::reload; use tracing_subscriber::EnvFilter; +use tracing_subscriber::reload; -use crate::config::{LogLevel, ProxyConfig}; use crate::config::hot_reload::spawn_config_watcher; +use crate::config::{LogLevel, ProxyConfig}; use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::metrics; use crate::network::probe::NetworkProbe; -use crate::startup::{COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, StartupTracker}; +use crate::startup::{ + COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, + StartupTracker, +}; use crate::stats::beobachten::BeobachtenStore; use crate::stats::telemetry::TelemetryPolicy; use crate::stats::{ReplayChecker, Stats}; -use crate::transport::middle_proxy::{MePool, MeReinitTrigger}; use crate::transport::UpstreamManager; +use crate::transport::middle_proxy::{MePool, MeReinitTrigger}; use super::helpers::write_beobachten_snapshot; @@ -32,7 +35,7 @@ pub(crate) struct RuntimeWatches { #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, - config_path: &str, + config_path: &Path, probe: &NetworkProbe, prefer_ipv6: bool, decision_ipv4_dc: bool, @@ -79,15 +82,13 @@ pub(crate) async fn spawn_runtime_tasks( Some("spawn config hot-reload watcher".to_string()), ) .await; - let (config_rx, log_level_rx): ( - watch::Receiver>, - watch::Receiver, - ) = spawn_config_watcher( - PathBuf::from(config_path), - config.clone(), - detected_ip_v4, - detected_ip_v6, - ); + let (config_rx, log_level_rx): (watch::Receiver>, watch::Receiver) = + spawn_config_watcher( + config_path.to_path_buf(), + config.clone(), + detected_ip_v4, + detected_ip_v6, + ); startup_tracker .complete_component( COMPONENT_CONFIG_WATCHER_START, @@ -114,7 +115,8 @@ pub(crate) async fn spawn_runtime_tasks( break; } let cfg = config_rx_policy.borrow_and_update().clone(); - stats_policy.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry)); + stats_policy + .apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry)); if let Some(pool) = &me_pool_for_policy { pool.update_runtime_transport_policy( cfg.general.me_socks_kdf_policy, @@ -130,7 +132,11 @@ pub(crate) async fn spawn_runtime_tasks( let ip_tracker_policy = ip_tracker.clone(); let mut config_rx_ip_limits = config_rx.clone(); tokio::spawn(async move { - let mut prev_limits = config_rx_ip_limits.borrow().access.user_max_unique_ips.clone(); + let mut prev_limits = config_rx_ip_limits + .borrow() + .access + .user_max_unique_ips + .clone(); let mut prev_global_each = config_rx_ip_limits .borrow() .access @@ -183,7 +189,9 @@ pub(crate) async fn spawn_runtime_tasks( let sleep_secs = cfg.general.beobachten_flush_secs.max(1); if cfg.general.beobachten { - let ttl = std::time::Duration::from_secs(cfg.general.beobachten_minutes.saturating_mul(60)); + let ttl = std::time::Duration::from_secs( + cfg.general.beobachten_minutes.saturating_mul(60), + ); let path = cfg.general.beobachten_file.clone(); let snapshot = beobachten_writer.snapshot_text(ttl); if let Err(e) = write_beobachten_snapshot(&path, &snapshot).await { @@ -227,8 +235,11 @@ pub(crate) async fn spawn_runtime_tasks( let config_rx_clone_rot = config_rx.clone(); let reinit_tx_rotation = reinit_tx.clone(); tokio::spawn(async move { - crate::transport::middle_proxy::me_rotation_task(config_rx_clone_rot, reinit_tx_rotation) - .await; + crate::transport::middle_proxy::me_rotation_task( + config_rx_clone_rot, + reinit_tx_rotation, + ) + .await; }); } diff --git a/src/maestro/shutdown.rs b/src/maestro/shutdown.rs index b73df30..243c772 100644 --- a/src/maestro/shutdown.rs +++ b/src/maestro/shutdown.rs @@ -16,8 +16,11 @@ pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Opti let uptime_secs = process_started_at.elapsed().as_secs(); info!("Uptime: {}", format_uptime(uptime_secs)); if let Some(pool) = &me_pool { - match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()) - .await + match tokio::time::timeout( + Duration::from_secs(2), + pool.shutdown_send_close_conn_all(), + ) + .await { Ok(total) => { info!( diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 73eec4c..342a2f9 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::Duration; -use rand::Rng; +use rand::RngExt; use tracing::warn; use crate::config::ProxyConfig; diff --git a/src/main.rs b/src/main.rs index 2cfbe28..c512e6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,13 @@ mod crypto; mod error; mod ip_tracker; #[cfg(test)] +#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] +mod ip_tracker_hotpath_adversarial_tests; +#[cfg(test)] +#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"] +mod ip_tracker_encapsulation_adversarial_tests; +#[cfg(test)] +#[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; mod maestro; mod metrics; diff --git a/src/metrics.rs b/src/metrics.rs index b7272b2..a821d4d 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,5 +1,5 @@ -use std::convert::Infallible; use std::collections::{BTreeSet, HashMap}; +use std::convert::Infallible; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -11,14 +11,12 @@ use hyper::service::service_fn; use hyper::{Request, Response, StatusCode}; use ipnetwork::IpNetwork; use tokio::net::TcpListener; -use tracing::{info, warn, debug}; +use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; use crate::stats::beobachten::BeobachtenStore; -use crate::stats::{ - MeWriterCleanupSideEffectStep, MeWriterTeardownMode, MeWriterTeardownReason, Stats, -}; use crate::transport::{ListenOptions, create_listener}; pub async fn serve( @@ -64,7 +62,10 @@ pub async fn serve( let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port)); match bind_metrics_listener(addr_v4, false) { Ok(listener) => { - info!("Metrics endpoint: http://{}/metrics and /beobachten", addr_v4); + info!( + "Metrics endpoint: http://{}/metrics and /beobachten", + addr_v4 + ); listener_v4 = Some(listener); } Err(e) => { @@ -75,7 +76,10 @@ pub async fn serve( let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port)); match bind_metrics_listener(addr_v6, true) { Ok(listener) => { - info!("Metrics endpoint: http://[::]:{}/metrics and /beobachten", port); + info!( + "Metrics endpoint: http://[::]:{}/metrics and /beobachten", + port + ); listener_v6 = Some(listener); } Err(e) => { @@ -111,12 +115,7 @@ pub async fn serve( .await; }); serve_listener( - listener4, - stats, - beobachten, - ip_tracker, - config_rx, - whitelist, + listener4, stats, beobachten, ip_tracker, config_rx, whitelist, ) .await; } @@ -233,7 +232,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge"); let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs()); - let _ = writeln!(out, "# HELP telemt_telemetry_core_enabled Runtime core telemetry switch"); + let _ = writeln!( + out, + "# HELP telemt_telemetry_core_enabled Runtime core telemetry switch" + ); let _ = writeln!(out, "# TYPE telemt_telemetry_core_enabled gauge"); let _ = writeln!( out, @@ -241,7 +243,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp if core_enabled { 1 } else { 0 } ); - let _ = writeln!(out, "# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch"); + let _ = writeln!( + out, + "# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch" + ); let _ = writeln!(out, "# TYPE telemt_telemetry_user_enabled gauge"); let _ = writeln!( out, @@ -249,7 +254,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp if user_enabled { 1 } else { 0 } ); - let _ = writeln!(out, "# HELP telemt_telemetry_me_level Runtime ME telemetry level flag"); + let _ = writeln!( + out, + "# HELP telemt_telemetry_me_level Runtime ME telemetry level flag" + ); let _ = writeln!(out, "# TYPE telemt_telemetry_me_level gauge"); let _ = writeln!( out, @@ -279,126 +287,40 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_connections_total Total accepted connections"); + let _ = writeln!( + out, + "# HELP telemt_connections_total Total accepted connections" + ); let _ = writeln!(out, "# TYPE telemt_connections_total counter"); let _ = writeln!( out, "telemt_connections_total {}", - if core_enabled { stats.get_connects_all() } else { 0 } + if core_enabled { + stats.get_connects_all() + } else { + 0 + } ); - let _ = writeln!(out, "# HELP telemt_connections_bad_total Bad/rejected connections"); + let _ = writeln!( + out, + "# HELP telemt_connections_bad_total Bad/rejected connections" + ); let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter"); let _ = writeln!( out, "telemt_connections_bad_total {}", - if core_enabled { stats.get_connects_bad() } else { 0 } - ); - let _ = writeln!(out, "# HELP telemt_connections_current Current active connections"); - let _ = writeln!(out, "# TYPE telemt_connections_current gauge"); - let _ = writeln!( - out, - "telemt_connections_current {}", if core_enabled { - stats.get_current_connections_total() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_connections_direct_current Current active direct connections"); - let _ = writeln!(out, "# TYPE telemt_connections_direct_current gauge"); - let _ = writeln!( - out, - "telemt_connections_direct_current {}", - if core_enabled { - stats.get_current_connections_direct() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_connections_me_current Current active middle-end connections"); - let _ = writeln!(out, "# TYPE telemt_connections_me_current gauge"); - let _ = writeln!( - out, - "telemt_connections_me_current {}", - if core_enabled { - stats.get_current_connections_me() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_relay_adaptive_promotions_total Adaptive relay tier promotions" - ); - let _ = writeln!(out, "# TYPE telemt_relay_adaptive_promotions_total counter"); - let _ = writeln!( - out, - "telemt_relay_adaptive_promotions_total {}", - if core_enabled { - stats.get_relay_adaptive_promotions_total() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_relay_adaptive_demotions_total Adaptive relay tier demotions" - ); - let _ = writeln!(out, "# TYPE telemt_relay_adaptive_demotions_total counter"); - let _ = writeln!( - out, - "telemt_relay_adaptive_demotions_total {}", - if core_enabled { - stats.get_relay_adaptive_demotions_total() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_relay_adaptive_hard_promotions_total Adaptive relay hard promotions triggered by write pressure" - ); - let _ = writeln!( - out, - "# TYPE telemt_relay_adaptive_hard_promotions_total counter" - ); - let _ = writeln!( - out, - "telemt_relay_adaptive_hard_promotions_total {}", - if core_enabled { - stats.get_relay_adaptive_hard_promotions_total() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_reconnect_evict_total Reconnect-driven session evictions"); - let _ = writeln!(out, "# TYPE telemt_reconnect_evict_total counter"); - let _ = writeln!( - out, - "telemt_reconnect_evict_total {}", - if core_enabled { - stats.get_reconnect_evict_total() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_reconnect_stale_close_total Sessions closed because they became stale after reconnect" - ); - let _ = writeln!(out, "# TYPE telemt_reconnect_stale_close_total counter"); - let _ = writeln!( - out, - "telemt_reconnect_stale_close_total {}", - if core_enabled { - stats.get_reconnect_stale_close_total() + stats.get_connects_bad() } else { 0 } ); - let _ = writeln!(out, "# HELP telemt_handshake_timeouts_total Handshake timeouts"); + let _ = writeln!( + out, + "# HELP telemt_handshake_timeouts_total Handshake timeouts" + ); let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter"); let _ = writeln!( out, @@ -477,7 +399,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_upstream_connect_attempts_per_request Histogram-like buckets for attempts per upstream connect request cycle" ); - let _ = writeln!(out, "# TYPE telemt_upstream_connect_attempts_per_request counter"); + let _ = writeln!( + out, + "# TYPE telemt_upstream_connect_attempts_per_request counter" + ); let _ = writeln!( out, "telemt_upstream_connect_attempts_per_request{{bucket=\"1\"}} {}", @@ -519,7 +444,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_upstream_connect_duration_success_total Histogram-like buckets of successful upstream connect cycle duration" ); - let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_success_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_upstream_connect_duration_success_total counter" + ); let _ = writeln!( out, "telemt_upstream_connect_duration_success_total{{bucket=\"le_100ms\"}} {}", @@ -561,7 +489,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_upstream_connect_duration_fail_total Histogram-like buckets of failed upstream connect cycle duration" ); - let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_fail_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_upstream_connect_duration_fail_total counter" + ); let _ = writeln!( out, "telemt_upstream_connect_duration_fail_total{{bucket=\"le_100ms\"}} {}", @@ -599,7 +530,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_sent_total ME keepalive frames sent"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_sent_total ME keepalive frames sent" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_sent_total counter"); let _ = writeln!( out, @@ -611,7 +545,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_failed_total ME keepalive send failures"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_failed_total ME keepalive send failures" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!( out, @@ -623,7 +560,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter"); let _ = writeln!( out, @@ -635,7 +575,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter"); let _ = writeln!( out, @@ -651,7 +594,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_rpc_proxy_req_signal_sent_total Service RPC_PROXY_REQ activity signals sent" ); - let _ = writeln!(out, "# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter" + ); let _ = writeln!( out, "telemt_me_rpc_proxy_req_signal_sent_total {}", @@ -734,7 +680,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"); + let _ = writeln!( + out, + "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts" + ); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!( out, @@ -746,7 +695,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_reconnect_success_total ME reconnect successes"); + let _ = writeln!( + out, + "# HELP telemt_me_reconnect_success_total ME reconnect successes" + ); let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!( out, @@ -758,7 +710,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream"); + let _ = writeln!( + out, + "# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream" + ); let _ = writeln!(out, "# TYPE telemt_me_handshake_reject_total counter"); let _ = writeln!( out, @@ -770,20 +725,25 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code"); + let _ = writeln!( + out, + "# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code" + ); let _ = writeln!(out, "# TYPE telemt_me_handshake_error_code_total counter"); if me_allows_normal { for (error_code, count) in stats.get_me_handshake_error_code_counts() { let _ = writeln!( out, "telemt_me_handshake_error_code_total{{error_code=\"{}\"}} {}", - error_code, - count + error_code, count ); } } - let _ = writeln!(out, "# HELP telemt_me_reader_eof_total ME reader EOF terminations"); + let _ = writeln!( + out, + "# HELP telemt_me_reader_eof_total ME reader EOF terminations" + ); let _ = writeln!(out, "# TYPE telemt_me_reader_eof_total counter"); let _ = writeln!( out, @@ -810,6 +770,69 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_relay_idle_soft_mark_total Middle-relay sessions marked as soft-idle candidates" + ); + let _ = writeln!(out, "# TYPE telemt_relay_idle_soft_mark_total counter"); + let _ = writeln!( + out, + "telemt_relay_idle_soft_mark_total {}", + if me_allows_normal { + stats.get_relay_idle_soft_mark_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_idle_hard_close_total Middle-relay sessions closed by hard-idle policy" + ); + let _ = writeln!(out, "# TYPE telemt_relay_idle_hard_close_total counter"); + let _ = writeln!( + out, + "telemt_relay_idle_hard_close_total {}", + if me_allows_normal { + stats.get_relay_idle_hard_close_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_pressure_evict_total Middle-relay sessions evicted under resource pressure" + ); + let _ = writeln!(out, "# TYPE telemt_relay_pressure_evict_total counter"); + let _ = writeln!( + out, + "telemt_relay_pressure_evict_total {}", + if me_allows_normal { + stats.get_relay_pressure_evict_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_protocol_desync_close_total Middle-relay sessions closed due to protocol desync" + ); + let _ = writeln!( + out, + "# TYPE telemt_relay_protocol_desync_close_total counter" + ); + let _ = writeln!( + out, + "telemt_relay_protocol_desync_close_total {}", + if me_allows_normal { + stats.get_relay_protocol_desync_close_total() + } else { + 0 + } + ); + let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches"); let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter"); let _ = writeln!( @@ -822,7 +845,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches"); + let _ = writeln!( + out, + "# HELP telemt_me_seq_mismatch_total ME sequence mismatches" + ); let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter"); let _ = writeln!( out, @@ -834,7 +860,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn"); + let _ = writeln!( + out, + "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn" + ); let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter"); let _ = writeln!( out, @@ -846,8 +875,14 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed"); - let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter"); + let _ = writeln!( + out, + "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_route_drop_channel_closed_total counter" + ); let _ = writeln!( out, "telemt_me_route_drop_channel_closed_total {}", @@ -858,7 +893,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full"); + let _ = writeln!( + out, + "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full" + ); let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter"); let _ = writeln!( out, @@ -897,6 +935,462 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batches_total Total DC->Client flush batches" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batches_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_batches_total {}", + if me_allows_normal { + stats.get_me_d2c_batches_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_frames_total Total DC->Client frames flushed in batches" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_frames_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_total {}", + if me_allows_normal { + stats.get_me_d2c_batch_frames_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_bytes_total Total DC->Client bytes flushed in batches" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_total {}", + if me_allows_normal { + stats.get_me_d2c_batch_bytes_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_flush_reason_total DC->Client flush reasons" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_flush_reason_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"queue_drain\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_queue_drain_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"batch_frames\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_batch_frames_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"batch_bytes\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_batch_bytes_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"max_delay\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_max_delay_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"ack_immediate\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_ack_immediate_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"close\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_close_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_data_frames_total DC->Client data frames" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_data_frames_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_data_frames_total {}", + if me_allows_normal { + stats.get_me_d2c_data_frames_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_ack_frames_total DC->Client quick-ack frames" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_ack_frames_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_ack_frames_total {}", + if me_allows_normal { + stats.get_me_d2c_ack_frames_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_payload_bytes_total DC->Client payload bytes before transport framing" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_payload_bytes_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_payload_bytes_total {}", + if me_allows_normal { + stats.get_me_d2c_payload_bytes_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_write_mode_total DC->Client writer mode selection" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_write_mode_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_write_mode_total{{mode=\"coalesced\"}} {}", + if me_allows_normal { + stats.get_me_d2c_write_mode_coalesced_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_write_mode_total{{mode=\"split\"}} {}", + if me_allows_normal { + stats.get_me_d2c_write_mode_split_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_quota_reject_total DC->Client quota rejects" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_quota_reject_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_quota_reject_total{{stage=\"pre_write\"}} {}", + if me_allows_normal { + stats.get_me_d2c_quota_reject_pre_write_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_quota_reject_total{{stage=\"post_write\"}} {}", + if me_allows_normal { + stats.get_me_d2c_quota_reject_post_write_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_frame_buf_shrink_total DC->Client reusable frame buffer shrink events" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_frame_buf_shrink_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_frame_buf_shrink_total {}", + if me_allows_normal { + stats.get_me_d2c_frame_buf_shrink_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_frame_buf_shrink_bytes_total DC->Client reusable frame buffer bytes released" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_frame_buf_shrink_bytes_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_frame_buf_shrink_bytes_total {}", + if me_allows_normal { + stats.get_me_d2c_frame_buf_shrink_bytes_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_frames_bucket_total DC->Client batch frame count buckets" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_batch_frames_bucket_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"1\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_1() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"2_4\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_2_4() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"5_8\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_5_8() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"9_16\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_9_16() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"17_32\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_17_32() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"gt_32\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_gt_32() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_0_1k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"1k_4k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_1k_4k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"4k_16k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_4k_16k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"16k_64k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_16k_64k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"64k_128k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_64k_128k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"gt_128k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_gt_128k() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_flush_duration_us_bucket_total DC->Client flush duration buckets" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_flush_duration_us_bucket_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"0_50\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_0_50() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"51_200\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_51_200() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"201_1000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_201_1000() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"1001_5000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_1001_5000() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"5001_20000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_5001_20000() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"gt_20000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_gt_20000() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_timeout_armed_total DC->Client max-delay timer armed events" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_batch_timeout_armed_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_timeout_armed_total {}", + if me_allows_debug { + stats.get_me_d2c_batch_timeout_armed_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_timeout_fired_total DC->Client max-delay timer fired events" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_batch_timeout_fired_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_timeout_fired_total {}", + if me_allows_debug { + stats.get_me_d2c_batch_timeout_fired_total() + } else { + 0 + } + ); + let _ = writeln!( out, "# HELP telemt_me_writer_pick_total ME writer-pick outcomes by mode and result" @@ -1015,7 +1509,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_pick_mode_switch_total Writer-pick mode switches via runtime updates" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_pick_mode_switch_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_pick_mode_switch_total counter" + ); let _ = writeln!( out, "telemt_me_writer_pick_mode_switch_total {}", @@ -1065,7 +1562,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_kdf_drift_total ME KDF input drift detections"); + let _ = writeln!( + out, + "# HELP telemt_me_kdf_drift_total ME KDF input drift detections" + ); let _ = writeln!(out, "# TYPE telemt_me_kdf_drift_total counter"); let _ = writeln!( out, @@ -1111,7 +1611,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_hardswap_pending_ttl_expired_total Pending hardswap generations reset by TTL expiration" ); - let _ = writeln!(out, "# TYPE telemt_me_hardswap_pending_ttl_expired_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_hardswap_pending_ttl_expired_total counter" + ); let _ = writeln!( out, "telemt_me_hardswap_pending_ttl_expired_total {}", @@ -1343,10 +1846,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_adaptive_floor_global_cap_raw Runtime raw global adaptive floor cap" ); - let _ = writeln!( - out, - "# TYPE telemt_me_adaptive_floor_global_cap_raw gauge" - ); + let _ = writeln!(out, "# TYPE telemt_me_adaptive_floor_global_cap_raw gauge"); let _ = writeln!( out, "telemt_me_adaptive_floor_global_cap_raw {}", @@ -1529,7 +2029,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"); + let _ = writeln!( + out, + "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths" + ); let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter"); let _ = writeln!( out, @@ -1541,7 +2044,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_total Total crypto-desync detections"); + let _ = writeln!( + out, + "# HELP telemt_desync_total Total crypto-desync detections" + ); let _ = writeln!(out, "# TYPE telemt_desync_total counter"); let _ = writeln!( out, @@ -1553,7 +2059,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_full_logged_total Full forensic desync logs emitted"); + let _ = writeln!( + out, + "# HELP telemt_desync_full_logged_total Full forensic desync logs emitted" + ); let _ = writeln!(out, "# TYPE telemt_desync_full_logged_total counter"); let _ = writeln!( out, @@ -1565,7 +2074,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_suppressed_total Suppressed desync forensic events"); + let _ = writeln!( + out, + "# HELP telemt_desync_suppressed_total Suppressed desync forensic events" + ); let _ = writeln!(out, "# TYPE telemt_desync_suppressed_total counter"); let _ = writeln!( out, @@ -1577,7 +2089,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket"); + let _ = writeln!( + out, + "# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket" + ); let _ = writeln!(out, "# TYPE telemt_desync_frames_bucket_total counter"); let _ = writeln!( out, @@ -1616,7 +2131,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_pool_swap_total Successful ME pool swaps"); + let _ = writeln!( + out, + "# HELP telemt_pool_swap_total Successful ME pool swaps" + ); let _ = writeln!(out, "# TYPE telemt_pool_swap_total counter"); let _ = writeln!( out, @@ -1628,7 +2146,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_pool_drain_active Active draining ME writers"); + let _ = writeln!( + out, + "# HELP telemt_pool_drain_active Active draining ME writers" + ); let _ = writeln!(out, "# TYPE telemt_pool_drain_active gauge"); let _ = writeln!( out, @@ -1640,7 +2161,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_pool_force_close_total Forced close events for draining writers"); + let _ = writeln!( + out, + "# HELP telemt_pool_force_close_total Forced close events for draining writers" + ); let _ = writeln!(out, "# TYPE telemt_pool_force_close_total counter"); let _ = writeln!( out, @@ -1654,35 +2178,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp let _ = writeln!( out, - "# HELP telemt_pool_drain_soft_evict_total Soft-evicted client sessions on stuck draining writers" + "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds" ); - let _ = writeln!(out, "# TYPE telemt_pool_drain_soft_evict_total counter"); - let _ = writeln!( - out, - "telemt_pool_drain_soft_evict_total {}", - if me_allows_normal { - stats.get_pool_drain_soft_evict_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_pool_drain_soft_evict_writer_total Draining writers with at least one soft eviction" - ); - let _ = writeln!(out, "# TYPE telemt_pool_drain_soft_evict_writer_total counter"); - let _ = writeln!( - out, - "telemt_pool_drain_soft_evict_writer_total {}", - if me_allows_normal { - stats.get_pool_drain_soft_evict_writer_total() - } else { - 0 - } - ); - - let _ = writeln!(out, "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds"); let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter"); let _ = writeln!( out, @@ -1696,56 +2193,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp let _ = writeln!( out, - "# HELP telemt_me_writer_close_signal_drop_total Close-signal drops for already-removed ME writers" + "# HELP telemt_me_writer_removed_total Total ME writer removals" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_close_signal_drop_total counter"); - let _ = writeln!( - out, - "telemt_me_writer_close_signal_drop_total {}", - if me_allows_normal { - stats.get_me_writer_close_signal_drop_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_close_signal_channel_full_total Close-signal drops caused by full writer command channels" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_close_signal_channel_full_total counter" - ); - let _ = writeln!( - out, - "telemt_me_writer_close_signal_channel_full_total {}", - if me_allows_normal { - stats.get_me_writer_close_signal_channel_full_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_draining_writers_reap_progress_total Draining-writer removals processed by reap cleanup" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_draining_writers_reap_progress_total counter" - ); - let _ = writeln!( - out, - "telemt_me_draining_writers_reap_progress_total {}", - if me_allows_normal { - stats.get_me_draining_writers_reap_progress_total() - } else { - 0 - } - ); - - let _ = writeln!(out, "# HELP telemt_me_writer_removed_total Total ME writer removals"); let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter"); let _ = writeln!( out, @@ -1761,7 +2210,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_removed_unexpected_total Unexpected ME writer removals that triggered refill" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_removed_unexpected_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_removed_unexpected_total counter" + ); let _ = writeln!( out, "telemt_me_writer_removed_unexpected_total {}", @@ -1774,168 +2226,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp let _ = writeln!( out, - "# HELP telemt_me_writer_teardown_attempt_total ME writer teardown attempts by reason and mode" + "# HELP telemt_me_refill_triggered_total Immediate ME refill runs started" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_attempt_total counter"); - for reason in MeWriterTeardownReason::ALL { - for mode in MeWriterTeardownMode::ALL { - let _ = writeln!( - out, - "telemt_me_writer_teardown_attempt_total{{reason=\"{}\",mode=\"{}\"}} {}", - reason.as_str(), - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_attempt_total(reason, mode) - } else { - 0 - } - ); - } - } - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_success_total ME writer teardown successes by mode" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_success_total counter"); - for mode in MeWriterTeardownMode::ALL { - let _ = writeln!( - out, - "telemt_me_writer_teardown_success_total{{mode=\"{}\"}} {}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_success_total(mode) - } else { - 0 - } - ); - } - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_timeout_total Teardown operations that timed out" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_timeout_total counter"); - let _ = writeln!( - out, - "telemt_me_writer_teardown_timeout_total {}", - if me_allows_normal { - stats.get_me_writer_teardown_timeout_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_escalation_total Watchdog teardown escalations to hard detach" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_teardown_escalation_total counter" - ); - let _ = writeln!( - out, - "telemt_me_writer_teardown_escalation_total {}", - if me_allows_normal { - stats.get_me_writer_teardown_escalation_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_noop_total Teardown operations that became no-op" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_noop_total counter"); - let _ = writeln!( - out, - "telemt_me_writer_teardown_noop_total {}", - if me_allows_normal { - stats.get_me_writer_teardown_noop_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_duration_seconds ME writer teardown latency histogram by mode" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_teardown_duration_seconds histogram" - ); - let bucket_labels = Stats::me_writer_teardown_duration_bucket_labels(); - for mode in MeWriterTeardownMode::ALL { - for (bucket_idx, label) in bucket_labels.iter().enumerate() { - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_bucket{{mode=\"{}\",le=\"{}\"}} {}", - mode.as_str(), - label, - if me_allows_normal { - stats.get_me_writer_teardown_duration_bucket_total(mode, bucket_idx) - } else { - 0 - } - ); - } - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_bucket{{mode=\"{}\",le=\"+Inf\"}} {}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_duration_count(mode) - } else { - 0 - } - ); - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_sum{{mode=\"{}\"}} {:.6}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_duration_sum_seconds(mode) - } else { - 0.0 - } - ); - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_count{{mode=\"{}\"}} {}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_duration_count(mode) - } else { - 0 - } - ); - } - - let _ = writeln!( - out, - "# HELP telemt_me_writer_cleanup_side_effect_failures_total Failed cleanup side effects by step" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_cleanup_side_effect_failures_total counter" - ); - for step in MeWriterCleanupSideEffectStep::ALL { - let _ = writeln!( - out, - "telemt_me_writer_cleanup_side_effect_failures_total{{step=\"{}\"}} {}", - step.as_str(), - if me_allows_normal { - stats.get_me_writer_cleanup_side_effect_failures_total(step) - } else { - 0 - } - ); - } - - let _ = writeln!(out, "# HELP telemt_me_refill_triggered_total Immediate ME refill runs started"); let _ = writeln!(out, "# TYPE telemt_me_refill_triggered_total counter"); let _ = writeln!( out, @@ -1951,7 +2243,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_refill_skipped_inflight_total Immediate ME refill skips due to inflight dedup" ); - let _ = writeln!(out, "# TYPE telemt_me_refill_skipped_inflight_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_refill_skipped_inflight_total counter" + ); let _ = writeln!( out, "telemt_me_refill_skipped_inflight_total {}", @@ -1962,7 +2257,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_refill_failed_total Immediate ME refill failures"); + let _ = writeln!( + out, + "# HELP telemt_me_refill_failed_total Immediate ME refill failures" + ); let _ = writeln!(out, "# TYPE telemt_me_refill_failed_total counter"); let _ = writeln!( out, @@ -1978,7 +2276,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_restored_same_endpoint_total Refilled ME writer restored on the same endpoint" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_restored_same_endpoint_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_restored_same_endpoint_total counter" + ); let _ = writeln!( out, "telemt_me_writer_restored_same_endpoint_total {}", @@ -1993,7 +2294,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_restored_fallback_total Refilled ME writer restored via fallback endpoint" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_restored_fallback_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_restored_fallback_total counter" + ); let _ = writeln!( out, "telemt_me_writer_restored_fallback_total {}", @@ -2071,17 +2375,35 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp unresolved_writer_losses ); - let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); + let _ = writeln!( + out, + "# HELP telemt_user_connections_total Per-user total connections" + ); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); - let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); + let _ = writeln!( + out, + "# HELP telemt_user_connections_current Per-user active connections" + ); let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge"); - let _ = writeln!(out, "# HELP telemt_user_octets_from_client Per-user bytes received"); + let _ = writeln!( + out, + "# HELP telemt_user_octets_from_client Per-user bytes received" + ); let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter"); - let _ = writeln!(out, "# HELP telemt_user_octets_to_client Per-user bytes sent"); + let _ = writeln!( + out, + "# HELP telemt_user_octets_to_client Per-user bytes sent" + ); let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter"); - let _ = writeln!(out, "# HELP telemt_user_msgs_from_client Per-user messages received"); + let _ = writeln!( + out, + "# HELP telemt_user_msgs_from_client Per-user messages received" + ); let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter"); - let _ = writeln!(out, "# HELP telemt_user_msgs_to_client Per-user messages sent"); + let _ = writeln!( + out, + "# HELP telemt_user_msgs_to_client Per-user messages sent" + ); let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter"); let _ = writeln!( out, @@ -2121,12 +2443,45 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp for entry in stats.iter_user_stats() { let user = entry.key(); let s = entry.value(); - let _ = writeln!(out, "telemt_user_connections_total{{user=\"{}\"}} {}", user, s.connects.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_connections_current{{user=\"{}\"}} {}", user, s.curr_connects.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_octets_from_client{{user=\"{}\"}} {}", user, s.octets_from_client.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_octets_to_client{{user=\"{}\"}} {}", user, s.octets_to_client.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_msgs_from_client{{user=\"{}\"}} {}", user, s.msgs_from_client.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed)); + let _ = writeln!( + out, + "telemt_user_connections_total{{user=\"{}\"}} {}", + user, + s.connects.load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_connections_current{{user=\"{}\"}} {}", + user, + s.curr_connects.load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_octets_from_client{{user=\"{}\"}} {}", + user, + s.octets_from_client + .load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_octets_to_client{{user=\"{}\"}} {}", + user, + s.octets_to_client + .load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_msgs_from_client{{user=\"{}\"}} {}", + user, + s.msgs_from_client + .load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_msgs_to_client{{user=\"{}\"}} {}", + user, + s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed) + ); } let ip_stats = ip_tracker.get_stats().await; @@ -2144,16 +2499,25 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp .get_recent_counts_for_users(&unique_users_vec) .await; - let _ = writeln!(out, "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs"); + let _ = writeln!( + out, + "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs" + ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge"); let _ = writeln!( out, "# HELP telemt_user_unique_ips_recent_window Per-user unique IPs seen in configured observation window" ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_recent_window gauge"); - let _ = writeln!(out, "# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)"); + let _ = writeln!( + out, + "# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)" + ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge"); - let _ = writeln!(out, "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)"); + let _ = writeln!( + out, + "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)" + ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge"); for user in unique_users { @@ -2164,29 +2528,34 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp .get(&user) .copied() .filter(|limit| *limit > 0) - .or( - (config.access.user_max_unique_ips_global_each > 0) - .then_some(config.access.user_max_unique_ips_global_each), - ) + .or((config.access.user_max_unique_ips_global_each > 0) + .then_some(config.access.user_max_unique_ips_global_each)) .unwrap_or(0); let utilization = if limit > 0 { current as f64 / limit as f64 } else { 0.0 }; - let _ = writeln!(out, "telemt_user_unique_ips_current{{user=\"{}\"}} {}", user, current); + let _ = writeln!( + out, + "telemt_user_unique_ips_current{{user=\"{}\"}} {}", + user, current + ); let _ = writeln!( out, "telemt_user_unique_ips_recent_window{{user=\"{}\"}} {}", user, recent_counts.get(&user).copied().unwrap_or(0) ); - let _ = writeln!(out, "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", user, limit); + let _ = writeln!( + out, + "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", + user, limit + ); let _ = writeln!( out, "telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}", - user, - utilization + user, utilization ); } } @@ -2197,8 +2566,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp #[cfg(test)] mod tests { use super::*; - use std::net::IpAddr; use http_body_util::BodyExt; + use std::net::IpAddr; #[tokio::test] async fn test_render_metrics_format() { @@ -2213,8 +2582,6 @@ mod tests { stats.increment_connects_all(); stats.increment_connects_all(); stats.increment_connects_bad(); - stats.increment_current_connections_direct(); - stats.increment_current_connections_me(); stats.increment_handshake_timeouts(); stats.increment_upstream_connect_attempt_total(); stats.increment_upstream_connect_attempt_total(); @@ -2230,6 +2597,20 @@ mod tests { stats.increment_me_rpc_proxy_req_signal_response_total(); stats.increment_me_rpc_proxy_req_signal_close_sent_total(); stats.increment_me_idle_close_by_peer_total(); + stats.increment_relay_idle_soft_mark_total(); + stats.increment_relay_idle_hard_close_total(); + stats.increment_relay_pressure_evict_total(); + stats.increment_relay_protocol_desync_close_total(); + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(3); + stats.add_me_d2c_batch_bytes_total(2048); + stats.increment_me_d2c_flush_reason(crate::stats::MeD2cFlushReason::AckImmediate); + stats.increment_me_d2c_data_frames_total(); + stats.increment_me_d2c_ack_frames_total(); + stats.add_me_d2c_payload_bytes_total(1800); + stats.increment_me_d2c_write_mode(crate::stats::MeD2cWriteMode::Coalesced); + stats.increment_me_d2c_quota_reject_total(crate::stats::MeD2cQuotaRejectStage::PostWrite); + stats.observe_me_d2c_frame_buf_shrink(4096); stats.increment_user_connects("alice"); stats.increment_user_curr_connects("alice"); stats.add_user_octets_from("alice", 1024); @@ -2246,21 +2627,15 @@ mod tests { assert!(output.contains("telemt_connections_total 2")); assert!(output.contains("telemt_connections_bad_total 1")); - assert!(output.contains("telemt_connections_current 2")); - assert!(output.contains("telemt_connections_direct_current 1")); - assert!(output.contains("telemt_connections_me_current 1")); assert!(output.contains("telemt_handshake_timeouts_total 1")); assert!(output.contains("telemt_upstream_connect_attempt_total 2")); assert!(output.contains("telemt_upstream_connect_success_total 1")); assert!(output.contains("telemt_upstream_connect_fail_total 1")); assert!(output.contains("telemt_upstream_connect_failfast_hard_error_total 1")); + assert!(output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1")); assert!( - output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1") - ); - assert!( - output.contains( - "telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1" - ) + output + .contains("telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1") ); assert!( output.contains("telemt_upstream_connect_duration_fail_total{bucket=\"gt_1000ms\"} 1") @@ -2271,6 +2646,21 @@ mod tests { assert!(output.contains("telemt_me_rpc_proxy_req_signal_response_total 1")); assert!(output.contains("telemt_me_rpc_proxy_req_signal_close_sent_total 1")); assert!(output.contains("telemt_me_idle_close_by_peer_total 1")); + assert!(output.contains("telemt_relay_idle_soft_mark_total 1")); + assert!(output.contains("telemt_relay_idle_hard_close_total 1")); + assert!(output.contains("telemt_relay_pressure_evict_total 1")); + assert!(output.contains("telemt_relay_protocol_desync_close_total 1")); + assert!(output.contains("telemt_me_d2c_batches_total 1")); + assert!(output.contains("telemt_me_d2c_batch_frames_total 3")); + assert!(output.contains("telemt_me_d2c_batch_bytes_total 2048")); + assert!(output.contains("telemt_me_d2c_flush_reason_total{reason=\"ack_immediate\"} 1")); + assert!(output.contains("telemt_me_d2c_data_frames_total 1")); + assert!(output.contains("telemt_me_d2c_ack_frames_total 1")); + assert!(output.contains("telemt_me_d2c_payload_bytes_total 1800")); + assert!(output.contains("telemt_me_d2c_write_mode_total{mode=\"coalesced\"} 1")); + assert!(output.contains("telemt_me_d2c_quota_reject_total{stage=\"post_write\"} 1")); + assert!(output.contains("telemt_me_d2c_frame_buf_shrink_total 1")); + assert!(output.contains("telemt_me_d2c_frame_buf_shrink_bytes_total 4096")); assert!(output.contains("telemt_user_connections_total{user=\"alice\"} 1")); assert!(output.contains("telemt_user_connections_current{user=\"alice\"} 1")); assert!(output.contains("telemt_user_octets_from_client{user=\"alice\"} 1024")); @@ -2291,9 +2681,6 @@ mod tests { let output = render_metrics(&stats, &config, &tracker).await; assert!(output.contains("telemt_connections_total 0")); assert!(output.contains("telemt_connections_bad_total 0")); - assert!(output.contains("telemt_connections_current 0")); - assert!(output.contains("telemt_connections_direct_current 0")); - assert!(output.contains("telemt_connections_me_current 0")); assert!(output.contains("telemt_handshake_timeouts_total 0")); assert!(output.contains("telemt_user_unique_ips_current{user=")); assert!(output.contains("telemt_user_unique_ips_recent_window{user=")); @@ -2327,42 +2714,24 @@ mod tests { assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); assert!(output.contains("# TYPE telemt_connections_total counter")); assert!(output.contains("# TYPE telemt_connections_bad_total counter")); - assert!(output.contains("# TYPE telemt_connections_current gauge")); - assert!(output.contains("# TYPE telemt_connections_direct_current gauge")); - assert!(output.contains("# TYPE telemt_connections_me_current gauge")); - assert!(output.contains("# TYPE telemt_relay_adaptive_promotions_total counter")); - assert!(output.contains("# TYPE telemt_relay_adaptive_demotions_total counter")); - assert!(output.contains("# TYPE telemt_relay_adaptive_hard_promotions_total counter")); - assert!(output.contains("# TYPE telemt_reconnect_evict_total counter")); - assert!(output.contains("# TYPE telemt_reconnect_stale_close_total counter")); assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter")); assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter")); assert!(output.contains("# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter")); assert!(output.contains("# TYPE telemt_me_idle_close_by_peer_total counter")); + assert!(output.contains("# TYPE telemt_relay_idle_soft_mark_total counter")); + assert!(output.contains("# TYPE telemt_relay_idle_hard_close_total counter")); + assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter")); + assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_batches_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_flush_reason_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_write_mode_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_batch_frames_bucket_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_flush_duration_us_bucket_total counter")); assert!(output.contains("# TYPE telemt_me_writer_removed_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_attempt_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_success_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_timeout_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_escalation_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_noop_total counter")); - assert!(output.contains( - "# TYPE telemt_me_writer_teardown_duration_seconds histogram" - )); - assert!(output.contains( - "# TYPE telemt_me_writer_cleanup_side_effect_failures_total counter" - )); - assert!(output.contains("# TYPE telemt_me_writer_close_signal_drop_total counter")); - assert!(output.contains( - "# TYPE telemt_me_writer_close_signal_channel_full_total counter" - )); - assert!(output.contains( - "# TYPE telemt_me_draining_writers_reap_progress_total counter" - )); - assert!(output.contains("# TYPE telemt_pool_drain_soft_evict_total counter")); - assert!(output.contains("# TYPE telemt_pool_drain_soft_evict_writer_total counter")); - assert!(output.contains( - "# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge" - )); + assert!( + output + .contains("# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge") + ); assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge")); @@ -2379,14 +2748,17 @@ mod tests { stats.increment_connects_all(); stats.increment_connects_all(); - let req = Request::builder() - .uri("/metrics") - .body(()) + let req = Request::builder().uri("/metrics").body(()).unwrap(); + let resp = handle(req, &stats, &beobachten, &tracker, &config) + .await .unwrap(); - let resp = handle(req, &stats, &beobachten, &tracker, &config).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = resp.into_body().collect().await.unwrap().to_bytes(); - assert!(std::str::from_utf8(body.as_ref()).unwrap().contains("telemt_connections_total 3")); + assert!( + std::str::from_utf8(body.as_ref()) + .unwrap() + .contains("telemt_connections_total 3") + ); config.general.beobachten = true; config.general.beobachten_minutes = 10; @@ -2395,10 +2767,7 @@ mod tests { "203.0.113.10".parse::().unwrap(), Duration::from_secs(600), ); - let req_beob = Request::builder() - .uri("/beobachten") - .body(()) - .unwrap(); + let req_beob = Request::builder().uri("/beobachten").body(()).unwrap(); let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config) .await .unwrap(); @@ -2408,10 +2777,7 @@ mod tests { assert!(beob_text.contains("[TLS-scanner]")); assert!(beob_text.contains("203.0.113.10-1")); - let req404 = Request::builder() - .uri("/other") - .body(()) - .unwrap(); + let req404 = Request::builder().uri("/other").body(()).unwrap(); let resp404 = handle(req404, &stats, &beobachten, &tracker, &config) .await .unwrap(); diff --git a/src/network/dns_overrides.rs b/src/network/dns_overrides.rs index 447863a..86fb325 100644 --- a/src/network/dns_overrides.rs +++ b/src/network/dns_overrides.rs @@ -26,9 +26,7 @@ fn parse_ip_spec(ip_spec: &str) -> Result { } let ip = ip_spec.parse::().map_err(|_| { - ProxyError::Config(format!( - "network.dns_overrides IP is invalid: '{ip_spec}'" - )) + ProxyError::Config(format!("network.dns_overrides IP is invalid: '{ip_spec}'")) })?; if matches!(ip, IpAddr::V6(_)) { return Err(ProxyError::Config(format!( @@ -103,9 +101,9 @@ pub fn validate_entries(entries: &[String]) -> Result<()> { /// Replace runtime DNS overrides with a new validated snapshot. pub fn install_entries(entries: &[String]) -> Result<()> { let parsed = parse_entries(entries)?; - let mut guard = overrides_store() - .write() - .map_err(|_| ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()))?; + let mut guard = overrides_store().write().map_err(|_| { + ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()) + })?; *guard = parsed; Ok(()) } diff --git a/src/network/probe.rs b/src/network/probe.rs index a9e369d..1787b92 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +#![allow(clippy::items_after_test_module)] use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; @@ -10,7 +11,9 @@ use tracing::{debug, info, warn}; use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType}; use crate::error::Result; -use crate::network::stun::{stun_probe_family_with_bind, DualStunResult, IpFamily, StunProbeResult}; +use crate::network::stun::{ + DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind, +}; use crate::transport::UpstreamManager; #[derive(Debug, Clone, Default)] @@ -78,13 +81,8 @@ pub async fn run_probe( warn!("STUN probe is enabled but network.stun_servers is empty"); DualStunResult::default() } else { - probe_stun_servers_parallel( - &servers, - stun_nat_probe_concurrency.max(1), - None, - None, - ) - .await + probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None) + .await } } else if nat_probe { info!("STUN probe is disabled by network.stun_use=false"); @@ -99,7 +97,8 @@ pub async fn run_probe( let UpstreamType::Direct { interface, bind_addresses, - } = &upstream.upstream_type else { + } = &upstream.upstream_type + else { continue; }; if let Some(addrs) = bind_addresses.as_ref().filter(|v| !v.is_empty()) { @@ -199,11 +198,10 @@ pub async fn run_probe( if nat_probe && probe.reflected_ipv4.is_none() && probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false) + && let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await { - if let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await { - probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); - info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); - } + probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); + info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); } probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) { @@ -217,12 +215,20 @@ pub async fn run_probe( probe.ipv4_usable = config.ipv4 && probe.detected_ipv4.is_some() - && (!probe.ipv4_is_bogon || probe.reflected_ipv4.map(|r| !is_bogon(r.ip())).unwrap_or(false)); + && (!probe.ipv4_is_bogon + || probe + .reflected_ipv4 + .map(|r| !is_bogon(r.ip())) + .unwrap_or(false)); let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()); probe.ipv6_usable = ipv6_enabled && probe.detected_ipv6.is_some() - && (!probe.ipv6_is_bogon || probe.reflected_ipv6.map(|r| !is_bogon(r.ip())).unwrap_or(false)); + && (!probe.ipv6_is_bogon + || probe + .reflected_ipv6 + .map(|r| !is_bogon(r.ip())) + .unwrap_or(false)); Ok(probe) } @@ -280,8 +286,6 @@ async fn probe_stun_servers_parallel( while next_idx < servers.len() && join_set.len() < concurrency { let stun_addr = servers[next_idx].clone(); next_idx += 1; - let bind_v4 = bind_v4; - let bind_v6 = bind_v6; join_set.spawn(async move { let res = timeout(STUN_BATCH_TIMEOUT, async { let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?; @@ -300,11 +304,15 @@ async fn probe_stun_servers_parallel( match task { Ok((stun_addr, Ok(Ok(result)))) => { if let Some(v4) = result.v4 { - let entry = best_v4_by_ip.entry(v4.reflected_addr.ip()).or_insert((0, v4)); + let entry = best_v4_by_ip + .entry(v4.reflected_addr.ip()) + .or_insert((0, v4)); entry.0 += 1; } if let Some(v6) = result.v6 { - let entry = best_v6_by_ip.entry(v6.reflected_addr.ip()).or_insert((0, v6)); + let entry = best_v6_by_ip + .entry(v6.reflected_addr.ip()) + .or_insert((0, v6)); entry.0 += 1; } if result.v4.is_some() || result.v6.is_some() { @@ -324,17 +332,11 @@ async fn probe_stun_servers_parallel( } let mut out = DualStunResult::default(); - if let Some((_, best)) = best_v4_by_ip - .into_values() - .max_by_key(|(count, _)| *count) - { + if let Some((_, best)) = best_v4_by_ip.into_values().max_by_key(|(count, _)| *count) { info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); out.v4 = Some(best); } - if let Some((_, best)) = best_v6_by_ip - .into_values() - .max_by_key(|(count, _)| *count) - { + if let Some((_, best)) = best_v6_by_ip.into_values().max_by_key(|(count, _)| *count) { info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); out.v6 = Some(best); } @@ -347,7 +349,8 @@ pub fn decide_network_capabilities( middle_proxy_nat_ip: Option, ) -> NetworkDecision { let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some(); - let ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some(); + let ipv6_dc = + config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some(); let nat_ip_v4 = matches!(middle_proxy_nat_ip, Some(IpAddr::V4(_))); let nat_ip_v6 = matches!(middle_proxy_nat_ip, Some(IpAddr::V6(_))); @@ -534,10 +537,26 @@ pub fn is_bogon_v6(ip: Ipv6Addr) -> bool { pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) { info!( - ipv4 = probe.detected_ipv4.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()), - ipv6 = probe.detected_ipv6.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()), - reflected_v4 = probe.reflected_ipv4.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()), - reflected_v6 = probe.reflected_ipv6.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()), + ipv4 = probe + .detected_ipv4 + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "-".into()), + ipv6 = probe + .detected_ipv6 + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "-".into()), + reflected_v4 = probe + .reflected_ipv4 + .as_ref() + .map(|v| v.ip().to_string()) + .unwrap_or_else(|| "-".into()), + reflected_v6 = probe + .reflected_ipv6 + .as_ref() + .map(|v| v.ip().to_string()) + .unwrap_or_else(|| "-".into()), ipv4_bogon = probe.ipv4_is_bogon, ipv6_bogon = probe.ipv6_is_bogon, ipv4_me = decision.ipv4_me, diff --git a/src/network/stun.rs b/src/network/stun.rs index c3a235f..d1e088c 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -2,13 +2,20 @@ #![allow(dead_code)] use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::OnceLock; -use tokio::net::{lookup_host, UdpSocket}; -use tokio::time::{timeout, Duration, sleep}; +use tokio::net::{UdpSocket, lookup_host}; +use tokio::time::{Duration, sleep, timeout}; +use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::network::dns_overrides::{resolve, split_host_port}; +fn stun_rng() -> &'static SecureRandom { + static STUN_RNG: OnceLock = OnceLock::new(); + STUN_RNG.get_or_init(SecureRandom::new) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IpFamily { V4, @@ -34,13 +41,13 @@ pub async fn stun_probe_dual(stun_addr: &str) -> Result { stun_probe_family(stun_addr, IpFamily::V6), ); - Ok(DualStunResult { - v4: v4?, - v6: v6?, - }) + Ok(DualStunResult { v4: v4?, v6: v6? }) } -pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result> { +pub async fn stun_probe_family( + stun_addr: &str, + family: IpFamily, +) -> Result> { stun_probe_family_with_bind(stun_addr, family, None).await } @@ -49,8 +56,6 @@ pub async fn stun_probe_family_with_bind( family: IpFamily, bind_ip: Option, ) -> Result> { - use rand::RngCore; - let bind_addr = match (family, bind_ip) { (IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0), (IpFamily::V6, Some(IpAddr::V6(ip))) => SocketAddr::new(IpAddr::V6(ip), 0), @@ -71,13 +76,18 @@ pub async fn stun_probe_family_with_bind( if let Some(addr) = target_addr { match socket.connect(addr).await { Ok(()) => {} - Err(e) if family == IpFamily::V6 && matches!( - e.kind(), - std::io::ErrorKind::NetworkUnreachable - | std::io::ErrorKind::HostUnreachable - | std::io::ErrorKind::Unsupported - | std::io::ErrorKind::NetworkDown - ) => return Ok(None), + Err(e) + if family == IpFamily::V6 + && matches!( + e.kind(), + std::io::ErrorKind::NetworkUnreachable + | std::io::ErrorKind::HostUnreachable + | std::io::ErrorKind::Unsupported + | std::io::ErrorKind::NetworkDown + ) => + { + return Ok(None); + } Err(e) => return Err(ProxyError::Proxy(format!("STUN connect failed: {e}"))), } } else { @@ -88,7 +98,7 @@ pub async fn stun_probe_family_with_bind( req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie - rand::rng().fill_bytes(&mut req[8..20]); // transaction ID + stun_rng().fill(&mut req[8..20]); // transaction ID let mut buf = [0u8; 256]; let mut attempt = 0; @@ -120,16 +130,16 @@ pub async fn stun_probe_family_with_bind( let magic = 0x2112A442u32.to_be_bytes(); let txid = &req[8..20]; - let mut idx = 20; - while idx + 4 <= n { - let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); - let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; - idx += 4; - if idx + alen > n { - break; - } + let mut idx = 20; + while idx + 4 <= n { + let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); + let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; + idx += 4; + if idx + alen > n { + break; + } - match atype { + match atype { 0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => { if alen < 8 { break; @@ -198,9 +208,8 @@ pub async fn stun_probe_family_with_bind( _ => {} } - idx += (alen + 3) & !3; - } - + idx += (alen + 3) & !3; + } } Ok(None) @@ -228,7 +237,11 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result> = LazyLock::new(|| { // ============= Middle Proxies (for advertising) ============= -pub static TG_MIDDLE_PROXIES_V4: LazyLock>> = +pub static TG_MIDDLE_PROXIES_V4: LazyLock>> = LazyLock::new(|| { let mut m = std::collections::HashMap::new(); - m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); - m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); - m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); - m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); - m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); - m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); + m.insert( + 1, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)], + ); + m.insert( + -1, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)], + ); + m.insert( + 2, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)], + ); + m.insert( + -2, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)], + ); + m.insert( + 3, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)], + ); + m.insert( + -3, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)], + ); m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]); - m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]); + m.insert( + -4, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)], + ); m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); - m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); + m.insert( + -5, + vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)], + ); m }); -pub static TG_MIDDLE_PROXIES_V6: LazyLock>> = +pub static TG_MIDDLE_PROXIES_V6: LazyLock>> = LazyLock::new(|| { let mut m = std::collections::HashMap::new(); - m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); - m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); - m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); - m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); - m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); - m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); - m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); - m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); - m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); - m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); + m.insert( + 1, + vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)], + ); + m.insert( + -1, + vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)], + ); + m.insert( + 2, + vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)], + ); + m.insert( + -2, + vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)], + ); + m.insert( + 3, + vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)], + ); + m.insert( + -3, + vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)], + ); + m.insert( + 4, + vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)], + ); + m.insert( + -4, + vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)], + ); + m.insert( + 5, + vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)], + ); + m.insert( + -5, + vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)], + ); m }); @@ -89,12 +143,12 @@ impl ProtoTag { _ => None, } } - + /// Convert to 4 bytes (little-endian) pub fn to_bytes(self) -> [u8; 4] { (self as u32).to_le_bytes() } - + /// Get protocol tag as bytes slice pub fn as_bytes(&self) -> &'static [u8; 4] { match self { @@ -152,11 +206,29 @@ pub const TLS_RECORD_CHANGE_CIPHER: u8 = 0x14; pub const TLS_RECORD_APPLICATION: u8 = 0x17; /// TLS record type: Alert pub const TLS_RECORD_ALERT: u8 = 0x15; -/// Maximum TLS record size -pub const MAX_TLS_RECORD_SIZE: usize = 16384; -/// Maximum TLS chunk size (with overhead) -/// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext -pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256; +/// Maximum TLS plaintext record payload size. +/// RFC 8446 §5.1: "The length MUST NOT exceed 2^14 bytes." +/// Use this for validating incoming unencrypted records +/// (ClientHello, ChangeCipherSpec, unprotected Handshake messages). +pub const MAX_TLS_PLAINTEXT_SIZE: usize = 16_384; + +/// Structural minimum for a valid TLS 1.3 ClientHello with SNI. +/// Derived from RFC 8446 §4.1.2 field layout + Appendix D.4 compat mode. +/// Deliberately conservative (below any real client) to avoid false +/// positives on legitimate connections with compact extension sets. +pub const MIN_TLS_CLIENT_HELLO_SIZE: usize = 100; + +/// Maximum TLS ciphertext record payload size. +/// RFC 8446 §5.2: "The length MUST NOT exceed 2^14 + 256 bytes." +/// The +256 accounts for maximum AEAD expansion overhead. +/// Use this for validating or sizing buffers for encrypted records. +pub const MAX_TLS_CIPHERTEXT_SIZE: usize = 16_384 + 256; + +#[deprecated(note = "use MAX_TLS_PLAINTEXT_SIZE")] +pub const MAX_TLS_RECORD_SIZE: usize = MAX_TLS_PLAINTEXT_SIZE; + +#[deprecated(note = "use MAX_TLS_CIPHERTEXT_SIZE")] +pub const MAX_TLS_CHUNK_SIZE: usize = MAX_TLS_CIPHERTEXT_SIZE; /// Secure Intermediate payload is expected to be 4-byte aligned. pub fn is_valid_secure_payload_len(data_len: usize) -> bool { @@ -204,9 +276,7 @@ pub const SMALL_BUFFER_SIZE: usize = 8192; // ============= Statistics ============= /// Duration buckets for histogram metrics -pub static DURATION_BUCKETS: &[f64] = &[ - 0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0, -]; +pub static DURATION_BUCKETS: &[f64] = &[0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0]; // ============= Reserved Nonce Patterns ============= @@ -217,29 +287,27 @@ pub static RESERVED_NONCE_FIRST_BYTES: &[u8] = &[0xef]; pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[ [0x48, 0x45, 0x41, 0x44], // HEAD [0x50, 0x4F, 0x53, 0x54], // POST - [0x47, 0x45, 0x54, 0x20], // GET + [0x47, 0x45, 0x54, 0x20], // GET [0xee, 0xee, 0xee, 0xee], // Intermediate [0xdd, 0xdd, 0xdd, 0xdd], // Secure [0x16, 0x03, 0x01, 0x02], // TLS ]; /// Reserved continuation bytes (bytes 4-7) -pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[ - [0x00, 0x00, 0x00, 0x00], -]; +pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[[0x00, 0x00, 0x00, 0x00]]; // ============= RPC Constants (for Middle Proxy) ============= /// RPC Proxy Request /// RPC Flags (from Erlang mtp_rpc.erl) pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2; -pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8; -pub const RPC_FLAG_MAGIC: u32 = 0x1000; -pub const RPC_FLAG_EXTMODE2: u32 = 0x20000; -pub const RPC_FLAG_PAD: u32 = 0x8000000; -pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000; -pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000; -pub const RPC_FLAG_QUICKACK: u32 = 0x80000000; +pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8; +pub const RPC_FLAG_MAGIC: u32 = 0x1000; +pub const RPC_FLAG_EXTMODE2: u32 = 0x20000; +pub const RPC_FLAG_PAD: u32 = 0x8000000; +pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000; +pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000; +pub const RPC_FLAG_QUICKACK: u32 = 0x80000000; pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36]; /// RPC Proxy Answer @@ -267,63 +335,66 @@ pub mod rpc_flags { pub const FLAG_QUICKACK: u32 = 0x80000000; } +// ============= Middle-End Proxy Servers ============= +pub const ME_PROXY_PORT: u16 = 8888; - // ============= Middle-End Proxy Servers ============= - pub const ME_PROXY_PORT: u16 = 8888; - - pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock> = LazyLock::new(|| { - vec![ - (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888), - (IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888), - (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888), - (IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888), - (IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888), - ] - }); - - // ============= RPC Constants (u32 native endian) ============= - // From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c - - pub const RPC_NONCE_U32: u32 = 0x7acb87aa; - pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5; - pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda; - pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121 - - // mtproto-common.h - pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee; - pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d; - pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d; - pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2; - pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b; - pub const RPC_PING_U32: u32 = 0x5730a2df; - pub const RPC_PONG_U32: u32 = 0x8430eaa7; - - pub const RPC_CRYPTO_NONE_U32: u32 = 0; - pub const RPC_CRYPTO_AES_U32: u32 = 1; - - pub mod proxy_flags { - pub const FLAG_HAS_AD_TAG: u32 = 1; - pub const FLAG_NOT_ENCRYPTED: u32 = 0x2; - pub const FLAG_HAS_AD_TAG2: u32 = 0x8; - pub const FLAG_MAGIC: u32 = 0x1000; - pub const FLAG_EXTMODE2: u32 = 0x20000; - pub const FLAG_PAD: u32 = 0x8000000; - pub const FLAG_INTERMEDIATE: u32 = 0x20000000; - pub const FLAG_ABRIDGED: u32 = 0x40000000; - pub const FLAG_QUICKACK: u32 = 0x80000000; - } +pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock> = LazyLock::new(|| { + vec![ + (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888), + (IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888), + (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888), + (IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888), + (IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888), + ] +}); - pub mod rpc_crypto_flags { - pub const USE_CRC32C: u32 = 0x800; - } - - pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; - pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; - - #[cfg(test)] +// ============= RPC Constants (u32 native endian) ============= +// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c + +pub const RPC_NONCE_U32: u32 = 0x7acb87aa; +pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5; +pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda; +pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121 + +// mtproto-common.h +pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee; +pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d; +pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d; +pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2; +pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b; +pub const RPC_PING_U32: u32 = 0x5730a2df; +pub const RPC_PONG_U32: u32 = 0x8430eaa7; + +pub const RPC_CRYPTO_NONE_U32: u32 = 0; +pub const RPC_CRYPTO_AES_U32: u32 = 1; + +pub mod proxy_flags { + pub const FLAG_HAS_AD_TAG: u32 = 1; + pub const FLAG_NOT_ENCRYPTED: u32 = 0x2; + pub const FLAG_HAS_AD_TAG2: u32 = 0x8; + pub const FLAG_MAGIC: u32 = 0x1000; + pub const FLAG_EXTMODE2: u32 = 0x20000; + pub const FLAG_PAD: u32 = 0x8000000; + pub const FLAG_INTERMEDIATE: u32 = 0x20000000; + pub const FLAG_ABRIDGED: u32 = 0x40000000; + pub const FLAG_QUICKACK: u32 = 0x80000000; +} + +pub mod rpc_crypto_flags { + pub const USE_CRC32C: u32 = 0x800; +} + +pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; +pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; + +#[cfg(test)] +#[path = "tests/tls_size_constants_security_tests.rs"] +mod tls_size_constants_security_tests; + +#[cfg(test)] mod tests { use super::*; - + #[test] fn test_proto_tag_roundtrip() { for tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { @@ -332,20 +403,20 @@ mod tests { assert_eq!(tag, parsed); } } - + #[test] fn test_proto_tag_values() { assert_eq!(ProtoTag::Abridged.to_bytes(), PROTO_TAG_ABRIDGED); assert_eq!(ProtoTag::Intermediate.to_bytes(), PROTO_TAG_INTERMEDIATE); assert_eq!(ProtoTag::Secure.to_bytes(), PROTO_TAG_SECURE); } - + #[test] fn test_invalid_proto_tag() { assert!(ProtoTag::from_bytes([0, 0, 0, 0]).is_none()); assert!(ProtoTag::from_bytes([0xff, 0xff, 0xff, 0xff]).is_none()); } - + #[test] fn test_datacenters_count() { assert_eq!(TG_DATACENTERS_V4.len(), 5); diff --git a/src/protocol/frame.rs b/src/protocol/frame.rs index dd59ba9..d8e3d4a 100644 --- a/src/protocol/frame.rs +++ b/src/protocol/frame.rs @@ -22,7 +22,7 @@ impl FrameExtra { pub fn new() -> Self { Self::default() } - + /// Create with quickack flag set pub fn with_quickack() -> Self { Self { @@ -30,7 +30,7 @@ impl FrameExtra { ..Default::default() } } - + /// Create with simple_ack flag set pub fn with_simple_ack() -> Self { Self { @@ -38,7 +38,7 @@ impl FrameExtra { ..Default::default() } } - + /// Check if any flags are set pub fn has_flags(&self) -> bool { self.quickack || self.simple_ack || self.skip_send @@ -76,22 +76,22 @@ impl FrameMode { FrameMode::Abridged => 4, FrameMode::Intermediate => 4, FrameMode::SecureIntermediate => 4 + 3, // length + padding - FrameMode::Full => 12 + 16, // header + max CBC padding + FrameMode::Full => 12 + 16, // header + max CBC padding } } } /// Validate message length for MTProto pub fn validate_message_length(len: usize) -> bool { - use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER}; - + use super::constants::{MAX_MSG_LEN, MIN_MSG_LEN, PADDING_FILLER}; + (MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) && len.is_multiple_of(PADDING_FILLER.len()) } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_frame_extra_default() { let extra = FrameExtra::default(); @@ -100,18 +100,18 @@ mod tests { assert!(!extra.skip_send); assert!(!extra.has_flags()); } - + #[test] fn test_frame_extra_flags() { let extra = FrameExtra::with_quickack(); assert!(extra.quickack); assert!(extra.has_flags()); - + let extra = FrameExtra::with_simple_ack(); assert!(extra.simple_ack); assert!(extra.has_flags()); } - + #[test] fn test_validate_message_length() { assert!(validate_message_length(12)); // MIN_MSG_LEN @@ -119,4 +119,4 @@ mod tests { assert!(!validate_message_length(8)); // Too small assert!(!validate_message_length(13)); // Not aligned to 4 } -} \ No newline at end of file +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 5518df2..f0b3a1a 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -12,4 +12,4 @@ pub use frame::*; #[allow(unused_imports)] pub use obfuscation::*; #[allow(unused_imports)] -pub use tls::*; \ No newline at end of file +pub use tls::*; diff --git a/src/protocol/obfuscation.rs b/src/protocol/obfuscation.rs index d9d1c0a..7aff9f3 100644 --- a/src/protocol/obfuscation.rs +++ b/src/protocol/obfuscation.rs @@ -2,9 +2,9 @@ #![allow(dead_code)] -use zeroize::Zeroize; -use crate::crypto::{sha256, AesCtr}; use super::constants::*; +use crate::crypto::{AesCtr, sha256}; +use zeroize::Zeroize; /// Obfuscation parameters from handshake /// @@ -44,41 +44,40 @@ impl ObfuscationParams { let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; - + for (username, secret) in secrets { let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(secret); let decrypt_key = sha256(&dec_key_input); - + let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); - + let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let decrypted = decryptor.decrypt(handshake); - + let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] .try_into() .unwrap(); - + let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, None => continue, }; - - let dc_idx = i16::from_le_bytes( - decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() - ); - + + let dc_idx = + i16::from_le_bytes(decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()); + let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(secret); let encrypt_key = sha256(&enc_key_input); let encrypt_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); - + return Some(( ObfuscationParams { decrypt_key, @@ -91,20 +90,20 @@ impl ObfuscationParams { username.clone(), )); } - + None } - + /// Create AES-CTR decryptor for client -> proxy direction pub fn create_decryptor(&self) -> AesCtr { AesCtr::new(&self.decrypt_key, self.decrypt_iv) } - + /// Create AES-CTR encryptor for proxy -> client direction pub fn create_encryptor(&self) -> AesCtr { AesCtr::new(&self.encrypt_key, self.encrypt_iv) } - + /// Get the combined encrypt key and IV for fast mode pub fn enc_key_iv(&self) -> Vec { let mut result = Vec::with_capacity(KEY_LEN + IV_LEN); @@ -120,7 +119,7 @@ pub fn generate_nonce Vec>(mut random_bytes: R) -> [u8; H let nonce_vec = random_bytes(HANDSHAKE_LEN); let mut nonce = [0u8; HANDSHAKE_LEN]; nonce.copy_from_slice(&nonce_vec); - + if is_valid_nonce(&nonce) { return nonce; } @@ -132,17 +131,17 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { return false; } - + let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { return false; } - + let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); if RESERVED_NONCE_CONTINUES.contains(&continue_four) { return false; } - + true } @@ -153,7 +152,7 @@ pub fn prepare_tg_nonce( enc_key_iv: Option<&[u8]>, ) { nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); - + if let Some(key_iv) = enc_key_iv { let reversed: Vec = key_iv.iter().rev().copied().collect(); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); @@ -171,39 +170,39 @@ pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key = sha256(key_iv); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); - + let mut encryptor = AesCtr::new(&enc_key, enc_iv); - + let mut result = nonce.to_vec(); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); - + result } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_is_valid_nonce() { let mut valid = [0x42u8; HANDSHAKE_LEN]; valid[4..8].copy_from_slice(&[1, 2, 3, 4]); assert!(is_valid_nonce(&valid)); - + let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[0] = 0xef; assert!(!is_valid_nonce(&invalid)); - + let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[..4].copy_from_slice(b"HEAD"); assert!(!is_valid_nonce(&invalid)); - + let mut invalid = [0x42u8; HANDSHAKE_LEN]; invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); assert!(!is_valid_nonce(&invalid)); } - + #[test] fn test_generate_nonce() { let mut counter = 0u8; @@ -211,7 +210,7 @@ mod tests { counter = counter.wrapping_add(1); vec![counter; n] }); - + assert!(is_valid_nonce(&nonce)); assert_eq!(nonce.len(), HANDSHAKE_LEN); } diff --git a/src/protocol/tests/tls_adversarial_tests.rs b/src/protocol/tests/tls_adversarial_tests.rs new file mode 100644 index 0000000..0b36ba3 --- /dev/null +++ b/src/protocol/tests/tls_adversarial_tests.rs @@ -0,0 +1,358 @@ +use super::*; +use crate::crypto::sha256_hmac; +use std::time::Instant; + +/// Helper to create a byte vector of specific length. +fn make_garbage(len: usize) -> Vec { + vec![0x42u8; len] +} + +/// Helper to create a valid-looking HMAC digest for test. +fn make_digest(secret: &[u8], msg: &[u8], ts: u32) -> [u8; 32] { + let mut hmac = sha256_hmac(secret, msg); + let ts_bytes = ts.to_le_bytes(); + for i in 0..4 { + hmac[28 + i] ^= ts_bytes[i]; + } + hmac +} + +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let digest = make_digest(secret, &handshake, timestamp); + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32]) +} + +// ------------------------------------------------------------------ +// Truncated Packet Tests (OWASP ASVS 5.1.4, 5.1.5) +// ------------------------------------------------------------------ + +#[test] +fn validate_tls_handshake_truncated_10_bytes_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + let truncated = make_garbage(10); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn validate_tls_handshake_truncated_at_digest_start_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + // TLS_DIGEST_POS = 11. 11 bytes should be rejected. + let truncated = make_garbage(TLS_DIGEST_POS); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn validate_tls_handshake_truncated_inside_digest_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + // TLS_DIGEST_POS + 16 (half digest) + let truncated = make_garbage(TLS_DIGEST_POS + 16); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn extract_sni_truncated_at_record_header_rejected() { + let truncated = make_garbage(3); + assert!(extract_sni_from_client_hello(&truncated).is_none()); +} + +#[test] +fn extract_sni_truncated_at_handshake_header_rejected() { + let mut truncated = vec![TLS_RECORD_HANDSHAKE, 0x03, 0x03, 0x00, 0x05]; + truncated.extend_from_slice(&[0x01, 0x00]); // ClientHello type but truncated length + assert!(extract_sni_from_client_hello(&truncated).is_none()); +} + +// ------------------------------------------------------------------ +// Malformed Extension Parsing Tests +// ------------------------------------------------------------------ + +#[test] +fn extract_sni_with_overlapping_extension_lengths_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // Handshake type: ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92 + h.extend_from_slice(&[0x03, 0x03]); // Version + h.extend_from_slice(&[0u8; 32]); // Random + h.push(0); // Session ID length: 0 + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites + h.extend_from_slice(&[0x01, 0x00]); // Compression + + // Extensions start + h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32 + + // Extension 1: SNI (type 0) + h.extend_from_slice(&[0x00, 0x00]); + h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32) + h.extend_from_slice(&[0u8; 64]); + + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_sni_with_infinite_loop_potential_extension_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // Handshake type: ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92 + h.extend_from_slice(&[0x03, 0x03]); // Version + h.extend_from_slice(&[0u8; 32]); // Random + h.push(0); // Session ID length: 0 + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites + h.extend_from_slice(&[0x01, 0x00]); // Compression + + // Extensions start + h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16 + + // Extension: zero length but claims more? + // If our parser didn't advance, it might loop. + // Telemt uses `pos += 4 + elen;` so it always advances. + h.extend_from_slice(&[0x12, 0x34]); // Unknown type + h.extend_from_slice(&[0x00, 0x00]); // Length 0 + + // Fill the rest with garbage + h.extend_from_slice(&[0x42; 12]); + + // We expect it to finish without SNI found + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_sni_with_invalid_hostname_rejected() { + let host = b"invalid_host!%^"; + let mut sni = Vec::new(); + sni.extend_from_slice(&((host.len() + 3) as u16).to_be_bytes()); + sni.push(0); + sni.extend_from_slice(&(host.len() as u16).to_be_bytes()); + sni.extend_from_slice(host); + + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); + h.extend_from_slice(&[0x01, 0x00]); + + let mut ext = Vec::new(); + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni); + + h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + h.extend_from_slice(&ext); + + assert!( + extract_sni_from_client_hello(&h).is_none(), + "Invalid SNI hostname must be rejected" + ); +} + +// ------------------------------------------------------------------ +// Timing Neutrality Tests (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[test] +fn validate_tls_handshake_timing_neutrality() { + let secret = b"timing_test_secret_32_bytes_long_"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let mut base = vec![0x42u8; 100]; + base[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + const ITER: usize = 600; + const ROUNDS: usize = 7; + + let mut per_round_avg_diff_ns = Vec::with_capacity(ROUNDS); + + for round in 0..ROUNDS { + let mut success_h = base.clone(); + let mut fail_h = base.clone(); + + let start_success = Instant::now(); + for _ in 0..ITER { + let digest = make_digest(secret, &success_h, 0); + success_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + let _ = validate_tls_handshake_at_time(&success_h, &secrets, true, 0); + } + let success_elapsed = start_success.elapsed(); + + let start_fail = Instant::now(); + for i in 0..ITER { + let mut digest = make_digest(secret, &fail_h, 0); + let flip_idx = (i + round) % (TLS_DIGEST_LEN - 4); + digest[flip_idx] ^= 0xFF; + fail_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + let _ = validate_tls_handshake_at_time(&fail_h, &secrets, true, 0); + } + let fail_elapsed = start_fail.elapsed(); + + let diff = if success_elapsed > fail_elapsed { + success_elapsed - fail_elapsed + } else { + fail_elapsed - success_elapsed + }; + per_round_avg_diff_ns.push(diff.as_nanos() as f64 / ITER as f64); + } + + per_round_avg_diff_ns.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let median_avg_diff_ns = per_round_avg_diff_ns[ROUNDS / 2]; + + // Keep this as a coarse side-channel guard only; noisy shared CI hosts can + // introduce microsecond-level jitter that should not fail deterministic suites. + assert!( + median_avg_diff_ns < 50_000.0, + "Median timing delta too large: {} ns/iter", + median_avg_diff_ns + ); +} + +// ------------------------------------------------------------------ +// Adversarial Fingerprinting / Active Probing Tests +// ------------------------------------------------------------------ + +#[test] +fn is_tls_handshake_robustness_against_probing() { + // Valid TLS 1.0 ClientHello + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + // Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer) + assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); + + // Invalid record type but matching version + assert!(!is_tls_handshake(&[0x17, 0x03, 0x03])); + // Plaintext HTTP request + assert!(!is_tls_handshake(b"GET / HTTP/1.1")); + // Short garbage + assert!(!is_tls_handshake(&[0x16, 0x03])); +} + +#[test] +fn validate_tls_handshake_at_time_strict_boundary() { + let secret = b"strict_boundary_secret_32_bytes_"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_000_000_000; + + // Boundary: exactly TIME_SKEW_MAX (120s past) + let ts_past = (now - TIME_SKEW_MAX) as u32; + let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]); + assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some()); + + // Boundary + 1s: should be rejected + let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; + let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]); + assert!(validate_tls_handshake_at_time(&h2, &secrets, false, now).is_none()); +} + +#[test] +fn extract_sni_with_duplicate_extensions_rejected() { + // Construct a ClientHello with TWO SNI extensions + let host1 = b"first.com"; + let mut sni1 = Vec::new(); + sni1.extend_from_slice(&((host1.len() + 3) as u16).to_be_bytes()); + sni1.push(0); + sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes()); + sni1.extend_from_slice(host1); + + let host2 = b"second.com"; + let mut sni2 = Vec::new(); + sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes()); + sni2.push(0); + sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes()); + sni2.extend_from_slice(host2); + + let mut ext = Vec::new(); + // Ext 1: SNI + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni1.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni1); + // Ext 2: SNI again + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni2); + + let mut body = Vec::new(); + body.extend_from_slice(&[0x03, 0x03]); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); + body.extend_from_slice(&[0x01, 0x00]); + body.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut h = Vec::new(); + h.push(0x16); + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + h.extend_from_slice(&handshake); + + // Duplicate SNI extensions are ambiguous and must fail closed. + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_alpn_with_malformed_list_rejected() { + let mut alpn_payload = Vec::new(); + alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5 + alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5) + alpn_payload.extend_from_slice(b"h2"); + + let mut ext = Vec::new(); + ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16) + ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes()); + ext.extend_from_slice(&alpn_payload); + + let mut h = vec![ + 0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03, + ]; + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); + h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + h.extend_from_slice(&ext); + + let res = extract_alpn_from_client_hello(&h); + assert!( + res.is_empty(), + "Malformed ALPN list must return empty or fail" + ); +} + +#[test] +fn extract_sni_with_huge_extension_header_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x00]; // Record header + h.push(0x01); // ClientHello + h.extend_from_slice(&[0x00, 0xFF, 0xFF]); // Huge length (65535) - overflows record + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); + + // Extensions start + h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything) + + assert!(extract_sni_from_client_hello(&h).is_none()); +} diff --git a/src/protocol/tests/tls_fuzz_security_tests.rs b/src/protocol/tests/tls_fuzz_security_tests.rs new file mode 100644 index 0000000..903adb3 --- /dev/null +++ b/src/protocol/tests/tls_fuzz_security_tests.rs @@ -0,0 +1,210 @@ +use super::*; +use crate::crypto::sha256_hmac; +use std::panic::catch_unwind; + +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + assert!(session_id_len <= u8::MAX as usize); + + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut digest = sha256_hmac(secret, &handshake); + let ts = timestamp.to_le_bytes(); + for idx in 0..4 { + digest[28 + idx] ^= ts[idx]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() { + let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]); + assert_eq!( + extract_sni_from_client_hello(&valid).as_deref(), + Some("example.com") + ); + assert_eq!( + extract_alpn_from_client_hello(&valid), + vec![b"h2".to_vec(), b"http/1.1".to_vec()] + ); + assert!( + extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(), + "literal IP hostnames must be rejected" + ); + + let mut corpus = vec![ + Vec::new(), + vec![0x16, 0x03, 0x03], + valid[..9].to_vec(), + valid[..valid.len() - 1].to_vec(), + ]; + + let mut wrong_type = valid.clone(); + wrong_type[0] = 0x15; + corpus.push(wrong_type); + + let mut wrong_handshake = valid.clone(); + wrong_handshake[5] = 0x02; + corpus.push(wrong_handshake); + + let mut wrong_length = valid.clone(); + wrong_length[3] ^= 0x7f; + corpus.push(wrong_length); + + for (idx, input) in corpus.iter().enumerate() { + assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok()); + assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok()); + + if idx == 0 { + continue; + } + + assert!( + extract_sni_from_client_hello(input).is_none(), + "corpus item {idx} must fail closed for SNI" + ); + assert!( + extract_alpn_from_client_hello(input).is_empty(), + "corpus item {idx} must fail closed for ALPN" + ); + } +} + +#[test] +fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() { + let secret = b"tls_fuzz_security_secret"; + let now: i64 = 1_700_000_000; + let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]); + let secrets = vec![("fuzz-user".to_string(), secret.to_vec())]; + + assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some()); + + let mut corpus = Vec::new(); + + let mut truncated = base.clone(); + truncated.truncate(TLS_DIGEST_POS + 16); + corpus.push(truncated); + + let mut digest_flip = base.clone(); + digest_flip[TLS_DIGEST_POS + 7] ^= 0x80; + corpus.push(digest_flip); + + let mut session_id_len_overflow = base.clone(); + session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33; + corpus.push(session_id_len_overflow); + + let mut timestamp_far_past = base.clone(); + timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_past); + + let mut timestamp_far_future = base.clone(); + timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_future); + + let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64; + for _ in 0..32 { + let mut mutated = base.clone(); + for _ in 0..2 { + seed = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN); + mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, handshake) in corpus.iter().enumerate() { + let result = + catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now)); + assert!(result.is_ok(), "corpus item {idx} must not panic"); + assert!( + result.unwrap().is_none(), + "corpus item {idx} must fail closed" + ); + } +} + +#[test] +fn tls_boot_time_acceptance_is_capped_by_replay_window() { + let secret = b"tls_boot_time_cap_secret"; + let secrets = vec![("boot-user".to_string(), secret.to_vec())]; + let boot_ts = 1u32; + let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]); + + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(), + "boot-time timestamp should be accepted while replay window permits it" + ); + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(), + "boot-time timestamp must be rejected when replay window disables the bypass" + ); +} diff --git a/src/protocol/tests/tls_length_cast_hardening_security_tests.rs b/src/protocol/tests/tls_length_cast_hardening_security_tests.rs new file mode 100644 index 0000000..31418e4 --- /dev/null +++ b/src/protocol/tests/tls_length_cast_hardening_security_tests.rs @@ -0,0 +1,37 @@ +use super::*; + +#[test] +fn extension_builder_fails_closed_on_u16_length_overflow() { + let builder = TlsExtensionBuilder { + extensions: vec![0u8; (u16::MAX as usize) + 1], + }; + + let built = builder.build(); + assert!( + built.is_empty(), + "oversized extension blob must fail closed instead of truncating length field" + ); +} + +#[test] +fn server_hello_builder_fails_closed_on_session_id_len_overflow() { + let builder = ServerHelloBuilder { + random: [0u8; 32], + session_id: vec![0xAB; (u8::MAX as usize) + 1], + cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256, + compression: 0, + extensions: TlsExtensionBuilder::new(), + }; + + let message = builder.build_message(); + let record = builder.build_record(); + + assert!( + message.is_empty(), + "session_id length overflow must fail closed in message builder" + ); + assert!( + record.is_empty(), + "session_id length overflow must fail closed in record builder" + ); +} diff --git a/src/protocol/tests/tls_security_tests.rs b/src/protocol/tests/tls_security_tests.rs new file mode 100644 index 0000000..3008e57 --- /dev/null +++ b/src/protocol/tests/tls_security_tests.rs @@ -0,0 +1,2429 @@ +use super::*; +use crate::crypto::sha256_hmac; +use crate::tls_front::emulator::build_emulated_server_hello; +use crate::tls_front::types::{ + CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource, +}; +use std::time::SystemTime; + +/// Build a TLS-handshake-like buffer that contains a valid HMAC digest +/// for the given `secret` and `timestamp`. +/// +/// Layout (bytes): +/// [0..TLS_DIGEST_POS] : fixed filler (0x42) +/// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le] +/// [TLS_DIGEST_POS+32] : session_id_len = 32 +/// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42) +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + assert!(session_id_len <= u8::MAX as usize); + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + // Zero the digest slot before computing HMAC (mirrors what validate does). + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + + // digest = HMAC such that XOR with stored digest yields [0..0, timestamp_le]. + // bytes 0-27 of digest == computed[0..28] -> xored[..28] == 0 + // bytes 28-31 of digest == computed[28..32] XOR timestamp_le + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32]) +} + +// ------------------------------------------------------------------ +// Happy-path sanity +// ------------------------------------------------------------------ + +#[test] +fn valid_handshake_with_correct_secret_accepted() { + let secret = b"correct_horse_battery_staple_32b"; + // timestamp = 0 triggers is_boot_time path, accepted without wall-clock check. + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("alice".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "Valid handshake must be accepted"); + assert_eq!(result.unwrap().user, "alice"); +} + +#[test] +fn deterministic_external_vector_validates_without_helper() { + // Deterministic vector generated by an external Python stdlib HMAC script, + // not by this test module helper. This catches mirrored helper mistakes. + let secret = hex::decode("00112233445566778899aabbccddeeff").unwrap(); + let handshake = hex::decode( + "4242424242424242424242a93225d1d6b46260bc9ce0cc48c7487d2b1ca5afa7ae9fc6609d9e60a3ca842b204242424242424242424242424242424242424242424242424242424242424242", + ) + .unwrap(); + + let secrets = vec![("vector_user".to_string(), secret)]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + assert_eq!(result.user, "vector_user"); + assert_eq!(result.timestamp, 0x01020304); +} + +#[test] +fn valid_handshake_timestamp_extracted_correctly() { + let secret = b"ts_extraction_test"; + let ts: u32 = 0xDEAD_BEEF; + let handshake = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().timestamp, ts); +} + +// ------------------------------------------------------------------ +// HMAC bit-flip rejection - adversarial HMAC forgery attempts +// ------------------------------------------------------------------ + +/// Flip every single bit across the 28-byte HMAC check window one at a +/// time. Each flip must cause rejection. This is the primary guard +/// against a censor gradually narrowing down a valid HMAC via partial +/// matches (which would be exploitable with a variable-time comparison). +#[test] +fn hmac_single_bit_flip_anywhere_in_check_window_rejected() { + let secret = b"flip_test_secret"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // First ensure the unmodified handshake is accepted. + assert!( + validate_tls_handshake(&base, &secrets, true).is_some(), + "Baseline handshake must be accepted before flip tests" + ); + + for byte_pos in 0..28usize { + for bit in 0u8..8 { + let mut h = base.clone(); + h[TLS_DIGEST_POS + byte_pos] ^= 1 << bit; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip of bit {bit} in HMAC byte {byte_pos} must be rejected" + ); + } + } +} + +/// XOR entire check window (bytes 0-27) with 0xFF - must still fail. +#[test] +fn hmac_full_window_corruption_rejected() { + let secret = b"full_window_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + for i in 0..28 { + h[TLS_DIGEST_POS + i] ^= 0xFF; + } + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Byte 27 is the last byte in the checked window. A non-constant-time +/// `all(|b| b == 0)` that short-circuits on byte 0 would never even reach +/// byte 27, making this an effective "did the fix actually run to the end" +/// sentinel: if this passes but the earlier byte-0 test fails, the check +/// window is not being evaluated end-to-end. +#[test] +fn hmac_last_byte_of_check_window_enforced() { + let secret = b"last_byte_sentinel"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt only byte 27. + h[TLS_DIGEST_POS + 27] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Corruption at byte 27 (end of HMAC window) must cause rejection" + ); +} + +// ------------------------------------------------------------------ +// User enumeration / multi-user ordering +// ------------------------------------------------------------------ + +#[test] +fn wrong_user_secret_rejected_even_with_valid_structure() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + // Only user_a is configured. + let secrets = vec![("user_a".to_string(), secret_a.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "Handshake for user_b must fail when only user_a is configured" + ); +} + +#[test] +fn second_user_in_list_found_when_first_does_not_match() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + let secrets = vec![ + ("user_a".to_string(), secret_a.to_vec()), + ("user_b".to_string(), secret_b.to_vec()), + ]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!( + result.is_some(), + "user_b must be found even though user_a comes first" + ); + assert_eq!(result.unwrap().user, "user_b"); +} + +#[test] +fn duplicate_secret_keeps_first_user_identity() { + // If multiple entries share the same secret, the selected identity must + // stay stable and deterministic (first entry wins). + let shared = b"same_secret_for_two_users"; + let handshake = make_valid_tls_handshake(shared, 0); + let secrets = vec![ + ("first_user".to_string(), shared.to_vec()), + ("second_user".to_string(), shared.to_vec()), + ]; + + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().user, "first_user"); +} + +#[test] +fn no_user_matches_returns_none() { + let secret_a = b"aaa"; + let secret_b = b"bbb"; + let secret_c = b"ccc"; + let handshake = make_valid_tls_handshake(b"unknown_secret", 0); + let secrets = vec![ + ("a".to_string(), secret_a.to_vec()), + ("b".to_string(), secret_b.to_vec()), + ("c".to_string(), secret_c.to_vec()), + ]; + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +#[test] +fn empty_secrets_list_rejects_everything() { + let secret = b"test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets: Vec<(String, Vec)> = Vec::new(); + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +// ------------------------------------------------------------------ +// Timestamp / time-skew boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn timestamp_at_time_skew_boundaries_accepted() { + let secret = b"skew_boundary_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = now - ts = TIME_SKEW_MIN = -1200 + // -> ts = now - TIME_SKEW_MIN = now + 1200 (20 min in the future). + let ts_at_future_limit = (now - TIME_SKEW_MIN) as u32; + let h = make_valid_tls_handshake(secret, ts_at_future_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed future (time_diff = TIME_SKEW_MIN) must be accepted" + ); + + // time_diff = now - ts = TIME_SKEW_MAX = 600 + // -> ts = now - TIME_SKEW_MAX = now - 600 (10 min in the past). + let ts_at_past_limit = (now - TIME_SKEW_MAX) as u32; + let h = make_valid_tls_handshake(secret, ts_at_past_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed past (time_diff = TIME_SKEW_MAX) must be accepted" + ); +} + +#[test] +fn timestamp_one_second_outside_skew_window_rejected() { + let secret = b"skew_outside_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = TIME_SKEW_MAX + 1 = 601 (one second too far in the past) + // -> ts = now - (TIME_SKEW_MAX + 1) = now - 601 + let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_past); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the past must be rejected" + ); + + // time_diff = TIME_SKEW_MIN - 1 = -1201 (one second too far in the future) + // -> ts = now - (TIME_SKEW_MIN - 1) = now + 1201 + let ts_too_future = (now - TIME_SKEW_MIN + 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_future); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the future must be rejected" + ); +} + +#[test] +fn ignore_time_skew_accepts_far_future_timestamp() { + let secret = b"ignore_skew_test"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // 1 hour in the future - outside TIME_SKEW_MAX but should pass with flag. + let future_ts = (now + 3600) as u32; + let h = make_valid_tls_handshake(secret, future_ts); + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, now).is_some(), + "ignore_time_skew=true must override window rejection" + ); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "ignore_time_skew=false must still reject far-future timestamp" + ); +} + +#[test] +fn boot_time_timestamp_accepted_without_ignore_flag() { + // Timestamps below the boot-time threshold are treated as client uptime, + // not real wall-clock time. The proxy allows them regardless of skew. + let secret = b"boot_time_test"; + // Keep this safely below compatibility cap to assert bypass behavior. + let boot_ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1); + let handshake = make_valid_tls_handshake(secret, boot_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, false).is_some(), + "Boot-time timestamp must be accepted even with ignore_time_skew=false" + ); +} + +// ------------------------------------------------------------------ +// Structural / length boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn too_short_handshake_rejected_without_panic() { + let secrets = vec![("u".to_string(), b"s".to_vec())]; + // Exactly one byte short of the minimum required length. + let h = vec![0u8; TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); + + // Empty buffer. + assert!(validate_tls_handshake(&[], &secrets, true).is_none()); +} + +#[test] +fn all_prefix_lengths_below_minimum_rejected_without_panic() { + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + + for len in 0..min_len { + let h = vec![0u8; len]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "prefix length {len} below minimum must be rejected" + ); + } +} + +#[test] +fn claimed_session_id_overflows_buffer_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0u8; min_len]; + // Claim session_id is 33 bytes - one more than the buffer holds. + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = (session_id_len + 1) as u8; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn max_session_id_len_255_does_not_panic() { + // session_id_len = 255 with a buffer that is far too small for it. + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + 32; + let mut h = vec![0u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 255; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn one_byte_session_id_validates_and_is_preserved() { + let secret = b"sid_len_1_test"; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &[0xAB]); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&handshake, &secrets, true) + .expect("one-byte session_id handshake must validate"); + assert_eq!(result.session_id, vec![0xAB]); +} + +#[test] +fn max_session_id_len_255_with_valid_digest_is_rejected_by_rfc_cap() { + let secret = b"sid_len_255_test"; + let session_id = vec![0xCCu8; 255]; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected even with valid digest" + ); +} + +// ------------------------------------------------------------------ +// Adversarial digest values +// ------------------------------------------------------------------ + +#[test] +fn all_zeros_digest_rejected() { + // An all-zeros digest would only pass if HMAC(secret, msg) happens to + // have its first 28 bytes all zero, which is computationally infeasible. + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn all_ones_digest_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0xFF); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Simulate a censor that sends 200 crafted packets with random digests. +/// Every single one must be rejected; no random digest should accidentally +/// pass (probability 2^{-224} per attempt; negligible for 200 trials). +#[test] +fn censor_probe_random_digests_all_rejected() { + use crate::crypto::SecureRandom; + let secret = b"production_like_secret_value_xyz"; + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let rng = SecureRandom::new(); + + for attempt in 0..200 { + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let rand_digest = rng.bytes(TLS_DIGEST_LEN); + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&rand_digest); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Random digest at attempt {attempt} must not match" + ); + } +} + +/// The check window is bytes 0-27 of the XOR result. Bytes 28-31 encode +/// the timestamp and must NOT affect whether the HMAC portion validates - +/// only the timestamp range check uses them. Build a valid handshake with +/// timestamp = 0 (boot-time), flip each of bytes 28-31 with ignore_time_skew +/// enabled, and verify the HMAC portion still passes (the timestamp changes +/// but the proxy still accepts the connection under ignore_time_skew). +#[test] +fn timestamp_bytes_28_31_do_not_affect_hmac_window() { + let secret = b"window_boundary_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // Baseline must pass. + assert!(validate_tls_handshake(&base, &secrets, true).is_some()); + + // Flip each of the timestamp bytes; with ignore_time_skew the + // modified timestamps (small absolute values) still pass boot-time check. + for i in 28..32usize { + let mut h = base.clone(); + h[TLS_DIGEST_POS + i] ^= 0xFF; + // The new timestamp is non-zero but potentially still < boot threshold; + // use ignore_time_skew=true so wallet test is HMAC-only. + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "Flipping byte {i} (timestamp region) must not invalidate HMAC window" + ); + } +} + +// ------------------------------------------------------------------ +// session_id preservation +// ------------------------------------------------------------------ + +#[test] +fn session_id_is_preserved_verbatim_in_validation_result() { + // If session_id extraction is ever broken (wrong offset, wrong length, + // off-by-one), this test will catch it before it silently corrupts the + // ServerHello that echoes the session_id back to the client. + let secret = b"session_id_preservation_test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + let sid_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; + let sid_len = handshake[sid_len_pos] as usize; + let expected = &handshake[sid_len_pos + 1..sid_len_pos + 1 + sid_len]; + + assert_eq!( + result.session_id, expected, + "session_id in TlsValidation must be the verbatim bytes from the handshake" + ); +} + +// ------------------------------------------------------------------ +// Clock decoupling - ignore_time_skew must not consult the system clock +// ------------------------------------------------------------------ + +/// When `ignore_time_skew = true`, a valid HMAC must be accepted even if +/// `now = 0` (the sentinel used when the clock is not needed). A broken +/// system clock cannot silently deny service when the admin has explicitly +/// disabled timestamp checking. +#[test] +fn ignore_time_skew_accepts_valid_hmac_with_now_zero() { + let secret = b"clock_decoupling_test"; + // Use a realistic Unix timestamp that would be far outside the window + // if compared against now=0 (time_diff would be ~-1_700_000_000). + let realistic_ts: u32 = 1_700_000_000; + let h = make_valid_tls_handshake(secret, realistic_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_some(), + "ignore_time_skew=true must accept a valid HMAC regardless of `now`" + ); + + // Confirm that the same handshake IS rejected when the window is enforced + // and now=0 (time_diff very negative -> outside window). This distinguishes + // "clock decoupling" from "always accept". + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), + "ignore_time_skew=false with now=0 must still reject out-of-window timestamps" + ); +} + +/// An HMAC-invalid handshake must be rejected even when ignore_time_skew=true +/// and now=0. Verifies that the clock-decoupling fix did not weaken HMAC +/// enforcement in the ignore_time_skew path. +#[test] +fn ignore_time_skew_with_now_zero_still_rejects_bad_hmac() { + let secret = b"clock_no_backdoor_test"; + let mut h = make_valid_tls_handshake(secret, 1_700_000_000); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt the HMAC check window. + h[TLS_DIGEST_POS] ^= 0xFF; + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_none(), + "Broken HMAC must be rejected even with ignore_time_skew=true and now=0" + ); +} + +#[test] +fn system_time_before_unix_epoch_is_rejected_without_panic() { + let before_epoch = UNIX_EPOCH + .checked_sub(std::time::Duration::from_secs(1)) + .expect("UNIX_EPOCH minus one second must be representable"); + assert!(system_time_to_unix_secs(before_epoch).is_none()); +} + +/// `i64::MAX` is 9_223_372_036_854_775_807 seconds (~292 billion years CE). +/// Any `SystemTime` whose duration since epoch exceeds `i64::MAX` seconds +/// must return `None` rather than silently wrapping to a large negative +/// timestamp that would corrupt every subsequent time-skew comparison. +#[test] +fn system_time_far_future_overflowing_i64_returns_none() { + // i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`. + let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1; + if let Some(far_future) = UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs)) + { + assert!( + system_time_to_unix_secs(far_future).is_none(), + "Seconds > i64::MAX must return None, not a wrapped negative timestamp" + ); + } + // If the platform cannot represent this SystemTime, the test is vacuously + // satisfied: `checked_add` returning None means the platform already rejects it. +} + +// ------------------------------------------------------------------ +// Message canonicalization — HMAC covers every byte of the handshake +// ------------------------------------------------------------------ + +/// Every byte before TLS_DIGEST_POS is part of the HMAC input (because msg +/// = full handshake with only the digest slot zeroed). An attacker cannot +/// replay a valid handshake with a modified ClientHello header while keeping +/// the stored digest; each such modification produces a different HMAC. +#[test] +fn pre_digest_bytes_are_hmac_covered() { + // TLS_DIGEST_POS = 11, so 11 bytes precede the digest. + let secret = b"pre_digest_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for byte_pos in 0..TLS_DIGEST_POS { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in pre-digest byte {byte_pos} must cause HMAC check failure" + ); + } +} + +/// session_id bytes follow the digest in the buffer and are also part of the +/// HMAC input. Flipping any of them invalidates the stored digest, preventing +/// a censor from capturing a valid session_id and replaying it with a different +/// one while keeping the rest of the packet intact. +#[test] +fn session_id_bytes_are_hmac_covered() { + let secret = b"session_id_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); // session_id_len = 32 + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + for byte_pos in sid_start..base.len() { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in session_id byte at offset {byte_pos} must cause HMAC check failure" + ); + } +} + +/// Appending even one byte to a valid handshake changes the HMAC input (msg +/// includes all bytes) and therefore invalidates the stored digest. This +/// prevents a length-extension-style modification of the payload. +#[test] +fn appended_trailing_byte_causes_rejection() { + let secret = b"trailing_byte_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "baseline" + ); + + h.push(0x00); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Appending a trailing byte to a valid handshake must invalidate the HMAC" + ); +} + +// ------------------------------------------------------------------ +// Zero-length session_id (structural edge case) +// ------------------------------------------------------------------ + +/// session_id_len = 0 is legal in the TLS spec. The validator must accept a +/// valid handshake with an empty session_id and return an empty session_id +/// slice without panicking or accessing out-of-bounds memory. +#[test] +fn zero_length_session_id_accepted() { + let secret = b"zero_sid_test"; + // Buffer: pre-digest | digest | session_id_len=0 (no session_id bytes follow) + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 0; // session_id_len = 0 + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + // timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged. + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&computed); + + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "zero-length session_id must be accepted"); + assert!( + result.unwrap().session_id.is_empty(), + "session_id field must be empty when session_id_len = 0" + ); +} + +// ------------------------------------------------------------------ +// Boot-time threshold — exact boundary precision +// ------------------------------------------------------------------ + +/// timestamp = BOOT_TIME_COMPAT_MAX_SECS - 1 is the last value inside +/// the runtime boot-time compatibility window. +/// is_boot_time = true → skew check is skipped entirely → accepted even +/// when `now` is far from the timestamp. +#[test] +fn timestamp_one_below_boot_threshold_bypasses_skew_check() { + let secret = b"boot_last_value_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS - 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = 0 → time_diff would be -86_399_999, way outside [-1200, 600]. + // Boot-time bypass must prevent the skew check from running. + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(), + "ts=BOOT_TIME_COMPAT_MAX_SECS-1 must bypass skew check regardless of now" + ); +} + +/// timestamp = BOOT_TIME_COMPAT_MAX_SECS is the first value outside the +/// runtime boot-time compatibility window. +/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this: +/// once with now chosen so the skew passes (accepted) and once where it fails. +#[test] +fn timestamp_at_boot_threshold_triggers_skew_check() { + let secret = b"boot_exact_value_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted. + let now_valid: i64 = ts as i64 + 50; + assert!( + validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + false, + now_valid, + BOOT_TIME_COMPAT_MAX_SECS, + ) + .is_some(), + "ts=BOOT_TIME_COMPAT_MAX_SECS within skew window must be accepted via skew check" + ); + + // now = -1 → time_diff = -121 at the 120-second threshold, outside window + // for TIME_SKEW_MIN=-120. If boot-time bypass were wrongly applied this + // would pass. + assert!( + validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + false, + -1, + BOOT_TIME_COMPAT_MAX_SECS, + ) + .is_none(), + "ts=BOOT_TIME_COMPAT_MAX_SECS far from now must be rejected — no boot-time bypass" + ); +} + +#[test] +fn replay_window_cap_disables_boot_bypass_for_old_timestamps() { + let secret = b"boot_cap_disabled_test"; + let ts: u32 = 900; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_none(), + "timestamp above replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn replay_window_cap_still_allows_small_boot_timestamp() { + let secret = b"boot_cap_enabled_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1); + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_some(), + "timestamp below replay-window cap must retain boot-time compatibility" + ); +} + +#[test] +fn large_replay_window_is_hard_capped_for_boot_compatibility() { + let secret = b"boot_cap_hard_limit_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS + 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, u64::MAX); + assert!( + result.is_none(), + "very large replay window must not expand boot-time bypass beyond hard compatibility cap" + ); +} + +#[test] +fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() { + let secret = b"ignore_skew_boot_cap_decouple_test"; + let ts: u32 = 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); + let cap_nonzero = validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + true, + 0, + BOOT_TIME_COMPAT_MAX_SECS, + ); + + assert!( + cap_zero.is_some(), + "ignore_time_skew=true must accept valid HMAC" + ); + assert!( + cap_nonzero.is_some(), + "ignore_time_skew path must not depend on boot-time cap" + ); + + let a = cap_zero.unwrap(); + let b = cap_nonzero.unwrap(); + assert_eq!(a.user, b.user); + assert_eq!(a.timestamp, b.timestamp); +} + +#[test] +fn adversarial_small_boot_timestamp_matrix_rejected_when_boot_cap_forced_zero() { + let secret = b"boot_cap_zero_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for ts in 0u32..1024u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "boot cap=0 must reject timestamp {ts} when skew checks are active" + ); + } +} + +#[test] +fn light_fuzz_boot_cap_zero_rejects_small_timestamp_space() { + let secret = b"boot_cap_zero_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = (s as u32) % 2048; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "fuzzed boot-range timestamp {ts} must be rejected when cap=0" + ); + } +} + +#[test] +fn stress_boot_cap_zero_rejection_is_deterministic_under_high_iteration_count() { + let secret = b"boot_cap_zero_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for i in 0u32..20_000u32 { + let ts = i % 4096; + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "iteration {i}: timestamp {ts} must be rejected with cap=0" + ); + } +} + +#[test] +fn replay_window_one_allows_only_zero_timestamp_boot_bypass() { + let secret = b"replay_window_one_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 1).is_some(), + "replay_window=1 must allow timestamp 0 via boot-time compatibility" + ); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 1).is_none(), + "replay_window=1 must reject timestamp 1 on normal wall-clock systems" + ); +} + +#[test] +fn replay_window_two_allows_ts0_ts1_but_rejects_ts2() { + let secret = b"replay_window_two_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + let ts2 = make_valid_tls_handshake(secret, 2); + + assert!(validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 2).is_some()); + assert!(validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 2).is_some()); + assert!( + validate_tls_handshake_with_replay_window(&ts2, &secrets, false, 2).is_none(), + "timestamp equal to replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disabled() { + let secret = b"skew_boundary_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for offset in -1500i64..=1500i64 { + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix"); + let h = make_valid_tls_handshake(secret, ts); + let accepted = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset); + assert_eq!( + accepted, expected, + "offset {offset} must match inclusive skew window when boot bypass is disabled" + ); + } +} + +#[test] +fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() { + let secret = b"skew_outside_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x0123_4567_89AB_CDEF; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let magnitude = 1300i64 + ((s % 2000u64) as i64); + let sign = if (s & 1) == 0 { 1i64 } else { -1i64 }; + let offset = sign * magnitude; + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test"); + + let h = make_valid_tls_handshake(secret, ts); + let accepted = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); + assert!( + !accepted, + "offset {offset} must be rejected outside strict skew window" + ); + } +} + +#[test] +fn stress_boot_disabled_validation_matches_time_diff_oracle() { + let secret = b"boot_disabled_oracle_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0xBADC_0FFE_EE11_2233; + + for _ in 0..25_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + let h = make_valid_tls_handshake(secret, ts); + + let accepted = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); + let time_diff = now - i64::from(ts); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff); + assert_eq!( + accepted, expected, + "boot-disabled validation must match pure time-diff oracle" + ); + } +} + +#[test] +fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() { + let now: i64 = 1_700_000_000; + let target_secret = b"target_user_secret"; + let target_ts = (now - 1) as u32; + let handshake = make_valid_tls_handshake(target_secret, target_ts); + + let mut secrets = Vec::new(); + for i in 0..512u32 { + secrets.push(( + format!("noise-{i}"), + format!("noise-secret-{i}").into_bytes(), + )); + } + secrets.push(("target-user".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake_at_time_with_boot_cap(&handshake, &secrets, false, now, 0) + .expect("matching user should validate within strict skew window"); + assert_eq!(result.user, "target-user"); +} + +#[test] +fn light_fuzz_ignore_time_skew_accepts_wide_timestamp_range_with_valid_hmac() { + let secret = b"ignore_skew_fuzz_accept_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let mut s: u64 = 0xC0FF_EE11_2233_4455; + + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_with_replay_window(&h, &secrets, true, 60); + assert!( + result.is_some(), + "ignore_time_skew=true must accept valid HMAC for arbitrary timestamp" + ); + } +} + +#[test] +fn light_fuzz_small_replay_window_rejects_far_timestamps_when_skew_enabled() { + let secret = b"replay_window_reject_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for ts in 300u32..=1323u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, 0, 300); + assert!( + result.is_none(), + "with skew checks enabled and boot cap=300, timestamp >=300 at now=0 must be rejected" + ); + } +} + +// ------------------------------------------------------------------ +// Extreme timestamp values +// ------------------------------------------------------------------ + +/// u32::MAX is a valid timestamp value. When ignore_time_skew=true the HMAC +/// is the only gate, and a correctly constructed handshake must be accepted. +#[test] +fn u32_max_timestamp_accepted_with_ignore_time_skew() { + let secret = b"u32_max_ts_accept_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!( + result.is_some(), + "u32::MAX timestamp must be accepted with ignore_time_skew=true" + ); + assert_eq!( + result.unwrap().timestamp, + u32::MAX, + "timestamp field must equal u32::MAX verbatim" + ); +} + +/// u32::MAX > BOOT_TIME_MAX_SECS so the skew check runs. With any realistic `now` +/// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside +/// [-1200, 600] — so the handshake must be rejected without overflow. +#[test] +fn u32_max_timestamp_rejected_by_skew_enforcement() { + let secret = b"u32_max_ts_reject_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let now: i64 = 1_700_000_000; + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "u32::MAX timestamp must be rejected by skew check with realistic now" + ); +} + +// ------------------------------------------------------------------ +// Validation result field correctness +// ------------------------------------------------------------------ + +/// result.digest must be the verbatim bytes stored in the handshake buffer, +/// not the freshly recomputed HMAC. Callers use this field directly when +/// constructing the ServerHello response digest. +#[test] +fn result_digest_field_is_verbatim_stored_digest() { + let secret = b"digest_field_verbatim_test"; + let ts: u32 = 0xCAFE_BABE; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true).unwrap(); + + let stored: [u8; TLS_DIGEST_LEN] = h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .try_into() + .unwrap(); + assert_eq!( + result.digest, stored, + "result.digest must equal the stored bytes, not the computed HMAC" + ); +} + +// ------------------------------------------------------------------ +// Secret length edge cases +// ------------------------------------------------------------------ + +/// HMAC-SHA256 pads or hashes keys of any length; a single-byte key must work. +#[test] +fn single_byte_secret_works() { + let secret = b"x"; + let h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "single-byte secret must produce a valid and verifiable HMAC" + ); +} + +/// Keys longer than the HMAC block size (64 bytes for SHA-256) are hashed +/// before use. A 256-byte key must work without truncation or panic. +#[test] +fn very_long_secret_256_bytes_works() { + let secret = vec![0xABu8; 256]; + let h = make_valid_tls_handshake(&secret, 0); + let secrets = vec![("u".to_string(), secret.clone())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "256-byte secret must be accepted without truncation" + ); +} + +// ------------------------------------------------------------------ +// Determinism — same input must always produce same result +// ------------------------------------------------------------------ + +/// Calling validate twice on the same input must return identical results. +/// Non-determinism (e.g. from an accidentally global mutable state or a +/// shared nonce) would be a critical security defect in a proxy that rejects +/// censors by relying on stable authentication outcomes. +#[test] +fn validation_is_deterministic() { + let secret = b"determinism_test_key"; + let h = make_valid_tls_handshake(secret, 42); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let r1 = validate_tls_handshake(&h, &secrets, true).unwrap(); + let r2 = validate_tls_handshake(&h, &secrets, true).unwrap(); + + assert_eq!(r1.user, r2.user); + assert_eq!(r1.session_id, r2.session_id); + assert_eq!(r1.digest, r2.digest); + assert_eq!(r1.timestamp, r2.timestamp); +} + +// ------------------------------------------------------------------ +// Multi-user: scan-all correctness guarantees +// ------------------------------------------------------------------ + +/// The matching logic must scan through the entire secrets list. A user +/// at position 99 of 100 must be found; an implementation that stops early +/// on the first non-match would fail this test. +#[test] +fn last_user_in_large_list_is_found() { + let target_secret = b"needle_in_haystack"; + let h = make_valid_tls_handshake(target_secret, 0); + + let mut secrets: Vec<(String, Vec)> = (0..99) + .map(|i| (format!("decoy_{i}"), format!("wrong_{i}").into_bytes())) + .collect(); + secrets.push(("needle".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some(), "100th user must be found"); + assert_eq!(result.unwrap().user, "needle"); +} + +/// When multiple users share the same secret the first occurrence must always +/// win. The scan-all loop must not replace first_match with a later one. +#[test] +fn first_matching_user_wins_over_later_duplicate_secret() { + let shared = b"duplicated_secret_key"; + let h = make_valid_tls_handshake(shared, 0); + + let secrets = vec![ + ("decoy_1".to_string(), b"wrong_1".to_vec()), + ("winner".to_string(), shared.to_vec()), // first match + ("decoy_2".to_string(), b"wrong_2".to_vec()), + ("loser".to_string(), shared.to_vec()), // second match — must not win + ("decoy_3".to_string(), b"wrong_3".to_vec()), + ]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some()); + assert_eq!( + result.unwrap().user, + "winner", + "first matching user must be returned even when a later entry also matches" + ); +} + +// ------------------------------------------------------------------ +// Legacy tls.rs tests moved here +// ------------------------------------------------------------------ + +#[test] +fn test_is_tls_handshake() { + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03, 0x02, 0x00])); + assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); + assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); + assert!(!is_tls_handshake(&[0x16, 0x03])); +} + +#[test] +fn test_parse_tls_record_header() { + let header = [0x16, 0x03, 0x01, 0x02, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_HANDSHAKE); + assert_eq!(result.1, 512); + + let header = [0x17, 0x03, 0x03, 0x40, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_APPLICATION); + assert_eq!(usize::from(result.1), MAX_TLS_PLAINTEXT_SIZE); +} + +#[test] +fn parse_tls_record_header_rejects_invalid_versions() { + let invalid = [ + [0x16, 0x03, 0x00, 0x00, 0x10], + [0x16, 0x02, 0x00, 0x00, 0x10], + [0x16, 0x03, 0x02, 0x00, 0x10], + [0x16, 0x04, 0x00, 0x00, 0x10], + ]; + for header in invalid { + assert!( + parse_tls_record_header(&header).is_none(), + "invalid TLS record version {:?} must be rejected", + [header[1], header[2]] + ); + } +} + +#[test] +fn test_gen_fake_x25519_key() { + let rng = crate::crypto::SecureRandom::new(); + let key1 = gen_fake_x25519_key(&rng); + let key2 = gen_fake_x25519_key(&rng); + + assert_eq!(key1.len(), 32); + assert_eq!(key2.len(), 32); + assert_ne!(key1, key2); +} + +#[test] +fn test_fake_x25519_key_is_nonzero_and_varies() { + let rng = crate::crypto::SecureRandom::new(); + let mut unique = std::collections::HashSet::new(); + let mut saw_non_zero = false; + + for _ in 0..64 { + let key = gen_fake_x25519_key(&rng); + if key != [0u8; 32] { + saw_non_zero = true; + } + unique.insert(key); + } + + assert!( + saw_non_zero, + "generated X25519 public keys must not collapse to all-zero output" + ); + assert!( + unique.len() > 1, + "generated X25519 public keys must vary across invocations" + ); +} + +#[test] +fn validate_tls_handshake_rejects_session_id_longer_than_rfc_cap() { + let secret = b"session_id_cap_secret"; + let oversized_sid = vec![0x42u8; 33]; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &oversized_sid); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected" + ); +} + +fn server_hello_extension_types(record: &[u8]) -> Vec { + if record.len() < 9 || record[0] != TLS_RECORD_HANDSHAKE || record[5] != 0x02 { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([record[3], record[4]]) as usize; + if record.len() < 5 + record_len { + return Vec::new(); + } + + let hs_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let hs_start = 5; + let hs_end = hs_start + 4 + hs_len; + if hs_end > record.len() { + return Vec::new(); + } + + let mut pos = hs_start + 4 + 2 + 32; + if pos >= hs_end { + return Vec::new(); + } + let sid_len = record[pos] as usize; + pos += 1 + sid_len; + if pos + 2 + 1 + 2 > hs_end { + return Vec::new(); + } + + pos += 2 + 1; + let ext_len = u16::from_be_bytes([record[pos], record[pos + 1]]) as usize; + pos += 2; + let ext_end = pos + ext_len; + if ext_end > hs_end { + return Vec::new(); + } + + let mut out = Vec::new(); + while pos + 4 <= ext_end { + let etype = u16::from_be_bytes([record[pos], record[pos + 1]]); + let elen = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize; + pos += 4; + if pos + elen > ext_end { + break; + } + out.push(etype); + pos += elen; + } + out +} + +#[test] +fn build_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_sh_forbidden"; + let client_digest = [0x11u8; 32]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in ServerHello" + ); +} + +#[test] +fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_emulated_forbidden"; + let client_digest = [0x22u8; 32]; + let session_id = vec![0xAB; 32]; + let rng = crate::crypto::SecureRandom::new(); + let cached = CachedTlsData { + server_hello_template: ParsedServerHello { + version: TLS_VERSION, + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload: None, + app_data_records_sizes: vec![1024], + total_app_data_len: 1024, + behavior_profile: TlsBehaviorProfile { + change_cipher_spec_count: 1, + app_data_record_sizes: vec![1024], + ticket_record_sizes: Vec::new(), + source: TlsProfileSource::Default, + }, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + }; + + let response = build_emulated_server_hello( + secret, + &client_digest, + &session_id, + &cached, + false, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in emulated ServerHello" + ); +} + +#[test] +fn test_tls_extension_builder() { + let key = [0x42u8; 32]; + + let mut builder = TlsExtensionBuilder::new(); + builder.add_key_share(&key); + builder.add_supported_versions(0x0304); + + let result = builder.build(); + let len = u16::from_be_bytes([result[0], result[1]]) as usize; + + assert_eq!(len, result.len() - 2); + assert!(result.len() > 40); +} + +#[test] +fn test_server_hello_builder() { + let session_id = vec![0x01, 0x02, 0x03, 0x04]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id.clone()) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); + + assert_eq!(record[0], TLS_RECORD_HANDSHAKE); + assert_eq!(&record[1..3], &TLS_VERSION); + assert_eq!(record[5], 0x02); +} + +#[test] +fn test_build_server_hello_structure() { + let secret = b"test secret"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); + + assert!(response.len() > 100); + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + validate_server_hello_structure(&response).expect("Invalid ServerHello"); + + let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = server_hello_len; + assert!(response.len() > ccs_start + 6); + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + + let ccs_len = + 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + ccs_len; + assert!(response.len() > app_start + 5); + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); +} + +#[test] +fn test_build_server_hello_digest() { + let secret = b"test secret key here"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(!digest1.iter().all(|&b| b == 0)); + + let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert_ne!(digest1, digest2); +} + +#[test] +fn test_server_hello_extensions_length() { + let session_id = vec![0x01; 32]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + let msg_start = 5; + let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let session_id_pos = msg_start + 4 + 2 + 32; + let session_id_len = record[session_id_pos] as usize; + let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; + let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; + let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; + + assert_eq!( + ext_len, + extensions_data.len(), + "Extension length mismatch: declared {}, actual {}", + ext_len, + extensions_data.len() + ); +} + +#[test] +fn test_validate_tls_handshake_format() { + let mut handshake = vec![0u8; 100]; + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&[0x42; 32]); + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + let secrets = vec![("test".to_string(), b"secret".to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_none()); +} + +fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let host_bytes = host.as_bytes(); + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_ext.extend_from_slice(host_bytes); + + let mut ext_blob = Vec::new(); + for (typ, data) in exts { + ext_blob.extend_from_slice(&typ.to_be_bytes()); + ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&data); + } + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +fn build_client_hello_with_raw_extensions(ext_blob: &[u8]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn test_extract_sni_with_grease_extension() { + let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("example.com")); +} + +#[test] +fn test_extract_sni_tolerates_empty_unknown_extension() { + let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("test.local")); +} + +#[test] +fn test_extract_alpn_single() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&3u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2"]); +} + +#[test] +fn test_extract_alpn_multiple() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&11u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + alpn_data.push(4); + alpn_data.extend_from_slice(b"spdy"); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h3"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); +} + +#[test] +fn extract_sni_rejects_zero_length_host_name() { + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&3u16.to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&0u16.to_be_bytes()); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_raw_ipv4_literals() { + let ch = build_client_hello_with_exts(Vec::new(), "203.0.113.10"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_invalid_label_characters() { + let ch = build_client_hello_with_exts(Vec::new(), "exa_mple.com"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_oversized_label() { + let oversized = format!("{}.example.com", "a".repeat(64)); + let ch = build_client_hello_with_exts(Vec::new(), &oversized); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 0]); + + let mut ch = build_client_hello_with_raw_extensions(&ext_blob); + ch.pop(); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_session_id_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + ch[sid_len_pos] = 255; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_cipher_suites_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize; + ch[cipher_len_pos] = 0xFF; + ch[cipher_len_pos + 1] = 0xFF; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_compression_methods_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize; + let cipher_len = u16::from_be_bytes([ch[cipher_len_pos], ch[cipher_len_pos + 1]]) as usize; + let comp_len_pos = cipher_len_pos + 2 + cipher_len; + ch[comp_len_pos] = 0xFF; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_alpn_returns_empty_on_session_id_len_overflow() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&3u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + let mut ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let sid_len_pos = 5 + 4 + 2 + 32; + ch[sid_len_pos] = 255; + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +#[test] +fn extract_alpn_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 2, b'h']); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +#[test] +fn extract_alpn_rejects_nested_length_overflow() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&10u16.to_be_bytes()); + alpn_data.push(8); + alpn_data.extend_from_slice(b"h2"); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +// ------------------------------------------------------------------ +// Additional adversarial checks +// ------------------------------------------------------------------ + +#[test] +fn empty_secret_hmac_is_supported() { + let secret: &[u8] = b""; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("empty".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!( + result.is_some(), + "Empty HMAC key must not panic and must validate when correct" + ); +} + +#[test] +fn server_hello_digest_verifies_against_full_response() { + let secret = b"fronting_digest_verify_key"; + let client_digest = [0x42u8; TLS_DIGEST_LEN]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 1); + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_eq!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "ServerHello digest must be verifiable by a client that recomputes HMAC over full response" + ); +} + +#[test] +fn server_hello_digest_fails_after_single_byte_tamper() { + let secret = b"fronting_tamper_detect_key"; + let client_digest = [0x24u8; TLS_DIGEST_LEN]; + let session_id = vec![0xBB; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + response[TLS_DIGEST_POS + TLS_DIGEST_LEN + 1] ^= 0x01; + + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_ne!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "Tampering any response byte must invalidate the embedded digest" + ); +} + +#[test] +fn server_hello_application_data_payload_varies_across_runs() { + use std::collections::HashSet; + + let secret = b"fronting_payload_variability_key"; + let client_digest = [0x13u8; TLS_DIGEST_LEN]; + let session_id = vec![0x44; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut unique_payloads: HashSet> = HashSet::new(); + for _ in 0..16 { + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec(); + + assert!( + payload.iter().any(|&b| b != 0), + "Payload must not be all-zero deterministic filler" + ); + unique_payloads.insert(payload); + } + + assert!( + unique_payloads.len() >= 4, + "ApplicationData payload should vary across runs to reduce fingerprintability" + ); +} + +#[test] +fn replay_window_zero_disables_boot_bypass_for_any_nonzero_timestamp() { + let secret = b"window_zero_boot_bypass_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts1 = make_valid_tls_handshake(secret, 1); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 0).is_none(), + "replay_window_secs=0 must reject nonzero timestamps even in boot-time range" + ); + + let ts0 = make_valid_tls_handshake(secret, 0); + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 0).is_none(), + "replay_window_secs=0 enforces strict skew check and rejects timestamp=0 on normal wall-clock systems" + ); +} + +#[test] +fn large_replay_window_does_not_expand_time_skew_acceptance() { + let secret = b"large_replay_window_skew_bound_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + let ts_far_past = (now - 600) as u32; + let valid = make_valid_tls_handshake(secret, ts_far_past); + assert!( + validate_tls_handshake_with_replay_window(&valid, &secrets, false, 86_400).is_none(), + "large replay window must not relax strict skew check once boot-time bypass is not in play" + ); +} + +#[test] +fn parse_tls_record_header_accepts_tls_version_constant() { + let header = [ + TLS_RECORD_HANDSHAKE, + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x2A, + ]; + let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted"); + assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE); + assert_eq!(parsed.1, 42); +} + +#[test] +fn server_hello_clamps_fake_cert_len_lower_bound() { + let secret = b"fake_cert_lower_bound_test"; + let client_digest = [0x11u8; TLS_DIGEST_LEN]; + let session_id = vec![0x77; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!( + app_len, 64, + "fake cert payload must be clamped to minimum 64 bytes" + ); +} + +#[test] +fn server_hello_clamps_fake_cert_len_upper_bound() { + let secret = b"fake_cert_upper_bound_test"; + let client_digest = [0x22u8; TLS_DIGEST_LEN]; + let session_id = vec![0x66; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 65_535, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!( + app_len, MAX_TLS_CIPHERTEXT_SIZE, + "fake cert payload must be clamped to TLS record max bound" + ); +} + +#[test] +fn server_hello_new_session_ticket_count_matches_configuration() { + let secret = b"ticket_count_surface_test"; + let client_digest = [0x33u8; TLS_DIGEST_LEN]; + let session_id = vec![0x55; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let tickets: u8 = 3; + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + None, + tickets, + ); + + let mut pos = 0usize; + let mut app_records = 0usize; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!( + next <= response.len(), + "TLS record must stay inside response bounds" + ); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + } + pos = next; + } + + assert_eq!( + app_records, + 1 + tickets as usize, + "response must contain one main application record plus configured ticket-like tail records" + ); +} + +#[test] +fn server_hello_new_session_ticket_count_is_safely_capped() { + let secret = b"ticket_count_cap_test"; + let client_digest = [0x44u8; TLS_DIGEST_LEN]; + let session_id = vec![0x54; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + None, + u8::MAX, + ); + + let mut pos = 0usize; + let mut app_records = 0usize; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!( + next <= response.len(), + "TLS record must stay inside response bounds" + ); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + } + pos = next; + } + + assert_eq!( + app_records, 5, + "response must cap ticket-like tail records to four plus one main application record" + ); +} + +#[test] +fn boot_time_handshake_replay_remains_blocked_after_cache_window_expires() { + let secret = b"gap_t01_boot_replay"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate on first use"); + + let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40)); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + assert!( + !checker.check_and_add_tls_digest(digest_half), + "first use must not be treated as replay" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "immediate second use must be detected as replay" + ); + + std::thread::sleep(std::time::Duration::from_millis(70)); + + let validation_after_expiry = + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must still cryptographically validate after cache expiry"); + let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; + assert_eq!( + digest_half, digest_half_after_expiry, + "replay key must be stable for same handshake" + ); + + assert!( + checker.check_and_add_tls_digest(digest_half_after_expiry), + "after cache window expiry, the same boot-time handshake must still be treated as replay" + ); +} + +#[test] +fn adversarial_boot_time_handshake_should_not_be_replayable_after_cache_expiry() { + let secret = b"gap_t01_boot_replay_adversarial"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate on first use"); + + let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40)); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + assert!( + !checker.check_and_add_tls_digest(digest_half), + "first use must not be treated as replay" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "immediate reuse must be rejected as replay" + ); + + std::thread::sleep(std::time::Duration::from_millis(70)); + + let validation_after_expiry = + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake still validates cryptographically after cache expiry"); + let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; + + assert_eq!( + digest_half, digest_half_after_expiry, + "replay key must remain stable for the same captured handshake" + ); + + assert!( + checker.check_and_add_tls_digest(digest_half_after_expiry), + "security expectation: a boot-time handshake should remain replay-protected even after cache expiry" + ); +} + +#[test] +fn stress_short_replay_window_boot_timestamp_replay_cycles_remain_fail_closed_in_window() { + let secret = b"gap_t01_boot_replay_stress"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let checker = crate::stats::ReplayChecker::new(256, std::time::Duration::from_millis(25)); + + for cycle in 0..64 { + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate"); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + if cycle == 0 { + assert!( + !checker.check_and_add_tls_digest(digest_half), + "cycle 0: first use must be fresh" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "cycle 0: second use must be replay" + ); + } else { + assert!( + checker.check_and_add_tls_digest(digest_half), + "cycle {cycle}: digest must remain replay-protected across short-window churn" + ); + } + + std::thread::sleep(std::time::Duration::from_millis(30)); + } +} + +#[test] +fn light_fuzz_boot_time_timestamp_matrix_with_short_replay_window_obeys_boot_cap() { + let secret = b"gap_t01_boot_replay_fuzz"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + + let mut s: u64 = 0xA1B2_C3D4_55AA_7733; + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = (s as u32) % 8; + + let handshake = make_valid_tls_handshake(secret, ts); + let accepted = + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2).is_some(); + + if ts < 2 { + assert!( + accepted, + "timestamp {ts} must remain boot-time compatible under 2s cap" + ); + } else { + assert!( + !accepted, + "timestamp {ts} must be rejected when outside replay-window boot cap" + ); + } + } +} + +#[test] +fn server_hello_application_data_contains_alpn_marker_when_selected() { + let secret = b"alpn_marker_test"; + let client_digest = [0x55u8; TLS_DIGEST_LEN]; + let session_id = vec![0xAB; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 512, + &rng, + Some(b"h2".to_vec()), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; + assert!( + app_payload + .windows(expected.len()) + .any(|window| window == expected), + "first application payload must carry ALPN marker for selected protocol" + ); +} + +#[test] +fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() { + let secret = b"alpn_oversize_ignore_test"; + let client_digest = [0x56u8; TLS_DIGEST_LEN]; + let session_id = vec![0xCD; 32]; + let rng = crate::crypto::SecureRandom::new(); + let oversized_alpn = vec![b'x'; u8::MAX as usize + 1]; + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 512, + &rng, + Some(oversized_alpn), + u8::MAX, + ); + + let mut pos = 0usize; + let mut app_records = 0usize; + let mut first_app_payload: Option<&[u8]> = None; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!( + next <= response.len(), + "TLS record must stay inside response bounds" + ); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + if first_app_payload.is_none() { + first_app_payload = Some(&response[pos + 5..next]); + } + } + pos = next; + } + let marker = [ + 0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x', + ]; + + assert_eq!( + app_records, 5, + "oversized ALPN must not change the four-ticket cap on tail records" + ); + assert!( + !first_app_payload + .expect("response must contain an application record") + .windows(marker.len()) + .any(|window| window == marker), + "oversized ALPN must be ignored rather than embedded into the first application payload" + ); +} + +#[test] +fn server_hello_ignores_oversized_alpn_when_marker_would_not_fit() { + let secret = b"alpn_too_large_to_fit_test"; + let client_digest = [0x57u8; TLS_DIGEST_LEN]; + let session_id = vec![0xEF; 32]; + let rng = crate::crypto::SecureRandom::new(); + let oversized_alpn = vec![0xAB; u8::MAX as usize]; + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(oversized_alpn), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes()); + marker_prefix.push(0xff); + marker_prefix.extend_from_slice(&[0xab; 8]); + assert!( + !app_payload.starts_with(&marker_prefix), + "oversized ALPN must not be partially embedded into the ServerHello application record" + ); +} + +#[test] +fn server_hello_embeds_full_alpn_marker_when_it_exactly_fits_fake_cert_len() { + let secret = b"alpn_exact_fit_test"; + let client_digest = [0x58u8; TLS_DIGEST_LEN]; + let session_id = vec![0xA5; 32]; + let rng = crate::crypto::SecureRandom::new(); + let proto = vec![b'z'; 57]; + + // marker_len = 4 + (2 + (1 + proto_len)) = 7 + proto_len = 64 + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(proto.clone()), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut expected_marker = Vec::new(); + expected_marker.extend_from_slice(&0x0010u16.to_be_bytes()); + expected_marker.extend_from_slice(&0x003Cu16.to_be_bytes()); + expected_marker.extend_from_slice(&0x003Au16.to_be_bytes()); + expected_marker.push(57u8); + expected_marker.extend_from_slice(&proto); + + assert_eq!(app_payload.len(), expected_marker.len()); + assert_eq!(app_payload, expected_marker.as_slice()); +} + +#[test] +fn server_hello_does_not_embed_partial_alpn_marker_when_one_byte_short() { + let secret = b"alpn_one_byte_short_test"; + let client_digest = [0x59u8; TLS_DIGEST_LEN]; + let session_id = vec![0xA6; 32]; + let rng = crate::crypto::SecureRandom::new(); + let proto = vec![0xAB; 58]; + + // marker_len = 65, fake_cert_len = 64 => marker must be fully skipped. + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(proto), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x003Du16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x003Bu16.to_be_bytes()); + marker_prefix.push(58u8); + marker_prefix.extend_from_slice(&[0xAB; 8]); + + assert!( + !app_payload.starts_with(&marker_prefix), + "one-byte-short ALPN marker must be skipped entirely, not partially embedded" + ); +} + +#[test] +fn exhaustive_tls_minor_version_classification_matches_policy() { + for minor in 0u8..=u8::MAX { + let first = [TLS_RECORD_HANDSHAKE, 0x03, minor]; + let expected = minor == 0x01 || minor == 0x03; + assert_eq!( + is_tls_handshake(&first), + expected, + "minor version {minor:#04x} classification mismatch" + ); + } +} + +#[test] +fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() { + // Deterministic xorshift state keeps this fuzz test reproducible. + let mut s: u64 = 0x9E37_79B9_AA95_5A5D; + + for _ in 0..10_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let header = [ + (s & 0xff) as u8, + ((s >> 8) & 0xff) as u8, + ((s >> 16) & 0xff) as u8, + ((s >> 24) & 0xff) as u8, + ((s >> 32) & 0xff) as u8, + ]; + + let classified = is_tls_handshake(&header[..3]); + let expected_classified = header[0] == TLS_RECORD_HANDSHAKE + && header[1] == 0x03 + && (header[2] == 0x01 || header[2] == 0x03); + assert_eq!( + classified, expected_classified, + "classifier policy mismatch for header {header:02x?}" + ); + + let parsed = parse_tls_record_header(&header); + let expected_parsed = + header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]); + assert_eq!( + parsed.is_some(), + expected_parsed, + "parser policy mismatch for header {header:02x?}" + ); + } +} + +#[test] +fn stress_random_noise_handshakes_never_authenticate() { + let secret = b"stress_noise_secret"; + let secrets = vec![("noise-user".to_string(), secret.to_vec())]; + + // Deterministic xorshift state keeps this stress test reproducible. + let mut s: u64 = 0xD1B5_4A32_9C6E_77F1; + + for _ in 0..5_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let len = 1 + ((s as usize) % 196); + let mut buf = vec![0u8; len]; + for b in &mut buf { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + *b = (s & 0xff) as u8; + } + + assert!( + validate_tls_handshake(&buf, &secrets, true).is_none(), + "random noise must never authenticate" + ); + } +} diff --git a/src/protocol/tests/tls_size_constants_security_tests.rs b/src/protocol/tests/tls_size_constants_security_tests.rs new file mode 100644 index 0000000..20e24c7 --- /dev/null +++ b/src/protocol/tests/tls_size_constants_security_tests.rs @@ -0,0 +1,11 @@ +use super::{MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; + +#[test] +fn tls_size_constants_match_rfc_8446() { + assert_eq!(MAX_TLS_PLAINTEXT_SIZE, 16_384); + assert_eq!(MAX_TLS_CIPHERTEXT_SIZE, 16_640); + + assert!(MIN_TLS_CLIENT_HELLO_SIZE < 512); + assert!(MIN_TLS_CLIENT_HELLO_SIZE > 64); + assert!(MAX_TLS_CIPHERTEXT_SIZE > MAX_TLS_PLAINTEXT_SIZE); +} diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index fbe7ad5..613106e 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -5,14 +5,69 @@ //! actually carries MTProto authentication data. #![allow(dead_code)] +#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] +#![cfg_attr( + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) +)] +#![cfg_attr( + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) +)] -use crate::crypto::{sha256_hmac, SecureRandom}; +use super::constants::*; +use crate::crypto::{SecureRandom, sha256_hmac}; #[cfg(test)] use crate::error::ProxyError; -use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; -use num_bigint::BigUint; -use num_traits::One; +use subtle::ConstantTimeEq; +use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; // ============= Public Constants ============= @@ -26,8 +81,17 @@ pub const TLS_DIGEST_POS: usize = 11; pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Time skew limits for anti-replay (in seconds) -pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before -pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after +/// +/// The default window is intentionally narrow to reduce replay acceptance. +/// Operators with known clock-drifted clients should tune deployment config +/// (for example replay-window policy) to match their environment. +pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before +pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after +/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. +pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; +/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance +/// windows when replay TTL is configured very large. +pub const BOOT_TIME_COMPAT_MAX_SECS: u32 = 2 * 60; // ============= Private Constants ============= @@ -77,80 +141,63 @@ impl TlsExtensionBuilder { extensions: Vec::with_capacity(128), } } - + /// Add Key Share extension with X25519 key fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self { // Extension type: key_share (0x0033) - self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes()); - + self.extensions + .extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes()); + // Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes // Extension data length let entry_len: u16 = 2 + 2 + 32; // curve + length + key self.extensions.extend_from_slice(&entry_len.to_be_bytes()); - + // Named curve: x25519 - self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes()); - + self.extensions + .extend_from_slice(&named_curve::X25519.to_be_bytes()); + // Key length self.extensions.extend_from_slice(&(32u16).to_be_bytes()); - + // Key data self.extensions.extend_from_slice(public_key); - + self } - + /// Add Supported Versions extension fn add_supported_versions(&mut self, version: u16) -> &mut Self { // Extension type: supported_versions (0x002b) - self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes()); - + self.extensions + .extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes()); + // Extension data: length (2) + version (2) self.extensions.extend_from_slice(&(2u16).to_be_bytes()); - + // Selected version self.extensions.extend_from_slice(&version.to_be_bytes()); - + self } - /// Add ALPN extension with a single selected protocol. - fn add_alpn(&mut self, proto: &[u8]) -> &mut Self { - // Extension type: ALPN (0x0010) - self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes()); - - // ALPN extension format: - // extension_data length (2 bytes) - // protocols length (2 bytes) - // protocol name length (1 byte) - // protocol name bytes - let proto_len = proto.len() as u8; - let list_len: u16 = 1 + proto_len as u16; - let ext_len: u16 = 2 + list_len; - - self.extensions.extend_from_slice(&ext_len.to_be_bytes()); - self.extensions.extend_from_slice(&list_len.to_be_bytes()); - self.extensions.push(proto_len); - self.extensions.extend_from_slice(proto); - self - } - /// Build final extensions with length prefix fn build(self) -> Vec { + let Ok(len) = u16::try_from(self.extensions.len()) else { + return Vec::new(); + }; let mut result = Vec::with_capacity(2 + self.extensions.len()); - + // Extensions length (2 bytes) - let len = self.extensions.len() as u16; result.extend_from_slice(&len.to_be_bytes()); - + // Extensions data result.extend_from_slice(&self.extensions); - + result } - + /// Get current extensions without length prefix (for calculation) - #[allow(dead_code)] fn as_bytes(&self) -> &[u8] { &self.extensions } @@ -170,8 +217,6 @@ struct ServerHelloBuilder { compression: u8, /// Extensions extensions: TlsExtensionBuilder, - /// Selected ALPN protocol (if any) - alpn: Option>, } impl ServerHelloBuilder { @@ -182,35 +227,30 @@ impl ServerHelloBuilder { cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256, compression: 0x00, extensions: TlsExtensionBuilder::new(), - alpn: None, } } - + fn with_x25519_key(mut self, key: &[u8; 32]) -> Self { self.extensions.add_key_share(key); self } - + fn with_tls13_version(mut self) -> Self { // TLS 1.3 = 0x0304 self.extensions.add_supported_versions(0x0304); self } - fn with_alpn(mut self, proto: Option>) -> Self { - self.alpn = proto; - self - } - /// Build ServerHello message (without record header) fn build_message(&self) -> Vec { - let mut ext_builder = self.extensions.clone(); - if let Some(ref alpn) = self.alpn { - ext_builder.add_alpn(alpn); - } - let extensions = ext_builder.extensions.clone(); - let extensions_len = extensions.len() as u16; - + let Ok(session_id_len) = u8::try_from(self.session_id.len()) else { + return Vec::new(); + }; + let extensions = self.extensions.extensions.clone(); + let Ok(extensions_len) = u16::try_from(extensions.len()) else { + return Vec::new(); + }; + // Calculate total length let body_len = 2 + // version 32 + // random @@ -218,160 +258,250 @@ impl ServerHelloBuilder { 2 + // cipher suite 1 + // compression 2 + extensions.len(); // extensions length + data - + if body_len > 0x00ff_ffff { + return Vec::new(); + } + let mut message = Vec::with_capacity(4 + body_len); - + // Handshake header message.push(0x02); // ServerHello message type - + // 3-byte length - let len_bytes = (body_len as u32).to_be_bytes(); + let Ok(body_len_u32) = u32::try_from(body_len) else { + return Vec::new(); + }; + let len_bytes = body_len_u32.to_be_bytes(); message.extend_from_slice(&len_bytes[1..4]); - + // Server version (TLS 1.2 in header, actual version in extension) message.extend_from_slice(&TLS_VERSION); - + // Random (32 bytes) - placeholder, will be replaced with digest message.extend_from_slice(&self.random); - + // Session ID - message.push(self.session_id.len() as u8); + message.push(session_id_len); message.extend_from_slice(&self.session_id); - + // Cipher suite message.extend_from_slice(&self.cipher_suite); - + // Compression method message.push(self.compression); - + // Extensions length message.extend_from_slice(&extensions_len.to_be_bytes()); - + // Extensions data message.extend_from_slice(&extensions); - + message } - + /// Build complete ServerHello TLS record fn build_record(&self) -> Vec { let message = self.build_message(); - + if message.is_empty() { + return Vec::new(); + } + let Ok(message_len) = u16::try_from(message.len()) else { + return Vec::new(); + }; + let mut record = Vec::with_capacity(5 + message.len()); - + // TLS record header record.push(TLS_RECORD_HANDSHAKE); record.extend_from_slice(&TLS_VERSION); - record.extend_from_slice(&(message.len() as u16).to_be_bytes()); - + record.extend_from_slice(&message_len.to_be_bytes()); + // Message record.extend_from_slice(&message); - + record } } // ============= Public Functions ============= -/// Validate TLS ClientHello against user secrets +/// Validate TLS ClientHello against user secrets. /// /// Returns validation result if a matching user is found. +/// The result **must** be used — ignoring it silently bypasses authentication. +#[must_use] pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], ignore_time_skew: bool, +) -> Option { + validate_tls_handshake_with_replay_window( + handshake, + secrets, + ignore_time_skew, + u64::from(BOOT_TIME_MAX_SECS), + ) +} + +/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL. +/// +/// A boot-time timestamp is only accepted when it falls below all three +/// bounds: `BOOT_TIME_MAX_SECS`, configured replay window, and +/// `BOOT_TIME_COMPAT_MAX_SECS`, preventing oversized compatibility windows. +#[must_use] +pub fn validate_tls_handshake_with_replay_window( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + replay_window_secs: u64, +) -> Option { + // Only pay the clock syscall when we will actually compare against it. + // If `ignore_time_skew` is set, a broken or unavailable system clock + // must not block legitimate clients — that would be a DoS via clock failure. + let now = if !ignore_time_skew { + system_time_to_unix_secs(SystemTime::now())? + } else { + 0_i64 + }; + + let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); + // Boot-time bypass and ignore_time_skew serve different compatibility paths. + // When skew checks are disabled, force boot-time cap to zero to prevent + // accidental future coupling of boot-time logic into the ignore-skew path. + let boot_time_cap_secs = if ignore_time_skew { + 0 + } else { + BOOT_TIME_MAX_SECS + .min(replay_window_u32) + .min(BOOT_TIME_COMPAT_MAX_SECS) + }; + + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + boot_time_cap_secs, + ) +} + +fn system_time_to_unix_secs(now: SystemTime) -> Option { + // `try_from` rejects values that overflow i64 (> ~292 billion years CE), + // whereas `as i64` would silently wrap to a negative timestamp and corrupt + // every subsequent time-skew comparison. + let d = now.duration_since(UNIX_EPOCH).ok()?; + i64::try_from(d.as_secs()).ok() +} + +fn validate_tls_handshake_at_time( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, +) -> Option { + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + BOOT_TIME_MAX_SECS, + ) +} + +fn validate_tls_handshake_at_time_with_boot_cap( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, + boot_time_cap_secs: u32, ) -> Option { if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { return None; } - + // Extract digest let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] .try_into() .ok()?; - + // Extract session ID let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; let session_id_len = handshake.get(session_id_len_pos).copied()? as usize; + if session_id_len > 32 { + return None; + } let session_id_start = session_id_len_pos + 1; - + if handshake.len() < session_id_start + session_id_len { return None; } - + let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec(); - + // Build message for HMAC (with zeroed digest) let mut msg = handshake.to_vec(); msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); - - // Get current time - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - + + let mut first_match: Option<(&String, u32)> = None; + for (user, secret) in secrets { let computed = sha256_hmac(secret, &msg); - - // XOR digests - let xored: Vec = digest.iter() - .zip(computed.iter()) - .map(|(a, b)| a ^ b) - .collect(); - - // Check that first 28 bytes are zeros (timestamp in last 4) - if !xored[..28].iter().all(|&b| b == 0) { + + // Constant-time equality check on the 28-byte HMAC window. + // A variable-time short-circuit here lets an active censor measure how many + // bytes matched, enabling secret brute-force via timing side-channels. + // Direct comparison on the original arrays avoids a heap allocation and + // removes the `try_into().unwrap()` that the intermediate Vec would require. + if !bool::from(digest[..28].ct_eq(&computed[..28])) { continue; } - - // Extract timestamp - let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap()); - let time_diff = now - timestamp as i64; - - // Check time skew + + // The last 4 bytes encode the timestamp as XOR(digest[28..32], computed[28..32]). + // Inline array construction is infallible: both slices are [u8; 32] by construction. + let timestamp = u32::from_le_bytes([ + digest[28] ^ computed[28], + digest[29] ^ computed[29], + digest[30] ^ computed[30], + digest[31] ^ computed[31], + ]); + + // time_diff is only meaningful (and `now` is only valid) when we are + // actually checking the window. Keep both inside the guard to make + // the dead-code path explicit and prevent accidental future use of + // a sentinel `now` value outside its intended scope. if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time - let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds - - if !is_boot_time && !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { - continue; + let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs; + if !is_boot_time { + let time_diff = now - i64::from(timestamp); + if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { + continue; + } } } - - return Some(TlsValidation { - user: user.clone(), - session_id, - digest, - timestamp, - }); - } - - None -} -fn curve25519_prime() -> BigUint { - (BigUint::one() << 255) - BigUint::from(19u32) + if first_match.is_none() { + first_match = Some((user, timestamp)); + } + } + + first_match.map(|(user, timestamp)| TlsValidation { + user: user.clone(), + session_id, + digest, + timestamp, + }) } /// Generate a fake X25519 public key for TLS /// -/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p, -/// which matches Python/C behavior and avoids DPI fingerprinting. +/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint, +/// yielding distribution-consistent public keys for anti-fingerprinting. pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { - let mut n_bytes = [0u8; 32]; - n_bytes.copy_from_slice(&rng.bytes(32)); - - let n = BigUint::from_bytes_le(&n_bytes); - let p = curve25519_prime(); - let pk = (&n * &n) % &p; - - let mut out = pk.to_bytes_le(); - out.resize(32, 0); - let mut result = [0u8; 32]; - result.copy_from_slice(&out[..32]); - result + let mut scalar = [0u8; 32]; + scalar.copy_from_slice(&rng.bytes(32)); + x25519(scalar, X25519_BASEPOINT_BYTES) } /// Build TLS ServerHello response @@ -392,27 +522,50 @@ pub fn build_server_hello( new_session_tickets: u8, ) -> Vec { const MIN_APP_DATA: usize = 64; - const MAX_APP_DATA: usize = 16640; // RFC 8446 §5.2 upper bound + const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; let fake_cert_len = fake_cert_len.clamp(MIN_APP_DATA, MAX_APP_DATA); let x25519_key = gen_fake_x25519_key(rng); - + // Build ServerHello let server_hello = ServerHelloBuilder::new(session_id.to_vec()) .with_x25519_key(&x25519_key) .with_tls13_version() - .with_alpn(alpn) .build_record(); - + // Build Change Cipher Spec record let change_cipher_spec = [ TLS_RECORD_CHANGE_CIPHER, - TLS_VERSION[0], TLS_VERSION[1], - 0x00, 0x01, // length = 1 - 0x01, // CCS byte + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x01, // length = 1 + 0x01, // CCS byte ]; - - // Build fake certificate (Application Data record) - let fake_cert = rng.bytes(fake_cert_len); + + // Build first encrypted flight mimic as opaque ApplicationData bytes. + // Embed a compact EncryptedExtensions-like ALPN block when selected. + let mut fake_cert = Vec::with_capacity(fake_cert_len); + if let Some(proto) = alpn + .as_ref() + .filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) + { + let proto_list_len = 1usize + proto.len(); + let ext_data_len = 2usize + proto_list_len; + let marker_len = 4usize + ext_data_len; + if marker_len <= fake_cert_len { + fake_cert.extend_from_slice(&0x0010u16.to_be_bytes()); + fake_cert.extend_from_slice(&(ext_data_len as u16).to_be_bytes()); + fake_cert.extend_from_slice(&(proto_list_len as u16).to_be_bytes()); + fake_cert.push(proto.len() as u8); + fake_cert.extend_from_slice(proto); + } + } + if fake_cert.len() < fake_cert_len { + fake_cert.extend_from_slice(&rng.bytes(fake_cert_len - fake_cert.len())); + } else if fake_cert.len() > fake_cert_len { + fake_cert.truncate(fake_cert_len); + } + let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); app_data_record.push(TLS_RECORD_APPLICATION); app_data_record.extend_from_slice(&TLS_VERSION); @@ -420,12 +573,13 @@ pub fn build_server_hello( // Fill ApplicationData with fully random bytes of desired length to avoid // deterministic DPI fingerprints (fixed inner content type markers). app_data_record.extend_from_slice(&fake_cert); - + // Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted; // here we mimic with opaque ApplicationData records of plausible size). let mut tickets = Vec::new(); - if new_session_tickets > 0 { - for _ in 0..new_session_tickets { + let ticket_count = new_session_tickets.min(4); + if ticket_count > 0 { + for _ in 0..ticket_count { let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes let mut record = Vec::with_capacity(5 + ticket_len); record.push(TLS_RECORD_APPLICATION); @@ -438,7 +592,10 @@ pub fn build_server_hello( // Combine all records let mut response = Vec::with_capacity( - server_hello.len() + change_cipher_spec.len() + app_data_record.len() + tickets.iter().map(|r| r.len()).sum::() + server_hello.len() + + change_cipher_spec.len() + + app_data_record.len() + + tickets.iter().map(|r| r.len()).sum::(), ); response.extend_from_slice(&server_hello); response.extend_from_slice(&change_cipher_spec); @@ -446,18 +603,17 @@ pub fn build_server_hello( for t in &tickets { response.extend_from_slice(t); } - + // Compute HMAC for the response let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len()); hmac_input.extend_from_slice(client_digest); hmac_input.extend_from_slice(&response); let response_digest = sha256_hmac(secret, &hmac_input); - + // Insert computed digest into ServerHello // Position: record header (5) + message type (1) + length (3) + version (2) = 11 - response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&response_digest); - + response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&response_digest); + response } @@ -467,6 +623,11 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { return None; } + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return None; + } + let mut pos = 5; // after record header if handshake.get(pos).copied()? != 0x01 { return None; // not ClientHello @@ -506,6 +667,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { return None; } + let mut saw_sni_extension = false; + let mut extracted_sni = None; + while pos + 4 <= ext_end { let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]); let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize; @@ -513,6 +677,12 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { if pos + elen > ext_end { break; } + if etype == 0x0000 { + if saw_sni_extension { + return None; + } + saw_sni_extension = true; + } if etype == 0x0000 && elen >= 5 { // server_name extension let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; @@ -520,15 +690,19 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { let sn_end = std::cmp::min(sn_pos + list_len, pos + elen); while sn_pos + 3 <= sn_end { let name_type = handshake[sn_pos]; - let name_len = u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize; + let name_len = + u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize; sn_pos += 3; if sn_pos + name_len > sn_end { break; } - if name_type == 0 && name_len > 0 + if name_type == 0 + && name_len > 0 && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) + && is_valid_sni_hostname(host) { - return Some(host.to_string()); + extracted_sni = Some(host.to_string()); + break; } sn_pos += name_len; } @@ -536,46 +710,98 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { pos += elen; } - None + extracted_sni +} + +fn is_valid_sni_hostname(host: &str) -> bool { + if host.is_empty() || host.len() > 253 { + return false; + } + if host.starts_with('.') || host.ends_with('.') { + return false; + } + if host.parse::().is_ok() { + return false; + } + + for label in host.split('.') { + if label.is_empty() || label.len() > 63 { + return false; + } + if label.starts_with('-') || label.ends_with('-') { + return false; + } + if !label + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + return false; + } + } + + true } /// Extract ALPN protocol list from ClientHello, return in offered order. pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { + if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return Vec::new(); + } + let mut pos = 5; // after record header if handshake.get(pos) != Some(&0x01) { return Vec::new(); } pos += 4; // type + len pos += 2 + 32; // version + random - if pos >= handshake.len() { return Vec::new(); } + if pos >= handshake.len() { + return Vec::new(); + } let session_id_len = *handshake.get(pos).unwrap_or(&0) as usize; pos += 1 + session_id_len; - if pos + 2 > handshake.len() { return Vec::new(); } - let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; + if pos + 2 > handshake.len() { + return Vec::new(); + } + let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; pos += 2 + cipher_len; - if pos >= handshake.len() { return Vec::new(); } + if pos >= handshake.len() { + return Vec::new(); + } let comp_len = *handshake.get(pos).unwrap_or(&0) as usize; pos += 1 + comp_len; - if pos + 2 > handshake.len() { return Vec::new(); } - let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; + if pos + 2 > handshake.len() { + return Vec::new(); + } + let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; pos += 2; let ext_end = pos + ext_len; - if ext_end > handshake.len() { return Vec::new(); } + if ext_end > handshake.len() { + return Vec::new(); + } let mut out = Vec::new(); while pos + 4 <= ext_end { - let etype = u16::from_be_bytes([handshake[pos], handshake[pos+1]]); - let elen = u16::from_be_bytes([handshake[pos+2], handshake[pos+3]]) as usize; + let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]); + let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize; pos += 4; - if pos + elen > ext_end { break; } + if pos + elen > ext_end { + break; + } if etype == extension_type::ALPN && elen >= 3 { - let list_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; + let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; let mut lp = pos + 2; let list_end = (pos + 2).saturating_add(list_len).min(pos + elen); while lp < list_end { let plen = handshake[lp] as usize; lp += 1; - if lp + plen > list_end { break; } - out.push(handshake[lp..lp+plen].to_vec()); + if lp + plen > list_end { + break; + } + out.push(handshake[lp..lp + plen].to_vec()); lp += plen; } break; @@ -585,29 +811,28 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { out } - /// Check if bytes look like a TLS ClientHello pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { if first_bytes.len() < 3 { return false; } - - // TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0) - first_bytes[0] == TLS_RECORD_HANDSHAKE - && first_bytes[1] == 0x03 - && first_bytes[2] == 0x01 + + // TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303. + first_bytes[0] == TLS_RECORD_HANDSHAKE + && first_bytes[1] == 0x03 + && (first_bytes[2] == 0x01 || first_bytes[2] == 0x03) } /// Parse TLS record header, returns (record_type, length) pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { let record_type = header[0]; let version = [header[1], header[2]]; - + // We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3 if version != [0x03, 0x01] && version != TLS_VERSION { return None; } - + let length = u16::from_be_bytes([header[3], header[4]]); Some((record_type, length)) } @@ -623,7 +848,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { version: [0, 0], }); } - + // Check record header if data[0] != TLS_RECORD_HANDSHAKE { return Err(ProxyError::InvalidTlsRecord { @@ -631,7 +856,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { version: [data[1], data[2]], }); } - + // Check version if data[1..3] != TLS_VERSION { return Err(ProxyError::InvalidTlsRecord { @@ -639,319 +864,72 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { version: [data[1], data[2]], }); } - + // Check record length let record_len = u16::from_be_bytes([data[3], data[4]]) as usize; if data.len() < 5 + record_len { - return Err(ProxyError::InvalidHandshake( - format!("ServerHello record truncated: expected {}, got {}", - 5 + record_len, data.len()) - )); + return Err(ProxyError::InvalidHandshake(format!( + "ServerHello record truncated: expected {}, got {}", + 5 + record_len, + data.len() + ))); } - + // Check message type if data[5] != 0x02 { - return Err(ProxyError::InvalidHandshake( - format!("Expected ServerHello (0x02), got 0x{:02x}", data[5]) - )); + return Err(ProxyError::InvalidHandshake(format!( + "Expected ServerHello (0x02), got 0x{:02x}", + data[5] + ))); } - + // Parse message length let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize; if msg_len + 4 != record_len { - return Err(ProxyError::InvalidHandshake( - format!("Message length mismatch: {} + 4 != {}", msg_len, record_len) - )); + return Err(ProxyError::InvalidHandshake(format!( + "Message length mismatch: {} + 4 != {}", + msg_len, record_len + ))); } - + Ok(()) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_tls_handshake() { - assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); - assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); - assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data - assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version - assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short - } - - #[test] - fn test_parse_tls_record_header() { - let header = [0x16, 0x03, 0x01, 0x02, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_HANDSHAKE); - assert_eq!(result.1, 512); - - let header = [0x17, 0x03, 0x03, 0x40, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_APPLICATION); - assert_eq!(result.1, 16384); - } - - #[test] - fn test_gen_fake_x25519_key() { - let rng = SecureRandom::new(); - let key1 = gen_fake_x25519_key(&rng); - let key2 = gen_fake_x25519_key(&rng); - - assert_eq!(key1.len(), 32); - assert_eq!(key2.len(), 32); - assert_ne!(key1, key2); // Should be random - } +// ============= Compile-time Security Invariants ============= - #[test] - fn test_fake_x25519_key_is_quadratic_residue() { - let rng = SecureRandom::new(); - let key = gen_fake_x25519_key(&rng); - let p = curve25519_prime(); - let k_num = BigUint::from_bytes_le(&key); - let exponent = (&p - BigUint::one()) >> 1; - let legendre = k_num.modpow(&exponent, &p); - assert_eq!(legendre, BigUint::one()); - } - - #[test] - fn test_tls_extension_builder() { - let key = [0x42u8; 32]; - - let mut builder = TlsExtensionBuilder::new(); - builder.add_key_share(&key); - builder.add_supported_versions(0x0304); - - let result = builder.build(); - - // Check length prefix - let len = u16::from_be_bytes([result[0], result[1]]) as usize; - assert_eq!(len, result.len() - 2); - - // Check key_share extension is present - assert!(result.len() > 40); // At least key share - } - - #[test] - fn test_server_hello_builder() { - let session_id = vec![0x01, 0x02, 0x03, 0x04]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id.clone()) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Validate structure - validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); - - // Check record type - assert_eq!(record[0], TLS_RECORD_HANDSHAKE); - - // Check version - assert_eq!(&record[1..3], &TLS_VERSION); - - // Check message type (ServerHello = 0x02) - assert_eq!(record[5], 0x02); - } - - #[test] - fn test_build_server_hello_structure() { - let secret = b"test secret"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); - - // Should have at least 3 records - assert!(response.len() > 100); - - // First record should be ServerHello - assert_eq!(response[0], TLS_RECORD_HANDSHAKE); - - // Validate ServerHello structure - validate_server_hello_structure(&response).expect("Invalid ServerHello"); - - // Find Change Cipher Spec - let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; - let ccs_start = server_hello_len; - - assert!(response.len() > ccs_start + 6); - assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); - - // Find Application Data - let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; - let app_start = ccs_start + ccs_len; - - assert!(response.len() > app_start + 5); - assert_eq!(response[app_start], TLS_RECORD_APPLICATION); - } - - #[test] - fn test_build_server_hello_digest() { - let secret = b"test secret key here"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - - // Digest position should have non-zero data - let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert!(!digest1.iter().all(|&b| b == 0)); - - // Different calls should have different digests (due to random cert) - let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert_ne!(digest1, digest2); - } - - #[test] - fn test_server_hello_extensions_length() { - let session_id = vec![0x01; 32]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Parse to find extensions - let msg_start = 5; // After record header - let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; - - // Skip to session ID - let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32) - let session_id_len = record[session_id_pos] as usize; - - // Skip to extensions - let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1) - let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; - - // Verify extensions length matches actual data - let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; - assert_eq!(ext_len, extensions_data.len(), - "Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len()); - } - - #[test] - fn test_validate_tls_handshake_format() { - // Build a minimal ClientHello-like structure - let mut handshake = vec![0u8; 100]; - - // Put a valid-looking digest at position 11 - handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&[0x42; 32]); - - // Session ID length - handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; - - // This won't validate (wrong HMAC) but shouldn't panic - let secrets = vec![("test".to_string(), b"secret".to_vec())]; - let result = validate_tls_handshake(&handshake, &secrets, true); - - // Should return None (no match) but not panic - assert!(result.is_none()); - } +/// Compile-time checks that enforce invariants the rest of the code relies on. +/// Using `static_assertions` ensures these can never silently break across +/// refactors without a compile error. +mod compile_time_security_checks { + use super::{TLS_DIGEST_HALF_LEN, TLS_DIGEST_LEN}; + use static_assertions::const_assert; - fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { - let mut body = Vec::new(); - body.extend_from_slice(&TLS_VERSION); // legacy version - body.extend_from_slice(&[0u8; 32]); // random - body.push(0); // session id len - body.extend_from_slice(&2u16.to_be_bytes()); // cipher suites len - body.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 - body.push(1); // compression len - body.push(0); // null compression + // The digest must be exactly one SHA-256 output. + const_assert!(TLS_DIGEST_LEN == 32); - // Build SNI extension - let host_bytes = host.as_bytes(); - let mut sni_ext = Vec::new(); - sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); - sni_ext.push(0); - sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); - sni_ext.extend_from_slice(host_bytes); + // Replay-dedup stores the first half; verify it is literally half. + const_assert!(TLS_DIGEST_HALF_LEN * 2 == TLS_DIGEST_LEN); - let mut ext_blob = Vec::new(); - for (typ, data) in exts { - ext_blob.extend_from_slice(&typ.to_be_bytes()); - ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&data); - } - // SNI last - ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); - ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&sni_ext); - - body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); - body.extend_from_slice(&ext_blob); - - let mut handshake = Vec::new(); - handshake.push(0x01); // ClientHello - let len_bytes = (body.len() as u32).to_be_bytes(); - handshake.extend_from_slice(&len_bytes[1..4]); - handshake.extend_from_slice(&body); - - let mut record = Vec::new(); - record.push(TLS_RECORD_HANDSHAKE); - record.extend_from_slice(&[0x03, 0x01]); - record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); - record.extend_from_slice(&handshake); - record - } - - #[test] - fn test_extract_sni_with_grease_extension() { - // GREASE type 0x0a0a with zero length before SNI - let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("example.com")); - } - - #[test] - fn test_extract_sni_tolerates_empty_unknown_extension() { - let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("test.local")); - } - - #[test] - fn test_extract_alpn_single() { - let mut alpn_data = Vec::new(); - // list length = 3 (1 length byte + "h2") - alpn_data.extend_from_slice(&3u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2"]); - } - - #[test] - fn test_extract_alpn_multiple() { - let mut alpn_data = Vec::new(); - // list length = 11 (sum of per-proto lengths including length bytes) - alpn_data.extend_from_slice(&11u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - alpn_data.push(4); - alpn_data.extend_from_slice(b"spdy"); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h3"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); - } + // The HMAC check window (28 bytes) plus the embedded timestamp (4 bytes) + // must exactly fill the digest. If TLS_DIGEST_LEN ever changes, these + // assertions will catch the mismatch before any timing-oracle fix is broke. + const_assert!(28 + 4 == TLS_DIGEST_LEN); } + +// ============= Security-focused regression tests ============= + +#[cfg(test)] +#[path = "tests/tls_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/tls_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/tls_fuzz_security_tests.rs"] +mod fuzz_security_tests; + +#[cfg(test)] +#[path = "tests/tls_length_cast_hardening_security_tests.rs"] +mod length_cast_hardening_security_tests; diff --git a/src/proxy/adaptive_buffers.rs b/src/proxy/adaptive_buffers.rs index 3b1bce9..0c210dd 100644 --- a/src/proxy/adaptive_buffers.rs +++ b/src/proxy/adaptive_buffers.rs @@ -1,3 +1,8 @@ +#![allow(dead_code)] + +// Adaptive buffer policy is staged and retained for deterministic rollout. +// Keep definitions compiled for compatibility and security test scaffolding. + use dashmap::DashMap; use std::cmp::max; use std::sync::OnceLock; @@ -170,7 +175,8 @@ impl SessionAdaptiveController { return self.promote(TierTransitionReason::SoftConfirmed, 0); } - let demote_candidate = self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now; + let demote_candidate = + self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now; if demote_candidate { self.quiet_ticks = self.quiet_ticks.saturating_add(1); if self.quiet_ticks >= QUIET_DEMOTE_TICKS { @@ -253,10 +259,7 @@ pub fn record_user_tier(user: &str, tier: AdaptiveTier) { }; return; } - profiles().insert( - user.to_string(), - UserAdaptiveProfile { tier, seen_at: now }, - ); + profiles().insert(user.to_string(), UserAdaptiveProfile { tier, seen_at: now }); } pub fn direct_copy_buffers_for_tier( @@ -339,10 +342,7 @@ mod tests { sample( 300_000, // ~9.6 Mbps 320_000, // incoming > outgoing to confirm tier2 - 250_000, - 10, - 0, - 0, + 250_000, 10, 0, 0, ), tick_secs, ); @@ -358,10 +358,7 @@ mod tests { fn test_hard_promotion_on_pending_pressure() { let mut ctrl = SessionAdaptiveController::new(AdaptiveTier::Base); let transition = ctrl - .observe( - sample(10_000, 20_000, 10_000, 4, 1, 3), - 0.25, - ) + .observe(sample(10_000, 20_000, 10_000, 4, 1, 3), 0.25) .expect("expected hard promotion"); assert_eq!(transition.reason, TierTransitionReason::HardPressure); assert_eq!(transition.to, AdaptiveTier::Tier1); diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 25e6cf9..4b7f57e 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,9 +1,13 @@ //! Client Handler +use ipnetwork::IpNetwork; +use rand::RngExt; use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::net::TcpStream; @@ -17,33 +21,197 @@ type PostHandshakeFuture = Pin> + Send>>; enum HandshakeOutcome { /// Handshake succeeded, relay work to do (outside timeout) NeedsRelay(PostHandshakeFuture), - /// Already fully handled (bad client masking, etc.) - Handled, + /// Handshake failed and masking must run outside handshake timeout budget + NeedsMasking(PostHandshakeFuture), +} + +#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"] +struct UserConnectionReservation { + stats: Arc, + ip_tracker: Arc, + user: String, + ip: IpAddr, + active: bool, +} + +impl UserConnectionReservation { + fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + Self { + stats, + ip_tracker, + user, + ip, + active: true, + } + } + + async fn release(mut self) { + if !self.active { + return; + } + self.ip_tracker.remove_ip(&self.user, self.ip).await; + self.active = false; + self.stats.decrement_user_curr_connects(&self.user); + } +} + +impl Drop for UserConnectionReservation { + fn drop(&mut self) { + if !self.active { + return; + } + self.active = false; + self.stats.decrement_user_curr_connects(&self.user); + self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); + } } use crate::config::ProxyConfig; use crate::crypto::SecureRandom; -use crate::error::{HandshakeResult, ProxyError, Result}; +use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; use crate::ip_tracker::UserIpTracker; use crate::protocol::constants::*; use crate::protocol::tls; use crate::stats::beobachten::BeobachtenStore; use crate::stats::{ReplayChecker, Stats}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; -use crate::transport::middle_proxy::MePool; -use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol}; -use crate::transport::socket::normalize_ip; use crate::tls_front::TlsFrontCache; +use crate::transport::middle_proxy::MePool; +use crate::transport::socket::normalize_ip; +use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol}; use crate::proxy::direct_relay::handle_via_direct; use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::masking::handle_bad_client; use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; -use crate::proxy::session_eviction::register_session; fn beobachten_ttl(config: &ProxyConfig) -> Duration { - Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)) + const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60; + let minutes = config.general.beobachten_minutes; + if minutes == 0 { + static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock = OnceLock::new(); + let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute" + ); + } + return Duration::from_secs(60); + } + + if minutes > BEOBACHTEN_TTL_MAX_MINUTES { + static BEOBACHTEN_OVERSIZED_MINUTES_WARNED: OnceLock = OnceLock::new(); + let warned = BEOBACHTEN_OVERSIZED_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + configured_minutes = minutes, + max_minutes = BEOBACHTEN_TTL_MAX_MINUTES, + "general.beobachten_minutes is too large; clamping to secure maximum" + ); + } + } + + Duration::from_secs(minutes.min(BEOBACHTEN_TTL_MAX_MINUTES).saturating_mul(60)) +} + +fn wrap_tls_application_record(payload: &[u8]) -> Vec { + let chunks = payload.len().div_ceil(u16::MAX as usize).max(1); + let mut record = Vec::with_capacity(payload.len() + 5 * chunks); + + if payload.is_empty() { + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&0u16.to_be_bytes()); + return record; + } + + for chunk in payload.chunks(u16::MAX as usize) { + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(chunk.len() as u16).to_be_bytes()); + record.extend_from_slice(chunk); + } + + record +} + +fn tls_clienthello_len_in_bounds(tls_len: usize) -> bool { + (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) +} + +async fn read_with_progress( + reader: &mut R, + mut buf: &mut [u8], +) -> std::io::Result { + let mut total = 0usize; + while !buf.is_empty() { + match reader.read(buf).await { + Ok(0) => return Ok(total), + Ok(n) => { + total += n; + let (_, rest) = buf.split_at_mut(n); + buf = rest; + } + Err(e) => return Err(e), + } + } + Ok(total) +} + +async fn maybe_apply_mask_reject_delay(config: &ProxyConfig) { + let min = config.censorship.server_hello_delay_min_ms; + let max = config.censorship.server_hello_delay_max_ms; + if max == 0 { + return; + } + + let delay_ms = if min >= max { + max + } else { + rand::rng().random_range(min..=max) + }; + + if delay_ms > 0 { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } +} + +fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration { + let base = Duration::from_secs(config.timeouts.client_handshake); + if config.censorship.mask { + base.saturating_add(Duration::from_millis(750)) + } else { + base + } +} + +fn masking_outcome( + reader: R, + writer: W, + initial_data: Vec, + peer: SocketAddr, + local_addr: SocketAddr, + config: Arc, + beobachten: Arc, +) -> HandshakeOutcome +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + HandshakeOutcome::NeedsMasking(Box::pin(async move { + handle_bad_client( + reader, + writer, + &initial_data, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + Ok(()) + })) } fn record_beobachten_class( @@ -64,14 +232,34 @@ fn record_handshake_failure_class( peer_ip: IpAddr, error: &ProxyError, ) { - let class = if error.to_string().contains("expected 64 bytes, got 0") { - "expected_64_got_0" - } else { - "other" + let class = match error { + ProxyError::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + "expected_64_got_0" + } + ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0", + _ => "other", }; record_beobachten_class(beobachten, config, peer_ip, class); } +fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { + if trusted.is_empty() { + static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); + let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + ); + } + return false; + } + trusted.iter().any(|cidr| cidr.contains(peer_ip)) +} + +fn synthetic_local_addr(port: u16) -> SocketAddr { + SocketAddr::from(([0, 0, 0, 0], port)) +} + pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -95,16 +283,29 @@ where let mut real_peer = normalize_ip(peer); // For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst - let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) - .parse() - .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + let mut local_addr = synthetic_local_addr(config.server.port); if proxy_protocol_enabled { - let proxy_header_timeout = Duration::from_millis( - config.server.proxy_protocol_header_timeout_ms.max(1), - ); - match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { + let proxy_header_timeout = + Duration::from_millis(config.server.proxy_protocol_header_timeout_ms.max(1)); + match timeout( + proxy_header_timeout, + parse_proxy_protocol(&mut stream, peer), + ) + .await + { Ok(Ok(info)) => { + if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) + { + stats.increment_connects_bad(); + warn!( + peer = %peer, + trusted = ?config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class(&beobachten, &config, peer.ip(), "other"); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %peer, client = %info.src_addr, @@ -133,7 +334,7 @@ where debug!(peer = %real_peer, "New connection (generic stream)"); - let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); + let handshake_timeout = handshake_timeout_with_mask_grace(&config); let stats_for_timeout = stats.clone(); let config_for_timeout = config.clone(); let beobachten_for_timeout = beobachten.clone(); @@ -150,26 +351,68 @@ where if is_tls { let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; - if tls_len < 512 { - debug!(peer = %real_peer, tls_len = tls_len, "TLS handshake too short"); + // RFC 8446 §5.1: TLS record payload MUST NOT exceed 2^14 (16_384) bytes. + // Lower bound is a structural minimum for a valid TLS 1.3 ClientHello + // (record header + handshake header + random + session_id + cipher_suites + // + compression + at least one extension with SNI). The previous value of + // 512 was implicitly coupled to TLS_REQUEST_LENGTH=517 from the official + // Telegram MTProxy reference server, leaving only a 5-byte margin and + // incorrectly rejecting compact but spec-compliant ClientHellos from + // third-party clients or future Telegram versions. + if !tls_clienthello_len_in_bounds(tls_len) { + debug!(peer = %real_peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds"); stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; let (reader, writer) = tokio::io::split(stream); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); - stream.read_exact(&mut handshake[5..]).await?; + let body_read = match read_with_progress(&mut stream, &mut handshake[5..]).await { + Ok(n) => n, + Err(e) => { + debug!(peer = %real_peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback"); + stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; + let initial_len = 5; + let (reader, writer) = tokio::io::split(stream); + return Ok(masking_outcome( + reader, + writer, + handshake[..initial_len].to_vec(), + real_peer, + local_addr, + config.clone(), + beobachten.clone(), + )); + } + }; + + if body_read < tls_len { + debug!(peer = %real_peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback"); + stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; + let initial_len = 5 + body_read; + let (reader, writer) = tokio::io::split(stream); + return Ok(masking_outcome( + reader, + writer, + handshake[..initial_len].to_vec(), + real_peer, + local_addr, + config.clone(), + beobachten.clone(), + )); + } let (read_half, write_half) = tokio::io::split(stream); @@ -180,17 +423,15 @@ where HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.clone(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -205,10 +446,32 @@ where &config, &replay_checker, true, Some(tls_user.as_str()), ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader: _, writer: _ } => { + HandshakeResult::BadClient { reader, writer } => { + // MTProto failed after TLS ServerHello was already sent. + // Switch fallback relay back to raw transport so the mask + // backend receives valid TLS records (not unwrapped payload). + let (reader, pending_plaintext) = reader.into_inner_with_pending_plaintext(); + let writer = writer.into_inner(); + let pending_record = if pending_plaintext.is_empty() { + Vec::new() + } else { + wrap_tls_application_record(&pending_plaintext) + }; + let reader = tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader); stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(HandshakeOutcome::Handled); + debug!( + peer = %peer, + "Authenticated TLS session failed MTProto validation; engaging masking fallback" + ); + return Ok(masking_outcome( + reader, + writer, + Vec::new(), + real_peer, + local_addr, + config.clone(), + beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -225,18 +488,17 @@ where if !config.general.modes.classic && !config.general.modes.secure { debug!(peer = %real_peer, "Non-TLS modes disabled"); stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; let (reader, writer) = tokio::io::split(stream); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -252,17 +514,15 @@ where HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.to_vec(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -312,8 +572,7 @@ where // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) match outcome { - HandshakeOutcome::NeedsRelay(fut) => fut.await, - HandshakeOutcome::Handled => Ok(()), + HandshakeOutcome::NeedsRelay(fut) | HandshakeOutcome::NeedsMasking(fut) => fut.await, } } @@ -382,7 +641,6 @@ impl RunningClientHandler { pub async fn run(self) -> Result<()> { self.stats.increment_connects_all(); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, "New connection"); if let Err(e) = configure_client_socket( @@ -393,7 +651,7 @@ impl RunningClientHandler { debug!(peer = %peer, error = %e, "Failed to configure client socket"); } - let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); + let handshake_timeout = handshake_timeout_with_mask_grace(&self.config); let stats = self.stats.clone(); let config_for_timeout = self.config.clone(); let beobachten_for_timeout = self.beobachten.clone(); @@ -427,8 +685,7 @@ impl RunningClientHandler { // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) match outcome { - HandshakeOutcome::NeedsRelay(fut) => fut.await, - HandshakeOutcome::Handled => Ok(()), + HandshakeOutcome::NeedsRelay(fut) | HandshakeOutcome::NeedsMasking(fut) => fut.await, } } @@ -436,9 +693,8 @@ impl RunningClientHandler { let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; if self.proxy_protocol_enabled { - let proxy_header_timeout = Duration::from_millis( - self.config.server.proxy_protocol_header_timeout_ms.max(1), - ); + let proxy_header_timeout = + Duration::from_millis(self.config.server.proxy_protocol_header_timeout_ms.max(1)); match timeout( proxy_header_timeout, parse_proxy_protocol(&mut self.stream, self.peer), @@ -446,6 +702,24 @@ impl RunningClientHandler { .await { Ok(Ok(info)) => { + if !is_trusted_proxy_source( + self.peer.ip(), + &self.config.server.proxy_protocol_trusted_cidrs, + ) { + self.stats.increment_connects_bad(); + warn!( + peer = %self.peer, + trusted = ?self.config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class( + &self.beobachten, + &self.config, + self.peer.ip(), + "other", + ); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %self.peer, client = %info.src_addr, @@ -495,7 +769,6 @@ impl RunningClientHandler { let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); @@ -506,34 +779,78 @@ impl RunningClientHandler { } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { + async fn handle_tls_client( + mut self, + first_bytes: [u8; 5], + local_addr: SocketAddr, + ) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); - if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + // RFC 8446 §5.1: TLS record payload MUST NOT exceed 2^14 (16_384) bytes. + // Lower bound is a structural minimum for a valid TLS 1.3 ClientHello + // (record header + handshake header + random + session_id + cipher_suites + // + compression + at least one extension with SNI). The previous value of + // 512 was implicitly coupled to TLS_REQUEST_LENGTH=517 from the official + // Telegram MTProxy reference server, leaving only a 5-byte margin and + // incorrectly rejecting compact but spec-compliant ClientHellos from + // third-party clients or future Telegram versions. + if !tls_clienthello_len_in_bounds(tls_len) { + debug!(peer = %peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds"); self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; let (reader, writer) = self.stream.into_split(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), peer, local_addr, - &self.config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + self.config.clone(), + self.beobachten.clone(), + )); } let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); - self.stream.read_exact(&mut handshake[5..]).await?; + let body_read = match read_with_progress(&mut self.stream, &mut handshake[5..]).await { + Ok(n) => n, + Err(e) => { + debug!(peer = %peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback"); + self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; + let (reader, writer) = self.stream.into_split(); + return Ok(masking_outcome( + reader, + writer, + handshake[..5].to_vec(), + peer, + local_addr, + self.config.clone(), + self.beobachten.clone(), + )); + } + }; + + if body_read < tls_len { + debug!(peer = %peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback"); + self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; + let initial_len = 5 + body_read; + let (reader, writer) = self.stream.into_split(); + return Ok(masking_outcome( + reader, + writer, + handshake[..initial_len].to_vec(), + peer, + local_addr, + self.config.clone(), + self.beobachten.clone(), + )); + } let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); @@ -557,17 +874,15 @@ impl RunningClientHandler { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.clone(), peer, local_addr, - &config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + self.beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -591,13 +906,33 @@ impl RunningClientHandler { .await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { - reader: _, - writer: _, - } => { + HandshakeResult::BadClient { reader, writer } => { + // MTProto failed after TLS ServerHello was already sent. + // Switch fallback relay back to raw transport so the mask + // backend receives valid TLS records (not unwrapped payload). + let (reader, pending_plaintext) = reader.into_inner_with_pending_plaintext(); + let writer = writer.into_inner(); + let pending_record = if pending_plaintext.is_empty() { + Vec::new() + } else { + wrap_tls_application_record(&pending_plaintext) + }; + let reader = + tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader); stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(HandshakeOutcome::Handled); + debug!( + peer = %peer, + "Authenticated TLS session failed MTProto validation; engaging masking fallback" + ); + return Ok(masking_outcome( + reader, + writer, + Vec::new(), + peer, + local_addr, + config.clone(), + self.beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -621,25 +956,27 @@ impl RunningClientHandler { ))) } - async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { + async fn handle_direct_client( + mut self, + first_bytes: [u8; 5], + local_addr: SocketAddr, + ) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; let (reader, writer) = self.stream.into_split(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), peer, local_addr, - &self.config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + self.config.clone(), + self.beobachten.clone(), + )); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -668,17 +1005,15 @@ impl RunningClientHandler { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.to_vec(), peer, local_addr, - &config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + self.beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -727,21 +1062,21 @@ impl RunningClientHandler { { let user = success.user.clone(); - if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await { - warn!(user = %user, error = %e, "User limit exceeded"); - return Err(e); - } - - let registration = register_session(&user, success.dc_idx); - if registration.replaced_existing { - stats.increment_reconnect_evict_total(); - warn!( - user = %user, - dc = success.dc_idx, - "Reconnect detected: replacing active session for user+dc" - ); - } - let session_lease = registration.lease; + let user_limit_reservation = match Self::acquire_user_connection_reservation_static( + &user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + { + Ok(reservation) => reservation, + Err(e) => { + warn!(user = %user, error = %e, "User admission check failed"); + return Err(e); + } + }; let route_snapshot = route_runtime.snapshot(); let session_id = rng.u64(); @@ -754,7 +1089,7 @@ impl RunningClientHandler { client_writer, success, pool.clone(), - stats, + stats.clone(), config, buffer_pool, local_addr, @@ -762,7 +1097,6 @@ impl RunningClientHandler { route_runtime.subscribe(), route_snapshot, session_id, - session_lease.clone(), ) .await } else { @@ -772,14 +1106,13 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, route_runtime.subscribe(), route_snapshot, session_id, - session_lease.clone(), ) .await } @@ -790,25 +1123,82 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, route_runtime.subscribe(), route_snapshot, session_id, - session_lease.clone(), ) .await }; - - ip_tracker.remove_ip(&user, peer_addr.ip()).await; + user_limit_reservation.release().await; relay_result } + async fn acquire_user_connection_reservation_static( + user: &str, + config: &ProxyConfig, + stats: Arc, + peer_addr: SocketAddr, + ip_tracker: Arc, + ) -> Result { + if let Some(expiration) = config.access.user_expirations.get(user) + && chrono::Utc::now() > *expiration + { + return Err(ProxyError::UserExpired { + user: user.to_string(), + }); + } + + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config + .access + .user_max_tcp_conns + .get(user) + .map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + } + + Ok(UserConnectionReservation::new( + stats, + ip_tracker, + user.to_string(), + peer_addr.ip(), + )) + } + + #[cfg(test)] async fn check_user_limits_static( - user: &str, - config: &ProxyConfig, + user: &str, + config: &ProxyConfig, stats: &Stats, peer_addr: SocketAddr, ip_tracker: &UserIpTracker, @@ -821,9 +1211,32 @@ impl RunningClientHandler { }); } - let ip_reserved = match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => true, + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config + .access + .user_max_tcp_conns + .get(user) + .map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => { + ip_tracker.remove_ip(user, peer_addr.ip()).await; + stats.decrement_user_curr_connects(user); + } Err(reason) => { + stats.decrement_user_curr_connects(user); warn!( user = %user, ip = %peer_addr.ip(), @@ -834,33 +1247,84 @@ impl RunningClientHandler { user: user.to_string(), }); } - }; - // IP limit check - - if let Some(limit) = config.access.user_max_tcp_conns.get(user) - && stats.get_user_curr_connects(user) >= *limit as u64 - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_tcp_limit_total(); - } - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); - } - - if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_quota_limit_total(); - } - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); } Ok(()) } } + +#[cfg(test)] +#[path = "tests/client_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/client_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_tls_mtproto_fallback_security_tests.rs"] +mod tls_mtproto_fallback_security_tests; + +#[cfg(test)] +#[path = "tests/client_tls_clienthello_size_security_tests.rs"] +mod tls_clienthello_size_security_tests; + +#[cfg(test)] +#[path = "tests/client_tls_clienthello_truncation_adversarial_tests.rs"] +mod tls_clienthello_truncation_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_timing_profile_adversarial_tests.rs"] +mod timing_profile_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_budget_security_tests.rs"] +mod masking_budget_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_redteam_expected_fail_tests.rs"] +mod masking_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/client_masking_hard_adversarial_tests.rs"] +mod masking_hard_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_stress_adversarial_tests.rs"] +mod masking_stress_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_blackhat_campaign_tests.rs"] +mod masking_blackhat_campaign_tests; + +#[cfg(test)] +#[path = "tests/client_masking_diagnostics_security_tests.rs"] +mod masking_diagnostics_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_hardening_security_tests.rs"] +mod masking_shape_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_hardening_adversarial_tests.rs"] +mod masking_shape_hardening_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs"] +mod masking_shape_hardening_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs"] +mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] +mod masking_probe_evasion_blackhat_tests; + +#[cfg(test)] +#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] +mod beobachten_ttl_bounds_security_tests; + +#[cfg(test)] +#[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"] +mod tls_record_wrap_hardening_security_tests; diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 108949c..5d9c450 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,7 +1,11 @@ +use std::collections::HashSet; +use std::ffi::OsString; use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; +use std::path::{Component, Path, PathBuf}; use std::sync::Arc; +use std::sync::{Mutex, OnceLock}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split}; use tokio::sync::watch; @@ -17,11 +21,209 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::proxy::adaptive_buffers; -use crate::proxy::session_eviction::SessionLease; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +#[cfg(unix)] +use nix::fcntl::{Flock, FlockArg, OFlag, openat}; +#[cfg(unix)] +use nix::sys::stat::Mode; + +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; + +const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; +static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); +const MAX_SCOPE_HINT_LEN: usize = 64; + +fn validated_scope_hint(user: &str) -> Option<&str> { + let scope = user.strip_prefix("scope_")?; + if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN { + return None; + } + if scope + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + Some(scope) + } else { + None + } +} + +#[derive(Clone)] +struct SanitizedUnknownDcLogPath { + resolved_path: PathBuf, + allowed_parent: PathBuf, + file_name: OsString, +} + +// In tests, this function shares global mutable state. Callers that also use +// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions +// deterministic under parallel execution. +fn should_log_unknown_dc(dc_idx: i16) -> bool { + let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new())); + should_log_unknown_dc_with_set(set, dc_idx) +} + +fn should_log_unknown_dc_with_set(set: &Mutex>, dc_idx: i16) -> bool { + match set.lock() { + Ok(mut guard) => { + if guard.contains(&dc_idx) { + return false; + } + if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + return false; + } + guard.insert(dc_idx) + } + // Fail closed on poisoned state to avoid unbounded blocking log writes. + Err(_) => false, + } +} + +fn sanitize_unknown_dc_log_path(path: &str) -> Option { + let candidate = Path::new(path); + if candidate.as_os_str().is_empty() { + return None; + } + if candidate + .components() + .any(|component| matches!(component, Component::ParentDir)) + { + return None; + } + + let cwd = std::env::current_dir().ok()?; + let file_name = candidate.file_name()?; + let parent = candidate.parent().unwrap_or_else(|| Path::new(".")); + let parent_path = if parent.is_absolute() { + parent.to_path_buf() + } else { + cwd.join(parent) + }; + let canonical_parent = parent_path.canonicalize().ok()?; + if !canonical_parent.is_dir() { + return None; + } + + Some(SanitizedUnknownDcLogPath { + resolved_path: canonical_parent.join(file_name), + allowed_parent: canonical_parent, + file_name: file_name.to_os_string(), + }) +} + +fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool { + let Some(parent) = path.resolved_path.parent() else { + return false; + }; + let Ok(current_parent) = parent.canonicalize() else { + return false; + }; + if current_parent != path.allowed_parent { + return false; + } + + if let Ok(canonical_target) = path.resolved_path.canonicalize() { + let Some(target_parent) = canonical_target.parent() else { + return false; + }; + let Some(target_name) = canonical_target.file_name() else { + return false; + }; + if target_parent != path.allowed_parent || target_name != path.file_name { + return false; + } + } + + true +} + +#[cfg(test)] +fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { + #[cfg(unix)] + { + OpenOptions::new() + .create(true) + .append(true) + .custom_flags(libc::O_NOFOLLOW) + .open(path) + } + #[cfg(not(unix))] + { + let _ = path; + Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "unknown_dc_file_log_enabled requires unix O_NOFOLLOW support", + )) + } +} + +fn open_unknown_dc_log_append_anchored( + path: &SanitizedUnknownDcLogPath, +) -> std::io::Result { + #[cfg(unix)] + { + let parent = OpenOptions::new() + .read(true) + .custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC) + .open(&path.allowed_parent)?; + + let oflags = OFlag::O_CREAT + | OFlag::O_APPEND + | OFlag::O_WRONLY + | OFlag::O_NOFOLLOW + | OFlag::O_CLOEXEC; + let mode = Mode::from_bits_truncate(0o600); + let path_component = Path::new(path.file_name.as_os_str()); + let fd = openat(&parent, path_component, oflags, mode) + .map_err(|err| std::io::Error::from_raw_os_error(err as i32))?; + let file = std::fs::File::from(fd); + Ok(file) + } + #[cfg(not(unix))] + { + let _ = path; + Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "unknown_dc_file_log_enabled requires unix O_NOFOLLOW support", + )) + } +} + +fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> { + #[cfg(unix)] + { + let cloned = file.try_clone()?; + let mut locked = Flock::lock(cloned, FlockArg::LockExclusive) + .map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?; + let write_result = writeln!(&mut *locked, "dc_idx={dc_idx}"); + let _ = locked + .unlock() + .map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?; + write_result + } + #[cfg(not(unix))] + { + writeln!(file, "dc_idx={dc_idx}") + } +} + +#[cfg(test)] +fn clear_unknown_dc_log_cache_for_testing() { + if let Some(set) = LOGGED_UNKNOWN_DCS.get() + && let Ok(mut guard) = set.lock() + { + guard.clear(); + } +} + +#[cfg(test)] +fn unknown_dc_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} pub(crate) async fn handle_via_direct( client_reader: CryptoReader, @@ -35,7 +237,6 @@ pub(crate) async fn handle_via_direct( mut route_rx: watch::Receiver, route_snapshot: RouteCutoverState, session_id: u64, - session_lease: SessionLease, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -54,12 +255,15 @@ where "Connecting to Telegram DC" ); + let scope_hint = validated_scope_hint(user); + if user.starts_with("scope_") && scope_hint.is_none() { + warn!( + user = %user, + "Ignoring invalid scope hint and falling back to default upstream selection" + ); + } let tg_stream = upstream_manager - .connect( - dc_addr, - Some(success.dc_idx), - user.strip_prefix("scope_").filter(|s| !s.is_empty()), - ) + .connect(dc_addr, Some(success.dc_idx), scope_hint) .await?; debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); @@ -70,29 +274,19 @@ where debug!(peer = %success.peer, "TG handshake complete, starting relay"); stats.increment_user_connects(user); - stats.increment_user_curr_connects(user); - stats.increment_current_connections_direct(); - - let seed_tier = adaptive_buffers::seed_tier_for_user(user); - let (c2s_copy_buf, s2c_copy_buf) = adaptive_buffers::direct_copy_buffers_for_tier( - seed_tier, - config.general.direct_relay_copy_buf_c2s_bytes, - config.general.direct_relay_copy_buf_s2c_bytes, - ); + let _direct_connection_lease = stats.acquire_direct_connection_lease(); let relay_result = relay_bidirectional( client_reader, client_writer, tg_reader, tg_writer, - c2s_copy_buf, - s2c_copy_buf, + config.general.direct_relay_copy_buf_c2s_bytes, + config.general.direct_relay_copy_buf_s2c_bytes, user, - success.dc_idx, Arc::clone(&stats), + config.access.user_data_quota.get(user).copied(), buffer_pool, - session_lease, - seed_tier, ); tokio::pin!(relay_result); let relay_result = loop { @@ -122,9 +316,6 @@ where } }; - stats.decrement_current_connections_direct(); - stats.decrement_user_curr_connects(user); - match &relay_result { Ok(()) => debug!(user = %user, "Direct relay completed"), Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), @@ -181,12 +372,19 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { && let Some(path) = &config.general.unknown_dc_log_path && let Ok(handle) = tokio::runtime::Handle::try_current() { - let path = path.clone(); - handle.spawn_blocking(move || { - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { - let _ = writeln!(file, "dc_idx={dc_idx}"); + if let Some(path) = sanitize_unknown_dc_log_path(path) { + if should_log_unknown_dc(dc_idx) { + handle.spawn_blocking(move || { + if unknown_dc_log_path_is_still_safe(&path) + && let Ok(mut file) = open_unknown_dc_log_append_anchored(&path) + { + let _ = append_unknown_dc_line(&mut file, dc_idx); + } + }); } - }); + } else { + warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path"); + } } } @@ -194,7 +392,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { default_dc - 1 } else { - 1 + 0 }; info!( @@ -222,8 +420,6 @@ where let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( success.proto_tag, success.dc_idx, - &success.dec_key, - success.dec_iv, &success.enc_key, success.enc_iv, rng, @@ -249,3 +445,19 @@ where CryptoWriter::new(write_half, tg_encryptor, max_pending), )) } + +#[cfg(test)] +#[path = "tests/direct_relay_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_business_logic_tests.rs"] +mod business_logic_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_common_mistakes_tests.rs"] +mod common_mistakes_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_subtle_adversarial_tests.rs"] +mod subtle_adversarial_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 296432f..5632977 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -2,22 +2,479 @@ #![allow(dead_code)] +use dashmap::DashMap; +use dashmap::mapref::entry::Entry; +use std::collections::HashSet; +use std::collections::hash_map::RandomState; +use std::hash::{BuildHasher, Hash, Hasher}; use std::net::SocketAddr; +use std::net::{IpAddr, Ipv6Addr}; use std::sync::Arc; -use std::time::Duration; +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tracing::{debug, warn, trace}; -use zeroize::Zeroize; +use tracing::{debug, trace, warn}; +use zeroize::{Zeroize, Zeroizing}; -use crate::crypto::{sha256, AesCtr, SecureRandom}; -use rand::Rng; +use crate::config::ProxyConfig; +use crate::crypto::{AesCtr, SecureRandom, sha256}; +use crate::error::{HandshakeResult, ProxyError}; use crate::protocol::constants::*; use crate::protocol::tls; -use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; -use crate::error::{ProxyError, HandshakeResult}; use crate::stats::ReplayChecker; -use crate::config::ProxyConfig; +use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::tls_front::{TlsFrontCache, emulator}; +use rand::RngExt; + +const ACCESS_SECRET_BYTES: usize = 16; +static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); +#[cfg(test)] +const WARNED_SECRET_MAX_ENTRIES: usize = 64; +#[cfg(not(test))] +const WARNED_SECRET_MAX_ENTRIES: usize = 1_024; + +const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60; +#[cfg(test)] +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256; +#[cfg(not(test))] +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; +const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024; +const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; +const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000; + +#[derive(Clone, Copy)] +struct AuthProbeState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + +#[derive(Clone, Copy)] +struct AuthProbeSaturationState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + +static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); +static AUTH_PROBE_SATURATION_STATE: OnceLock>> = + OnceLock::new(); +static AUTH_PROBE_EVICTION_HASHER: OnceLock = OnceLock::new(); + +fn auth_probe_state_map() -> &'static DashMap { + AUTH_PROBE_STATE.get_or_init(DashMap::new) +} + +fn auth_probe_saturation_state() -> &'static Mutex> { + AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None)) +} + +fn auth_probe_saturation_state_lock() +-> std::sync::MutexGuard<'static, Option> { + auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { + match peer_ip { + IpAddr::V4(ip) => IpAddr::V4(ip), + IpAddr::V6(ip) => { + let [a, b, c, d, _, _, _, _] = ip.segments(); + IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0)) + } + } +} + +fn auth_probe_backoff(fail_streak: u32) -> Duration { + if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS { + return Duration::ZERO; + } + let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10); + let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX); + let ms = AUTH_PROBE_BACKOFF_BASE_MS + .saturating_mul(multiplier) + .min(AUTH_PROBE_BACKOFF_MAX_MS); + Duration::from_millis(ms) +} + +fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { + let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS); + now.duration_since(state.last_seen) > retention +} + +fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { + let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new); + let mut hasher = hasher_state.build_hasher(); + peer_ip.hash(&mut hasher); + now.hash(&mut hasher); + hasher.finish() as usize +} + +fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + now < entry.blocked_until +} + +fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + + entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS +} + +fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool { + if !auth_probe_is_throttled(peer_ip, now) { + return false; + } + + if !auth_probe_saturation_is_throttled(now) { + return true; + } + + auth_probe_saturation_grace_exhausted(peer_ip, now) +} + +fn auth_probe_saturation_is_throttled(now: Instant) -> bool { + let mut guard = auth_probe_saturation_state_lock(); + + let Some(state) = guard.as_mut() else { + return false; + }; + + if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) { + *guard = None; + return false; + } + + if now < state.blocked_until { + return true; + } + + false +} + +fn auth_probe_note_saturation(now: Instant) { + let mut guard = auth_probe_saturation_state_lock(); + + match guard.as_mut() { + Some(state) + if now.duration_since(state.last_seen) + <= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) => + { + state.fail_streak = state.fail_streak.saturating_add(1); + state.last_seen = now; + state.blocked_until = now + auth_probe_backoff(state.fail_streak); + } + _ => { + let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS; + *guard = Some(AuthProbeSaturationState { + fail_streak, + blocked_until: now + auth_probe_backoff(fail_streak), + last_seen: now, + }); + } + } +} + +fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + auth_probe_record_failure_with_state(state, peer_ip, now); +} + +fn auth_probe_record_failure_with_state( + state: &DashMap, + peer_ip: IpAddr, + now: Instant, +) { + let make_new_state = || AuthProbeState { + fail_streak: 1, + blocked_until: now + auth_probe_backoff(1), + last_seen: now, + }; + + let update_existing = |entry: &mut AuthProbeState| { + if auth_probe_state_expired(entry, now) { + *entry = make_new_state(); + } else { + entry.fail_streak = entry.fail_streak.saturating_add(1); + entry.last_seen = now; + entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); + } + }; + + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); + return; + } + Entry::Vacant(_) => {} + } + + if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + let mut rounds = 0usize; + while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + rounds += 1; + if rounds > 8 { + auth_probe_note_saturation(now); + let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; + for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + } + + let Some((evict_key, _, _)) = eviction_candidate else { + return; + }; + state.remove(&evict_key); + break; + } + + let mut stale_keys = Vec::new(); + let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; + let state_len = state.len(); + let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); + let start_offset = if state_len == 0 { + 0 + } else { + auth_probe_eviction_offset(peer_ip, now) % state_len + }; + + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } + + for stale_key in stale_keys { + state.remove(&stale_key); + } + + if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES { + break; + } + + let Some((evict_key, _, _)) = eviction_candidate else { + auth_probe_note_saturation(now); + return; + }; + state.remove(&evict_key); + auth_probe_note_saturation(now); + } + } + + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); + } + Entry::Vacant(entry) => { + entry.insert(make_new_state()); + } + } +} + +fn auth_probe_record_success(peer_ip: IpAddr) { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + state.remove(&peer_ip); +} + +#[cfg(test)] +fn clear_auth_probe_state_for_testing() { + if let Some(state) = AUTH_PROBE_STATE.get() { + state.clear(); + } + if AUTH_PROBE_SATURATION_STATE.get().is_some() { + let mut guard = auth_probe_saturation_state_lock(); + *guard = None; + } +} + +#[cfg(test)] +fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = AUTH_PROBE_STATE.get()?; + state.get(&peer_ip).map(|entry| entry.fail_streak) +} + +#[cfg(test)] +fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool { + auth_probe_is_throttled(peer_ip, Instant::now()) +} + +#[cfg(test)] +fn auth_probe_saturation_is_throttled_for_testing() -> bool { + auth_probe_saturation_is_throttled(Instant::now()) +} + +#[cfg(test)] +fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool { + auth_probe_saturation_is_throttled(now) +} + +#[cfg(test)] +fn auth_probe_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn clear_warned_secrets_for_testing() { + if let Some(warned) = INVALID_SECRET_WARNED.get() + && let Ok(mut guard) = warned.lock() + { + guard.clear(); + } +} + +#[cfg(test)] +fn warned_secrets_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option) { + let key = (name.to_string(), reason.to_string()); + let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); + let should_warn = match warned.lock() { + Ok(mut guard) => { + if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES { + false + } else { + guard.insert(key) + } + } + Err(_) => true, + }; + + if !should_warn { + return; + } + + match got { + Some(actual) => { + warn!( + user = %name, + expected = expected, + got = actual, + "Skipping user: access secret has unexpected length" + ); + } + None => { + warn!( + user = %name, + "Skipping user: access secret is not valid hex" + ); + } + } +} + +fn decode_user_secret(name: &str, secret_hex: &str) -> Option> { + match hex::decode(secret_hex) { + Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes), + Ok(bytes) => { + warn_invalid_secret_once( + name, + "invalid_length", + ACCESS_SECRET_BYTES, + Some(bytes.len()), + ); + None + } + Err(_) => { + warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None); + None + } + } +} + +// Decide whether a client-supplied proto tag is allowed given the configured +// proxy modes and the transport that carried the handshake. +// +// A common mistake is to treat `modes.tls` and `modes.secure` as interchangeable +// even though they correspond to different transport profiles: `modes.tls` is +// for the TLS-fronted (EE-TLS) path, while `modes.secure` is for direct MTProto +// over TCP (DD). Enforcing this separation prevents an attacker from using a +// TLS-capable client to bypass the operator intent for the direct MTProto mode, +// and vice versa. +fn mode_enabled_for_proto(config: &ProxyConfig, proto_tag: ProtoTag, is_tls: bool) -> bool { + match proto_tag { + ProtoTag::Secure => { + if is_tls { + config.general.modes.tls + } else { + config.general.modes.secure + } + } + ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, + } +} fn decode_user_secrets( config: &ProxyConfig, @@ -27,7 +484,7 @@ fn decode_user_secrets( if let Some(preferred) = preferred_user && let Some(secret_hex) = config.access.users.get(preferred) - && let Ok(bytes) = hex::decode(secret_hex) + && let Some(bytes) = decode_user_secret(preferred, secret_hex) { secrets.push((preferred.to_string(), bytes)); } @@ -36,7 +493,7 @@ fn decode_user_secrets( if preferred_user.is_some_and(|preferred| preferred == name.as_str()) { continue; } - if let Ok(bytes) = hex::decode(secret_hex) { + if let Some(bytes) = decode_user_secret(name, secret_hex) { secrets.push((name.clone(), bytes)); } } @@ -44,11 +501,29 @@ fn decode_user_secrets( secrets } +async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { + if config.censorship.server_hello_delay_max_ms == 0 { + return; + } + + let min = config.censorship.server_hello_delay_min_ms; + let max = config.censorship.server_hello_delay_max_ms.max(min); + let delay_ms = if max == min { + max + } else { + rand::rng().random_range(min..=max) + }; + + if delay_ms > 0 { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } +} + /// Result of successful handshake /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is /// zeroized on drop. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct HandshakeSuccess { /// Authenticated user name pub user: String, @@ -59,7 +534,7 @@ pub struct HandshakeSuccess { /// Decryption key and IV (for reading from client) pub dec_key: [u8; 32], pub dec_iv: u128, - /// Encryption key and IV (for writing to client) + /// Encryption key and IV (for writing to client) pub enc_key: [u8; 32], pub enc_iv: u128, /// Client address @@ -94,30 +569,35 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); + return HandshakeResult::BadClient { reader, writer }; + } + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } - let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; - let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; + let client_sni = tls::extract_sni_from_client_hello(handshake); + let secrets = decode_user_secrets(config, client_sni.as_deref()); - if replay_checker.check_and_add_tls_digest(digest_half) { - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - - let secrets = decode_user_secrets(config, None); - - let validation = match tls::validate_tls_handshake( + let validation = match tls::validate_tls_handshake_with_replay_window( handshake, &secrets, config.access.ignore_time_skew, + config.access.replay_window_secs, ) { Some(v) => v, None => { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!( - peer = %peer, + peer = %peer, ignore_time_skew = config.access.ignore_time_skew, "TLS handshake validation failed - no matching user or time skew" ); @@ -125,16 +605,29 @@ where } }; + // Replay tracking is applied only after successful authentication to avoid + // letting unauthenticated probes evict valid entries from the replay cache. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_and_add_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, - None => return HandshakeResult::BadClient { reader, writer }, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } }; let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { - if cache.contains_domain(&sni).await { - sni + let selected_domain = if let Some(sni) = client_sni.as_ref() { + if cache.contains_domain(sni).await { + sni.clone() } else { config.censorship.tls_domain.clone() } @@ -166,6 +659,10 @@ where Some(b"h2".to_vec()) } else if alpn_list.iter().any(|p| p == b"http/1.1") { Some(b"http/1.1".to_vec()) + } else if !alpn_list.is_empty() { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); + return HandshakeResult::BadClient { reader, writer }; } else { None } @@ -196,19 +693,9 @@ where ) }; - // Optional anti-fingerprint delay before sending ServerHello. - if config.censorship.server_hello_delay_max_ms > 0 { - let min = config.censorship.server_hello_delay_min_ms; - let max = config.censorship.server_hello_delay_max_ms.max(min); - let delay_ms = if max == min { - max - } else { - rand::rng().random_range(min..=max) - }; - if delay_ms > 0 { - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - } + // Apply the same optional delay budget used by reject paths to reduce + // distinguishability between success and fail-closed handshakes. + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -228,6 +715,8 @@ where "TLS handshake successful" ); + auth_probe_record_success(peer.ip()); + HandshakeResult::Success(( FakeTlsReader::new(reader), FakeTlsWriter::new(writer), @@ -250,75 +739,93 @@ where R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { - trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); + let handshake_fingerprint = { + let digest = sha256(&handshake[..8]); + hex::encode(&digest[..4]) + }; + trace!( + peer = %peer, + handshake_fingerprint = %handshake_fingerprint, + "MTProto handshake prefix" + ); - let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - - if replay_checker.check_and_add_handshake(dec_prekey_iv) { - warn!(peer = %peer, "MTProto replay attack detected"); + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } + let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let decoded_users = decode_user_secrets(config, preferred_user); for (user, secret) in decoded_users { - let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); let dec_key = sha256(&dec_key_input); - let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut decryptor = AesCtr::new(&dec_key, dec_iv); let decrypted = decryptor.decrypt(handshake); - let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] - .try_into() - .unwrap(); + let tag_bytes: [u8; 4] = [ + decrypted[PROTO_TAG_POS], + decrypted[PROTO_TAG_POS + 1], + decrypted[PROTO_TAG_POS + 2], + decrypted[PROTO_TAG_POS + 3], + ]; let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, None => continue, }; - let mode_ok = match proto_tag { - ProtoTag::Secure => { - if is_tls { - config.general.modes.tls || config.general.modes.secure - } else { - config.general.modes.secure || config.general.modes.tls - } - } - ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, - }; + let mode_ok = mode_enabled_for_proto(config, proto_tag, is_tls); if !mode_ok { debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); continue; } - let dc_idx = i16::from_le_bytes( - decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() - ); + let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; - let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); let enc_key = sha256(&enc_key_input); - let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(enc_iv_bytes); + let enc_iv = u128::from_be_bytes(enc_iv_arr); let encryptor = AesCtr::new(&enc_key, enc_iv); + // Apply replay tracking only after successful authentication. + // + // This ordering prevents an attacker from producing invalid handshakes that + // still collide with a valid handshake's replay slot and thus evict a valid + // entry from the cache. We accept the cost of performing the full + // authentication check first to avoid poisoning the replay cache. + if replay_checker.check_and_add_handshake(dec_prekey_iv) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, user = %user, "MTProto replay attack detected"); + return HandshakeResult::BadClient { reader, writer }; + } + let success = HandshakeSuccess { user: user.clone(), dc_idx, @@ -340,6 +847,8 @@ where "MTProto handshake successful" ); + auth_probe_record_success(peer.ip()); + let max_pending = config.general.crypto_pending_buffer; return HandshakeResult::Success(( CryptoReader::new(reader, decryptor), @@ -348,16 +857,16 @@ where )); } + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } /// Generate nonce for Telegram connection pub fn generate_tg_nonce( - proto_tag: ProtoTag, + proto_tag: ProtoTag, dc_idx: i16, - _client_dec_key: &[u8; 32], - _client_dec_iv: u128, client_enc_key: &[u8; 32], client_enc_iv: u128, rng: &SecureRandom, @@ -365,22 +874,30 @@ pub fn generate_tg_nonce( ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { loop { let bytes = rng.bytes(HANDSHAKE_LEN); - let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); + let Ok(mut nonce): Result<[u8; HANDSHAKE_LEN], _> = bytes.try_into() else { + continue; + }; - if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } + if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { + continue; + } - let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); - if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; + if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { + continue; + } - let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); - if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } + let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]]; + if RESERVED_NONCE_CONTINUES.contains(&continue_four) { + continue; + } nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); // CRITICAL: write dc_idx so upstream DC knows where to route nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); if fast_mode { - let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN); + let mut key_iv = Zeroizing::new(Vec::with_capacity(KEY_LEN + IV_LEN)); key_iv.extend_from_slice(client_enc_key); key_iv.extend_from_slice(&client_enc_iv.to_be_bytes()); key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce @@ -388,13 +905,19 @@ pub fn generate_tg_nonce( } let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; - let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::>()); - let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_enc_key = [0u8; 32]; + tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut tg_enc_iv_arr = [0u8; IV_LEN]; + tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let tg_enc_iv = u128::from_be_bytes(tg_enc_iv_arr); - let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_dec_key = [0u8; 32]; + tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut tg_dec_iv_arr = [0u8; IV_LEN]; + tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let tg_dec_iv = u128::from_be_bytes(tg_dec_iv_arr); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); } @@ -403,21 +926,29 @@ pub fn generate_tg_nonce( /// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, AesCtr, AesCtr) { let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; - let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::>()); - let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut enc_key = [0u8; 32]; + enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let enc_iv = u128::from_be_bytes(enc_iv_arr); - let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut dec_key = [0u8; 32]; + dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut encryptor = AesCtr::new(&enc_key, enc_iv); - let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 + let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 let mut result = nonce[..PROTO_TAG_POS].to_vec(); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); let decryptor = AesCtr::new(&dec_key, dec_iv); + enc_key.zeroize(); + dec_key.zeroize(); (result, encryptor, decryptor) } @@ -429,80 +960,31 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { } #[cfg(test)] -mod tests { - use super::*; +#[path = "tests/handshake_security_tests.rs"] +mod security_tests; - #[test] - fn test_generate_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; +#[cfg(test)] +#[path = "tests/handshake_adversarial_tests.rs"] +mod adversarial_tests; - let rng = SecureRandom::new(); - let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); +#[cfg(test)] +#[path = "tests/handshake_fuzz_security_tests.rs"] +mod fuzz_security_tests; - assert_eq!(nonce.len(), HANDSHAKE_LEN); +#[cfg(test)] +#[path = "tests/handshake_saturation_poison_security_tests.rs"] +mod saturation_poison_security_tests; - let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); - assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); - } +#[cfg(test)] +#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] +mod auth_probe_hardening_adversarial_tests; - #[test] - fn test_encrypt_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; +/// Compile-time guard: HandshakeSuccess holds cryptographic key material and +/// must never be Copy. A Copy impl would allow silent key duplication, +/// undermining the zeroize-on-drop guarantee. +mod compile_time_security_checks { + use super::HandshakeSuccess; + use static_assertions::assert_not_impl_all; - let rng = SecureRandom::new(); - let (nonce, _, _, _, _) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); - - let encrypted = encrypt_tg_nonce(&nonce); - - assert_eq!(encrypted.len(), HANDSHAKE_LEN); - assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); - assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); - } - - #[test] - fn test_handshake_success_zeroize_on_drop() { - let success = HandshakeSuccess { - user: "test".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Secure, - dec_key: [0xAA; 32], - dec_iv: 0xBBBBBBBB, - enc_key: [0xCC; 32], - enc_iv: 0xDDDDDDDD, - peer: "127.0.0.1:1234".parse().unwrap(), - is_tls: true, - }; - - assert_eq!(success.dec_key, [0xAA; 32]); - assert_eq!(success.enc_key, [0xCC; 32]); - - drop(success); - // Drop impl zeroizes key material without panic - } + assert_not_impl_all!(HandshakeSuccess: Copy, Clone); } diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 318071b..3639db1 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -1,32 +1,231 @@ //! Masking - forward unrecognized traffic to mask host -use std::str; -use std::net::SocketAddr; -use std::time::Duration; -use tokio::net::TcpStream; -#[cfg(unix)] -use tokio::net::UnixStream; -use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tokio::time::timeout; -use tracing::debug; use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; +use rand::{Rng, RngExt}; +use std::net::SocketAddr; +use std::str; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::time::{Instant, timeout}; +use tracing::debug; +#[cfg(not(test))] const MASK_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_TIMEOUT: Duration = Duration::from_millis(50); /// Maximum duration for the entire masking relay. /// Limits resource consumption from slow-loris attacks and port scanners. +#[cfg(not(test))] const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); +#[cfg(test)] +const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); +#[cfg(not(test))] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +struct CopyOutcome { + total: usize, + ended_by_eof: bool, +} + +async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) -> CopyOutcome +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut total = 0usize; + let mut ended_by_eof = false; + loop { + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let n = match read_res { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { + ended_by_eof = true; + break; + } + total = total.saturating_add(n); + + let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await; + match write_res { + Ok(Ok(())) => {} + Ok(Err(_)) | Err(_) => break, + } + } + CopyOutcome { + total, + ended_by_eof, + } +} + +fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize { + if total == 0 || floor == 0 || cap < floor { + return total; + } + + if total >= cap { + return total; + } + + let mut bucket = floor; + while bucket < total { + match bucket.checked_mul(2) { + Some(next) => bucket = next, + None => return total, + } + if bucket > cap { + return cap; + } + } + bucket +} + +async fn maybe_write_shape_padding( + mask_write: &mut W, + total_sent: usize, + enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, + aggressive_mode: bool, +) where + W: AsyncWrite + Unpin, +{ + if !enabled { + return; + } + + let target_total = if total_sent >= cap && above_cap_blur && above_cap_blur_max_bytes > 0 { + let mut rng = rand::rng(); + let extra = if aggressive_mode { + rng.random_range(1..=above_cap_blur_max_bytes) + } else { + rng.random_range(0..=above_cap_blur_max_bytes) + }; + total_sent.saturating_add(extra) + } else { + next_mask_shape_bucket(total_sent, floor, cap) + }; + + if target_total <= total_sent { + return; + } + + let mut remaining = target_total - total_sent; + let mut pad_chunk = [0u8; 1024]; + let deadline = Instant::now() + MASK_TIMEOUT; + + while remaining > 0 { + let now = Instant::now(); + if now >= deadline { + return; + } + + let write_len = remaining.min(pad_chunk.len()); + { + let mut rng = rand::rng(); + rng.fill_bytes(&mut pad_chunk[..write_len]); + } + let write_budget = deadline.saturating_duration_since(now); + match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await { + Ok(Ok(())) => {} + Ok(Err(_)) | Err(_) => return, + } + remaining -= write_len; + } + + let now = Instant::now(); + if now >= deadline { + return; + } + let flush_budget = deadline.saturating_duration_since(now); + let _ = timeout(flush_budget, mask_write.flush()).await; +} + +async fn write_proxy_header_with_timeout(mask_write: &mut W, header: &[u8]) -> bool +where + W: AsyncWrite + Unpin, +{ + match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await { + Ok(Ok(())) => true, + Ok(Err(_)) => false, + Err(_) => { + debug!("Timeout writing proxy protocol header to mask backend"); + false + } + } +} + +async fn consume_client_data_with_timeout(reader: R) +where + R: AsyncRead + Unpin, +{ + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) + .await + .is_err() + { + debug!("Timed out while consuming client data on masking fallback path"); + } +} + +async fn wait_mask_connect_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + +fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { + if config.censorship.mask_timing_normalization_enabled { + let floor = config.censorship.mask_timing_normalization_floor_ms; + let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; + if ceiling > floor { + let mut rng = rand::rng(); + return Duration::from_millis(rng.random_range(floor..=ceiling)); + } + return Duration::from_millis(floor); + } + + MASK_TIMEOUT +} + +async fn wait_mask_connect_budget_if_needed(started: Instant, config: &ProxyConfig) { + if config.censorship.mask_timing_normalization_enabled { + return; + } + + wait_mask_connect_budget(started).await; +} + +async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { + let target = mask_outcome_target_budget(config); + let elapsed = started.elapsed(); + if elapsed < target { + tokio::time::sleep(target - elapsed).await; + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request if data.len() > 4 - && (data.starts_with(b"GET ") || data.starts_with(b"POST") || - data.starts_with(b"HEAD") || data.starts_with(b"PUT ") || - data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS")) + && (data.starts_with(b"GET ") + || data.starts_with(b"POST") + || data.starts_with(b"HEAD") + || data.starts_with(b"PUT ") + || data.starts_with(b"DELETE") + || data.starts_with(b"OPTIONS")) { return "HTTP"; } @@ -49,6 +248,33 @@ fn detect_client_type(data: &[u8]) -> &'static str { "unknown" } +fn build_mask_proxy_header( + version: u8, + peer: SocketAddr, + local_addr: SocketAddr, +) -> Option> { + match version { + 0 => None, + 2 => Some( + ProxyProtocolV2Builder::new() + .with_addrs(peer, local_addr) + .build(), + ), + _ => { + let header = match (peer, local_addr) { + (SocketAddr::V4(src), SocketAddr::V4(dst)) => ProxyProtocolV1Builder::new() + .tcp4(src.into(), dst.into()) + .build(), + (SocketAddr::V6(src), SocketAddr::V6(dst)) => ProxyProtocolV1Builder::new() + .tcp6(src.into(), dst.into()) + .build(), + _ => ProxyProtocolV1Builder::new().build(), + }; + Some(header) + } + } +} + /// Handle a bad client by forwarding to mask host pub async fn handle_bad_client( reader: R, @@ -58,8 +284,7 @@ pub async fn handle_bad_client( local_addr: SocketAddr, config: &ProxyConfig, beobachten: &BeobachtenStore, -) -where +) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, { @@ -71,13 +296,15 @@ where if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; return; } // Connect via Unix socket or TCP #[cfg(unix)] if let Some(ref sock_path) = config.censorship.mask_unix_sock { + let outcome_started = Instant::now(); + let connect_started = Instant::now(); debug!( client_type = client_type, sock = %sock_path, @@ -89,45 +316,59 @@ where match connect_result { Ok(Ok(stream)) => { let (mask_read, mut mask_write) = stream.into_split(); - let proxy_header: Option> = match config.censorship.mask_proxy_protocol { - 0 => None, - version => { - let header = match version { - 2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(), - _ => match (peer, local_addr) { - (SocketAddr::V4(src), SocketAddr::V4(dst)) => - ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(), - (SocketAddr::V6(src), SocketAddr::V6(dst)) => - ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(), - _ => - ProxyProtocolV1Builder::new().build(), - }, - }; - Some(header) - } - }; - if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { - return; - } + let proxy_header = build_mask_proxy_header( + config.censorship.mask_proxy_protocol, + peer, + local_addr, + ); + if let Some(header) = proxy_header + && !write_proxy_header_with_timeout(&mut mask_write, &header).await + { + wait_mask_outcome_budget(outcome_started, config).await; + return; } - if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { + if timeout( + MASK_RELAY_TIMEOUT, + relay_to_mask( + reader, + writer, + mask_read, + mask_write, + initial_data, + config.censorship.mask_shape_hardening, + config.censorship.mask_shape_bucket_floor_bytes, + config.censorship.mask_shape_bucket_cap_bytes, + config.censorship.mask_shape_above_cap_blur, + config.censorship.mask_shape_above_cap_blur_max_bytes, + config.censorship.mask_shape_hardening_aggressive_mode, + ), + ) + .await + .is_err() + { debug!("Mask relay timed out (unix socket)"); } + wait_mask_outcome_budget(outcome_started, config).await; } Ok(Err(e)) => { + wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } } return; } - let mask_host = config.censorship.mask_host.as_deref() + let mask_host = config + .censorship + .mask_host + .as_deref() .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; @@ -143,44 +384,54 @@ where let mask_addr = resolve_socket_addr(mask_host, mask_port) .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); + let outcome_started = Instant::now(); + let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { - let proxy_header: Option> = match config.censorship.mask_proxy_protocol { - 0 => None, - version => { - let header = match version { - 2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(), - _ => match (peer, local_addr) { - (SocketAddr::V4(src), SocketAddr::V4(dst)) => - ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(), - (SocketAddr::V6(src), SocketAddr::V6(dst)) => - ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(), - _ => - ProxyProtocolV1Builder::new().build(), - }, - }; - Some(header) - } - }; + let proxy_header = + build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); let (mask_read, mut mask_write) = stream.into_split(); - if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { - return; - } + if let Some(header) = proxy_header + && !write_proxy_header_with_timeout(&mut mask_write, &header).await + { + wait_mask_outcome_budget(outcome_started, config).await; + return; } - if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { + if timeout( + MASK_RELAY_TIMEOUT, + relay_to_mask( + reader, + writer, + mask_read, + mask_write, + initial_data, + config.censorship.mask_shape_hardening, + config.censorship.mask_shape_bucket_floor_bytes, + config.censorship.mask_shape_bucket_cap_bytes, + config.censorship.mask_shape_above_cap_blur, + config.censorship.mask_shape_above_cap_blur_max_bytes, + config.censorship.mask_shape_hardening_aggressive_mode, + ), + ) + .await + .is_err() + { debug!("Mask relay timed out"); } + wait_mask_outcome_budget(outcome_started, config).await; } Ok(Err(e)) => { + wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } } } @@ -192,8 +443,13 @@ async fn relay_to_mask( mut mask_read: MR, mut mask_write: MW, initial_data: &[u8], -) -where + shape_hardening_enabled: bool, + shape_bucket_floor_bytes: usize, + shape_bucket_cap_bytes: usize, + shape_above_cap_blur: bool, + shape_above_cap_blur_max_bytes: usize, + shape_hardening_aggressive_mode: bool, +) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, MR: AsyncRead + Unpin + Send + 'static, @@ -203,47 +459,36 @@ where if mask_write.write_all(initial_data).await.is_err() { return; } - - // Relay traffic - let c2m = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match reader.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = mask_write.shutdown().await; - break; - } - Ok(n) => { - if mask_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - } - } - }); - - let m2c = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match mask_read.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = writer.shutdown().await; - break; - } - Ok(n) => { - if writer.write_all(&buf[..n]).await.is_err() { - break; - } - } - } - } - }); - - // Wait for either to complete - tokio::select! { - _ = c2m => {} - _ = m2c => {} + if mask_write.flush().await.is_err() { + return; } + + let (upstream_copy, downstream_copy) = tokio::join!( + async { copy_with_idle_timeout(&mut reader, &mut mask_write).await }, + async { copy_with_idle_timeout(&mut mask_read, &mut writer).await } + ); + + let total_sent = initial_data.len().saturating_add(upstream_copy.total); + + let should_shape = shape_hardening_enabled + && !initial_data.is_empty() + && (upstream_copy.ended_by_eof + || (shape_hardening_aggressive_mode && downstream_copy.total == 0)); + + maybe_write_shape_padding( + &mut mask_write, + total_sent, + should_shape, + shape_bucket_floor_bytes, + shape_bucket_cap_bytes, + shape_above_cap_blur, + shape_above_cap_blur_max_bytes, + shape_hardening_aggressive_mode, + ) + .await; + + let _ = mask_write.shutdown().await; + let _ = writer.shutdown().await; } /// Just consume all data from client without responding @@ -255,3 +500,51 @@ async fn consume_client_data(mut reader: R) { } } } + +#[cfg(test)] +#[path = "tests/masking_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/masking_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_hardening_adversarial_tests.rs"] +mod masking_shape_hardening_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_above_cap_blur_security_tests.rs"] +mod masking_shape_above_cap_blur_security_tests; + +#[cfg(test)] +#[path = "tests/masking_timing_normalization_security_tests.rs"] +mod masking_timing_normalization_security_tests; + +#[cfg(test)] +#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"] +mod masking_ab_envelope_blur_integration_security_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_guard_security_tests.rs"] +mod masking_shape_guard_security_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_guard_adversarial_tests.rs"] +mod masking_shape_guard_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_classifier_resistance_adversarial_tests.rs"] +mod masking_shape_classifier_resistance_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_bypass_blackhat_tests.rs"] +mod masking_shape_bypass_blackhat_tests; + +#[cfg(test)] +#[path = "tests/masking_aggressive_mode_security_tests.rs"] +mod masking_aggressive_mode_security_tests; + +#[cfg(test)] +#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] +mod masking_timing_sidechannel_redteam_expected_fail_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 102b06c..d0f5ffb 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,44 +1,68 @@ -use std::collections::HashMap; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; +use std::collections::hash_map::RandomState; +use std::collections::{BTreeSet, HashMap}; +use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant}; -use bytes::Bytes; +use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot, watch}; -use tracing::{debug, trace, warn}; +use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; +use tokio::time::timeout; +use tracing::{debug, info, trace, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; -use crate::protocol::constants::{*, secure_padding_len}; +use crate::protocol::constants::{secure_padding_len, *}; use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::route_mode::{ - RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state, + ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::proxy::adaptive_buffers::{self, AdaptiveTier}; -use crate::proxy::session_eviction::SessionLease; -use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: Bytes, flags: u32 }, + Data { payload: PooledBuffer, flags: u32 }, Close, } const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); +const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; +const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024; +const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000); const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); +#[cfg(test)] +const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); +#[cfg(not(test))] +const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; -static DESYNC_DEDUP: OnceLock>> = OnceLock::new(); +const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; +const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; +#[cfg(test)] +const QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; +static DESYNC_DEDUP: OnceLock> = OnceLock::new(); +static DESYNC_HASHER: OnceLock = OnceLock::new(); +static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); +static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); +static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); +static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); +static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); struct RelayForensicsState { trace_id: u64, @@ -52,17 +76,208 @@ struct RelayForensicsState { desync_all_full: bool, } +#[derive(Default)] +struct RelayIdleCandidateRegistry { + by_conn_id: HashMap, + ordered: BTreeSet<(u64, u64)>, + pressure_event_seq: u64, + pressure_consumed_seq: u64, +} + +#[derive(Clone, Copy)] +struct RelayIdleCandidateMeta { + mark_order_seq: u64, + mark_pressure_seq: u64, +} + +fn relay_idle_candidate_registry() -> &'static Mutex { + RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) +} + +fn mark_relay_idle_candidate(conn_id: u64) -> bool { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return false; + }; + + if guard.by_conn_id.contains_key(&conn_id) { + return false; + } + + let mark_order_seq = RELAY_IDLE_MARK_SEQ + .fetch_add(1, Ordering::Relaxed) + .saturating_add(1); + let meta = RelayIdleCandidateMeta { + mark_order_seq, + mark_pressure_seq: guard.pressure_event_seq, + }; + guard.by_conn_id.insert(conn_id, meta); + guard.ordered.insert((meta.mark_order_seq, conn_id)); + true +} + +fn clear_relay_idle_candidate(conn_id: u64) { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return; + }; + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } +} + +#[cfg(test)] +fn oldest_relay_idle_candidate() -> Option { + let Ok(guard) = relay_idle_candidate_registry().lock() else { + return None; + }; + guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) +} + +fn note_relay_pressure_event() { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return; + }; + guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); +} + +fn relay_pressure_event_seq() -> u64 { + let Ok(guard) = relay_idle_candidate_registry().lock() else { + return 0; + }; + guard.pressure_event_seq +} + +fn maybe_evict_idle_candidate_on_pressure( + conn_id: u64, + seen_pressure_seq: &mut u64, + stats: &Stats, +) -> bool { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return false; + }; + + let latest_pressure_seq = guard.pressure_event_seq; + if latest_pressure_seq == *seen_pressure_seq { + return false; + } + *seen_pressure_seq = latest_pressure_seq; + + if latest_pressure_seq == guard.pressure_consumed_seq { + return false; + } + + if guard.ordered.is_empty() { + guard.pressure_consumed_seq = latest_pressure_seq; + return false; + } + + let oldest = guard + .ordered + .iter() + .next() + .map(|(_, candidate_conn_id)| *candidate_conn_id); + if oldest != Some(conn_id) { + return false; + } + + let Some(candidate_meta) = guard.by_conn_id.get(&conn_id).copied() else { + return false; + }; + + // Pressure events that happened before candidate soft-mark are stale for this candidate. + if latest_pressure_seq == candidate_meta.mark_pressure_seq { + return false; + } + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } + guard.pressure_consumed_seq = latest_pressure_seq; + stats.increment_relay_pressure_evict_total(); + true +} + +#[cfg(test)] +fn clear_relay_idle_pressure_state_for_testing() { + if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() + && let Ok(mut guard) = registry.lock() + { + guard.by_conn_id.clear(); + guard.ordered.clear(); + guard.pressure_event_seq = 0; + guard.pressure_consumed_seq = 0; + } + RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); +} + #[derive(Clone, Copy)] struct MeD2cFlushPolicy { max_frames: usize, max_bytes: usize, max_delay: Duration, ack_flush_immediate: bool, + quota_soft_overshoot_bytes: u64, + frame_buf_shrink_threshold_bytes: usize, +} + +#[derive(Clone, Copy)] +struct RelayClientIdlePolicy { + enabled: bool, + soft_idle: Duration, + hard_idle: Duration, + grace_after_downstream_activity: Duration, + legacy_frame_read_timeout: Duration, +} + +impl RelayClientIdlePolicy { + fn from_config(config: &ProxyConfig) -> Self { + Self { + enabled: config.timeouts.relay_idle_policy_v2_enabled, + soft_idle: Duration::from_secs(config.timeouts.relay_client_idle_soft_secs.max(1)), + hard_idle: Duration::from_secs(config.timeouts.relay_client_idle_hard_secs.max(1)), + grace_after_downstream_activity: Duration::from_secs( + config + .timeouts + .relay_idle_grace_after_downstream_activity_secs, + ), + legacy_frame_read_timeout: Duration::from_secs(config.timeouts.client_handshake.max(1)), + } + } + + #[cfg(test)] + fn disabled(frame_read_timeout: Duration) -> Self { + Self { + enabled: false, + soft_idle: Duration::from_secs(0), + hard_idle: Duration::from_secs(0), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: frame_read_timeout, + } + } +} + +struct RelayClientIdleState { + last_client_frame_at: Instant, + soft_idle_marked: bool, +} + +impl RelayClientIdleState { + fn new(now: Instant) -> Self { + Self { + last_client_frame_at: now, + soft_idle_marked: false, + } + } + + fn on_client_frame(&mut self, now: Instant) { + self.last_client_frame_at = now; + self.soft_idle_marked = false; + } } impl MeD2cFlushPolicy { - fn from_config(config: &ProxyConfig, tier: AdaptiveTier) -> Self { - let base = Self { + fn from_config(config: &ProxyConfig) -> Self { + Self { max_frames: config .general .me_d2c_flush_batch_max_frames @@ -73,26 +288,18 @@ impl MeD2cFlushPolicy { .max(ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN), max_delay: Duration::from_micros(config.general.me_d2c_flush_batch_max_delay_us), ack_flush_immediate: config.general.me_d2c_ack_flush_immediate, - }; - let (max_frames, max_bytes, max_delay) = adaptive_buffers::me_flush_policy_for_tier( - tier, - base.max_frames, - base.max_bytes, - base.max_delay, - ); - Self { - max_frames, - max_bytes, - max_delay, - ack_flush_immediate: base.ack_flush_immediate, + quota_soft_overshoot_bytes: config.general.me_quota_soft_overshoot_bytes, + frame_buf_shrink_threshold_bytes: config + .general + .me_d2c_frame_buf_shrink_threshold_bytes + .max(4096), } } } fn hash_value(value: &T) -> u64 { - let mut hasher = DefaultHasher::new(); - value.hash(&mut hasher); - hasher.finish() + let state = DESYNC_HASHER.get_or_init(RandomState::new); + state.hash_one(value) } fn hash_ip(ip: IpAddr) -> u64 { @@ -104,26 +311,129 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { return true; } - let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new())); - let mut guard = dedup.lock().expect("desync dedup mutex poisoned"); - guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW); + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; + let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false)); + if saturated_before { + ever_saturated.store(true, Ordering::Relaxed); + } - match guard.get_mut(&key) { - Some(seen_at) => { - if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { - *seen_at = now; + if let Some(mut seen_at) = dedup.get_mut(&key) { + if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { + *seen_at = now; + return true; + } + return false; + } + + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + let mut stale_keys = Vec::new(); + let mut oldest_candidate: Option<(u64, Instant)> = None; + for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { + let key = *entry.key(); + let seen_at = *entry.value(); + + match oldest_candidate { + Some((_, oldest_seen)) if seen_at >= oldest_seen => {} + _ => oldest_candidate = Some((key, seen_at)), + } + + if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW { + stale_keys.push(*entry.key()); + } + } + for stale_key in stale_keys { + dedup.remove(&stale_key); + } + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + let Some((evict_key, _)) = oldest_candidate else { + return false; + }; + dedup.remove(&evict_key); + dedup.insert(key, now); + return should_emit_full_desync_full_cache(now); + } + } + + dedup.insert(key, now); + let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; + // Preserve the first sequential insert that reaches capacity as a normal + // emit, while still gating concurrent newcomer churn after the cache has + // ever been observed at saturation. + let was_ever_saturated = if saturated_after { + ever_saturated.swap(true, Ordering::Relaxed) + } else { + ever_saturated.load(Ordering::Relaxed) + }; + + if saturated_before || (saturated_after && was_ever_saturated) { + should_emit_full_desync_full_cache(now) + } else { + true + } +} + +fn should_emit_full_desync_full_cache(now: Instant) -> bool { + let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); + let Ok(mut last_emit_at) = gate.lock() else { + return false; + }; + + match *last_emit_at { + None => { + *last_emit_at = Some(now); + true + } + Some(last) => { + let Some(elapsed) = now.checked_duration_since(last) else { + *last_emit_at = Some(now); + return true; + }; + if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL { + *last_emit_at = Some(now); true } else { false } } - None => { - guard.insert(key, now); - true + } +} + +#[cfg(test)] +fn clear_desync_dedup_for_testing() { + if let Some(dedup) = DESYNC_DEDUP.get() { + dedup.clear(); + } + if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() { + ever_saturated.store(false, Ordering::Relaxed); + } + if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() { + match last_emit_at.lock() { + Ok(mut guard) => { + *guard = None; + } + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = None; + last_emit_at.clear_poison(); + } } } } +#[cfg(test)] +fn desync_dedup_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +fn desync_forensics_len_bytes(len: usize) -> ([u8; 4], bool) { + match u32::try_from(len) { + Ok(value) => (value.to_le_bytes(), false), + Err(_) => (u32::MAX.to_le_bytes(), true), + } +} + fn report_desync_frame_too_large( state: &RelayForensicsState, proto_tag: ProtoTag, @@ -133,7 +443,8 @@ fn report_desync_frame_too_large( raw_len_bytes: Option<[u8; 4]>, stats: &Stats, ) -> ProxyError { - let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes()); + let (fallback_len_buf, len_buf_truncated) = desync_forensics_len_bytes(len); + let len_buf = raw_len_bytes.unwrap_or(fallback_len_buf); let looks_like_tls = raw_len_bytes .map(|b| b[0] == 0x16 && b[1] == 0x03) .unwrap_or(false); @@ -152,6 +463,7 @@ fn report_desync_frame_too_large( let bytes_me2c = state.bytes_me2c.load(Ordering::Relaxed); stats.increment_desync_total(); + stats.increment_relay_protocol_desync_close_total(); stats.observe_desync_frames_ok(frame_counter); if emit_full { stats.increment_desync_full_logged(); @@ -168,6 +480,7 @@ fn report_desync_frame_too_large( bytes_me2c, raw_len = len, raw_len_hex = format_args!("0x{:08x}", len), + raw_len_bytes_truncated = len_buf_truncated, raw_bytes = format_args!( "{:02x} {:02x} {:02x} {:02x}", len_buf[0], len_buf[1], len_buf[2], len_buf[3] @@ -210,8 +523,7 @@ fn report_desync_frame_too_large( ProxyError::Proxy(format!( "Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}", - state.conn_id, - state.trace_id + state.conn_id, state.trace_id )) } @@ -219,23 +531,155 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } +fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { + quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) +} + +#[cfg_attr(not(test), allow(dead_code))] +fn quota_would_be_exceeded_for_user( + stats: &Stats, + user: &str, + quota_limit: Option, + bytes: u64, +) -> bool { + quota_limit.is_some_and(|quota| { + let used = stats.get_user_total_octets(user); + used >= quota || bytes > quota.saturating_sub(used) + }) +} + +fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { + limit.saturating_add(overshoot) +} + +fn quota_exceeded_for_user_soft( + stats: &Stats, + user: &str, + quota_limit: Option, + overshoot: u64, +) -> bool { + quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot)) +} + +fn quota_would_be_exceeded_for_user_soft( + stats: &Stats, + user: &str, + quota_limit: Option, + bytes: u64, + overshoot: u64, +) -> bool { + quota_limit.is_some_and(|quota| { + let cap = quota_soft_cap(quota, overshoot); + let used = stats.get_user_total_octets(user); + used >= cap || bytes > cap.saturating_sub(used) + }) +} + +fn classify_me_d2c_flush_reason( + flush_immediately: bool, + batch_frames: usize, + max_frames: usize, + batch_bytes: usize, + max_bytes: usize, + max_delay_fired: bool, +) -> MeD2cFlushReason { + if flush_immediately { + return MeD2cFlushReason::AckImmediate; + } + if batch_frames >= max_frames { + return MeD2cFlushReason::BatchFrames; + } + if batch_bytes >= max_bytes { + return MeD2cFlushReason::BatchBytes; + } + if max_delay_fired { + return MeD2cFlushReason::MaxDelay; + } + MeD2cFlushReason::QueueDrain +} + +fn observe_me_d2c_flush_event( + stats: &Stats, + reason: MeD2cFlushReason, + batch_frames: usize, + batch_bytes: usize, + flush_duration_us: Option, +) { + stats.increment_me_d2c_flush_reason(reason); + if batch_frames > 0 || batch_bytes > 0 { + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(batch_frames as u64); + stats.add_me_d2c_batch_bytes_total(batch_bytes as u64); + stats.observe_me_d2c_batch_frames(batch_frames as u64); + stats.observe_me_d2c_batch_bytes(batch_bytes as u64); + } + if let Some(duration_us) = flush_duration_us { + stats.observe_me_d2c_flush_duration_us(duration_us); + } +} + +#[cfg(test)] +fn quota_user_lock_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { + quota_user_lock_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(AsyncMutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} + +fn quota_user_lock(user: &str) -> Arc> { + let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + return quota_overflow_user_lock(user); + } + + let created = Arc::new(AsyncMutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } + } +} + async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, - send_timeout: Duration, ) -> std::result::Result<(), mpsc::error::SendError> { match tx.try_send(cmd) { Ok(()) => Ok(()), Err(mpsc::error::TrySendError::Closed(cmd)) => Err(mpsc::error::SendError(cmd)), Err(mpsc::error::TrySendError::Full(cmd)) => { + note_relay_pressure_event(); // Cooperative yield reduces burst catch-up when the per-conn queue is near saturation. if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { tokio::task::yield_now().await; } - if send_timeout.is_zero() { - return tx.send(cmd).await; - } - match tokio::time::timeout(send_timeout, tx.reserve()).await { + match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await { Ok(Ok(permit)) => { permit.send(cmd); Ok(()) @@ -254,23 +698,22 @@ pub(crate) async fn handle_via_middle_proxy( me_pool: Arc, stats: Arc, config: Arc, - _buffer_pool: Arc, + buffer_pool: Arc, local_addr: SocketAddr, rng: Arc, mut route_rx: watch::Receiver, route_snapshot: RouteCutoverState, session_id: u64, - session_lease: SessionLease, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, { let user = success.user.clone(); + let quota_limit = config.access.user_data_quota.get(&user).copied(); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); - let seed_tier = adaptive_buffers::seed_tier_for_user(&user); debug!( user = %user, @@ -283,7 +726,7 @@ where ); let (conn_id, me_rx) = me_pool.registry().register().await; - let trace_id = conn_id; + let trace_id = session_id; let bytes_me2c = Arc::new(AtomicU64::new(0)); let mut forensics = RelayForensicsState { trace_id, @@ -298,14 +741,11 @@ where }; stats.increment_user_connects(&user); - stats.increment_user_curr_connects(&user); - stats.increment_current_connections_me(); + let _me_connection_lease = stats.acquire_me_connection_lease(); - if let Some(cutover) = affected_cutover_state( - &route_rx, - RelayRouteMode::Middle, - route_snapshot.generation, - ) { + if let Some(cutover) = + affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) + { let delay = cutover_stagger_delay(session_id, cutover.generation); warn!( conn_id, @@ -317,20 +757,9 @@ where tokio::time::sleep(delay).await; let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); } - if session_lease.is_stale() { - stats.increment_reconnect_stale_close_total(); - let _ = me_pool.send_close(conn_id).await; - me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); - return Err(ProxyError::Proxy("Session evicted by reconnect".to_string())); - } - // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) let user_tag: Option> = config .access @@ -361,56 +790,33 @@ where let translated_local_addr = me_pool.translate_our_addr(local_addr); let frame_limit = config.general.max_client_frame; + let relay_idle_policy = RelayClientIdlePolicy::from_config(&config); + let session_started_at = forensics.started_at; + let mut relay_idle_state = RelayClientIdleState::new(session_started_at); + let last_downstream_activity_ms = Arc::new(AtomicU64::new(0)); let c2me_channel_capacity = config .general .me_c2me_channel_capacity .max(C2ME_CHANNEL_CAPACITY_FALLBACK); - let c2me_send_timeout = Duration::from_millis(config.general.me_c2me_send_timeout_ms); let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); - let effective_tag = effective_tag; let c2me_sender = tokio::spawn(async move { let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { match cmd { C2MeCommand::Data { payload, flags } => { - if c2me_send_timeout.is_zero() { - me_pool_c2me - .send_proxy_req( - conn_id, - success.dc_idx, - peer, - translated_local_addr, - payload.as_ref(), - flags, - effective_tag.as_deref(), - ) - .await?; - } else { - match tokio::time::timeout( - c2me_send_timeout, - me_pool_c2me.send_proxy_req( - conn_id, - success.dc_idx, - peer, - translated_local_addr, - payload.as_ref(), - flags, - effective_tag.as_deref(), - ), + me_pool_c2me + .send_proxy_req( + conn_id, + success.dc_idx, + peer, + translated_local_addr, + payload.as_ref(), + flags, + effective_tag.as_deref(), ) - .await - { - Ok(send_result) => send_result?, - Err(_) => { - return Err(ProxyError::Proxy(format!( - "ME send timeout after {}ms", - c2me_send_timeout.as_millis() - ))); - } - } - } + .await?; sent_since_yield = sent_since_yield.saturating_add(1); if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { sent_since_yield = 0; @@ -431,8 +837,9 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); - let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config, seed_tier); + let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); let me_writer = tokio::spawn(async move { let mut writer = crypto_writer; let mut frame_buf = Vec::with_capacity(16 * 1024); @@ -447,7 +854,10 @@ where let mut batch_frames = 0usize; let mut batch_bytes = 0usize; let mut flush_immediately; + let mut max_delay_fired = false; + let first_is_downstream_activity = + matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( first, &mut writer, @@ -456,18 +866,42 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, false, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if first_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately = immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -480,6 +914,8 @@ where break; }; + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( next, &mut writer, @@ -488,18 +924,44 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = + if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -510,8 +972,11 @@ where && batch_frames < d2c_flush_policy.max_frames && batch_bytes < d2c_flush_policy.max_bytes { + stats_clone.increment_me_d2c_batch_timeout_armed_total(); match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await { Ok(Some(next)) => { + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( next, &mut writer, @@ -520,18 +985,47 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone + .telemetry_policy() + .me_level + .allows_debug() + { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -544,6 +1038,8 @@ where break; }; + let extra_is_downstream_activity = + matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( extra, &mut writer, @@ -552,18 +1048,47 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if extra_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone + .telemetry_policy() + .me_level + .allows_debug() + { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -573,11 +1098,50 @@ where debug!(conn_id, "ME channel closed"); return Err(ProxyError::Proxy("ME connection lost".into())); } - Err(_) => {} + Err(_) => { + max_delay_fired = true; + stats_clone.increment_me_d2c_batch_timeout_fired_total(); + } } } + let flush_reason = classify_me_d2c_flush_reason( + flush_immediately, + batch_frames, + d2c_flush_policy.max_frames, + batch_bytes, + d2c_flush_policy.max_bytes, + max_delay_fired, + ); + let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; writer.flush().await.map_err(ProxyError::Io)?; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + flush_reason, + batch_frames, + batch_bytes, + flush_duration_us, + ); + let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; + let shrink_trigger = shrink_threshold + .saturating_mul(ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR); + if frame_buf.capacity() > shrink_trigger { + let cap_before = frame_buf.capacity(); + frame_buf.shrink_to(shrink_threshold); + let cap_after = frame_buf.capacity(); + let bytes_freed = cap_before.saturating_sub(cap_after) as u64; + stats_clone.observe_me_d2c_frame_buf_shrink(bytes_freed); + } } _ = &mut stop_rx => { debug!(conn_id, "ME writer stop signal"); @@ -591,18 +1155,31 @@ where let mut client_closed = false; let mut frame_counter: u64 = 0; let mut route_watch_open = true; + let mut seen_pressure_seq = relay_pressure_event_seq(); loop { - if session_lease.is_stale() { - stats.increment_reconnect_stale_close_total(); - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await; - main_result = Err(ProxyError::Proxy("Session evicted by reconnect".to_string())); + if relay_idle_policy.enabled + && maybe_evict_idle_candidate_on_pressure( + conn_id, + &mut seen_pressure_seq, + stats.as_ref(), + ) + { + info!( + conn_id, + trace_id = format_args!("0x{:016x}", trace_id), + user = %user, + "Middle-relay pressure eviction for idle-candidate session" + ); + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + main_result = Err(ProxyError::Proxy( + "middle-relay session evicted under pressure (idle-candidate)".to_string(), + )); break; } - if let Some(cutover) = affected_cutover_state( - &route_rx, - RelayRouteMode::Middle, - route_snapshot.generation, - ) { + + if let Some(cutover) = + affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) + { let delay = cutover_stagger_delay(session_id, cutover.generation); warn!( conn_id, @@ -612,7 +1189,7 @@ where "Cutover affected middle session, closing client connection" ); tokio::time::sleep(delay).await; - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); break; } @@ -623,13 +1200,18 @@ where route_watch_open = false; } } - payload_result = read_client_payload( + payload_result = read_client_payload_with_idle_policy( &mut crypto_reader, proto_tag, frame_limit, + &buffer_pool, &forensics, &mut frame_counter, &stats, + &relay_idle_policy, + &mut relay_idle_state, + last_downstream_activity_ms.as_ref(), + session_started_at, ) => { match payload_result { Ok(Some((payload, quickack))) => { @@ -637,7 +1219,19 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - stats.add_user_octets_from(&user, payload.len() as u64); + if let Some(limit) = quota_limit { + let quota_lock = quota_user_lock(&user); + let _quota_guard = quota_lock.lock().await; + stats.add_user_octets_from(&user, payload.len() as u64); + if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + main_result = Err(ProxyError::DataQuotaExceeded { + user: user.clone(), + }); + break; + } + } else { + stats.add_user_octets_from(&user, payload.len() as u64); + } let mut flags = proto_flags; if quickack { flags |= RPC_FLAG_QUICKACK; @@ -646,13 +1240,9 @@ where flags |= RPC_FLAG_NOT_ENCRYPTED; } // Keep client read loop lightweight: route heavy ME send path via a dedicated task. - if enqueue_c2me_command( - &c2me_tx, - C2MeCommand::Data { payload, flags }, - c2me_send_timeout, - ) - .await - .is_err() + if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags }) + .await + .is_err() { main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); break; @@ -661,12 +1251,7 @@ where Ok(None) => { debug!(conn_id, "Client EOF"); client_closed = true; - let _ = enqueue_c2me_command( - &c2me_tx, - C2MeCommand::Close, - c2me_send_timeout, - ) - .await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; break; } Err(e) => { @@ -715,41 +1300,204 @@ where frames_ok = frame_counter, "ME relay cleanup" ); - adaptive_buffers::record_user_tier(&user, seed_tier); + clear_relay_idle_candidate(conn_id); me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); result } -async fn read_client_payload( +async fn read_client_payload_with_idle_policy( client_reader: &mut CryptoReader, proto_tag: ProtoTag, max_frame: usize, + buffer_pool: &Arc, forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, -) -> Result> + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> where R: AsyncRead + Unpin + Send + 'static, { + async fn read_exact_with_policy( + client_reader: &mut CryptoReader, + buf: &mut [u8], + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, + forensics: &RelayForensicsState, + stats: &Stats, + read_label: &'static str, + ) -> Result<()> + where + R: AsyncRead + Unpin + Send + 'static, + { + fn hard_deadline( + idle_policy: &RelayClientIdlePolicy, + idle_state: &RelayClientIdleState, + session_started_at: Instant, + last_downstream_activity_ms: u64, + ) -> Instant { + let mut deadline = idle_state.last_client_frame_at + idle_policy.hard_idle; + if idle_policy.grace_after_downstream_activity.is_zero() { + return deadline; + } + + let downstream_at = + session_started_at + Duration::from_millis(last_downstream_activity_ms); + if downstream_at > idle_state.last_client_frame_at { + let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; + if grace_deadline > deadline { + deadline = grace_deadline; + } + } + deadline + } + + let mut filled = 0usize; + while filled < buf.len() { + let timeout_window = if idle_policy.enabled { + let now = Instant::now(); + let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); + let hard_deadline = + hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); + if now >= hard_deadline { + clear_relay_idle_candidate(forensics.conn_id); + stats.increment_relay_idle_hard_close_total(); + let client_idle_secs = now + .saturating_duration_since(idle_state.last_client_frame_at) + .as_secs(); + let downstream_idle_secs = now + .saturating_duration_since( + session_started_at + Duration::from_millis(downstream_ms), + ) + .as_secs(); + warn!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + client_idle_secs, + downstream_idle_secs, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay hard idle close" + ); + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}", + idle_policy.soft_idle.as_secs(), + idle_policy.hard_idle.as_secs(), + idle_policy.grace_after_downstream_activity.as_secs(), + ), + ))); + } + + if !idle_state.soft_idle_marked + && now.saturating_duration_since(idle_state.last_client_frame_at) + >= idle_policy.soft_idle + { + idle_state.soft_idle_marked = true; + if mark_relay_idle_candidate(forensics.conn_id) { + stats.increment_relay_idle_soft_mark_total(); + } + info!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay soft idle mark" + ); + } + + let soft_deadline = idle_state.last_client_frame_at + idle_policy.soft_idle; + let next_deadline = if idle_state.soft_idle_marked { + hard_deadline + } else { + soft_deadline.min(hard_deadline) + }; + let mut remaining = next_deadline.saturating_duration_since(now); + if remaining.is_zero() { + remaining = Duration::from_millis(1); + } + remaining.min(RELAY_IDLE_IO_POLL_MAX) + } else { + idle_policy.legacy_frame_read_timeout + }; + + let read_result = timeout(timeout_window, client_reader.read(&mut buf[filled..])).await; + match read_result { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::from( + std::io::ErrorKind::UnexpectedEof, + ))); + } + Ok(Ok(n)) => { + filled = filled.saturating_add(n); + } + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) if !idle_policy.enabled => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "middle-relay client frame read timeout while reading {read_label}" + ), + ))); + } + Err(_) => {} + } + } + + Ok(()) + } + loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { let mut first = [0u8; 1]; - match client_reader.read_exact(&mut first).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_policy( + client_reader, + &mut first, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "abridged.first_len_byte", + ) + .await + { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (first[0] & 0x80) != 0; let len_words = if (first[0] & 0x7f) == 0x7f { let mut ext = [0u8; 3]; - client_reader - .read_exact(&mut ext) - .await - .map_err(ProxyError::Io)?; + read_exact_with_policy( + client_reader, + &mut ext, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "abridged.extended_len", + ) + .await?; u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize } else { (first[0] & 0x7f) as usize @@ -762,10 +1510,24 @@ where } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len_buf = [0u8; 4]; - match client_reader.read_exact(&mut len_buf).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_policy( + client_reader, + &mut len_buf, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "len_prefix", + ) + .await + { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (len_buf[3] & 0x80) != 0; ( @@ -788,6 +1550,7 @@ where proto = ?proto_tag, "Frame too small — corrupt or probe" ); + stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy(format!("Frame too small: {len}"))); } @@ -808,6 +1571,7 @@ where Some(payload_len) => payload_len, None => { stats.increment_secure_padding_invalid(); + stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy(format!( "Invalid secure frame length: {len}" ))); @@ -817,21 +1581,98 @@ where len }; - let mut payload = vec![0u8; len]; - client_reader - .read_exact(&mut payload) - .await - .map_err(ProxyError::Io)?; + let mut payload = buffer_pool.get(); + payload.clear(); + let current_cap = payload.capacity(); + if current_cap < len { + payload.reserve(len - current_cap); + } + payload.resize(len, 0); + read_exact_with_policy( + client_reader, + &mut payload[..len], + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "payload", + ) + .await?; // Secure Intermediate: strip validated trailing padding bytes. if proto_tag == ProtoTag::Secure { payload.truncate(secure_payload_len); } *frame_counter += 1; - return Ok(Some((Bytes::from(payload), quickack))); + idle_state.on_client_frame(Instant::now()); + clear_relay_idle_candidate(forensics.conn_id); + return Ok(Some((payload, quickack))); } } +#[cfg(test)] +async fn read_client_payload_legacy( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + let now = Instant::now(); + let mut idle_state = RelayClientIdleState::new(now); + let last_downstream_activity_ms = AtomicU64::new(0); + let idle_policy = RelayClientIdlePolicy::disabled(frame_read_timeout); + read_client_payload_with_idle_policy( + client_reader, + proto_tag, + max_frame, + buffer_pool, + forensics, + frame_counter, + stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + now, + ) + .await +} + +#[cfg(test)] +async fn read_client_payload( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + read_client_payload_legacy( + client_reader, + proto_tag, + max_frame, + frame_read_timeout, + buffer_pool, + forensics, + frame_counter, + stats, + ) + .await +} + enum MeWriterResponseOutcome { Continue { frames: usize, @@ -849,6 +1690,8 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, bytes_me2c: &AtomicU64, conn_id: u64, ack_flush_immediate: bool, @@ -864,17 +1707,41 @@ where } else { trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } + let data_len = data.len() as u64; + if quota_would_be_exceeded_for_user_soft( + stats, + user, + quota_limit, + data_len, + quota_soft_overshoot_bytes, + ) { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let write_mode = + write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await?; + stats.increment_me_d2c_write_mode(write_mode); + bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); stats.add_user_octets_to(user, data.len() as u64); - write_client_payload( - client_writer, - proto_tag, - flags, - &data, - rng, - frame_buf, - ) - .await?; + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data.len() as u64); + + if quota_exceeded_for_user_soft( + stats, + user, + quota_limit, + quota_soft_overshoot_bytes, + ) { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } Ok(MeWriterResponseOutcome::Continue { frames: 1, @@ -889,6 +1756,7 @@ where trace!(conn_id, confirm, "ME->C quickack"); } write_client_ack(client_writer, proto_tag, confirm).await?; + stats.increment_me_d2c_ack_frames_total(); Ok(MeWriterResponseOutcome::Continue { frames: 1, @@ -907,6 +1775,31 @@ where } } +fn compute_intermediate_secure_wire_len( + data_len: usize, + padding_len: usize, + quickack: bool, +) -> Result<(u32, usize)> { + let wire_len = data_len + .checked_add(padding_len) + .ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?; + if wire_len > 0x7fff_ffffusize { + return Err(ProxyError::Proxy(format!( + "Intermediate/Secure frame too large: {wire_len}" + ))); + } + + let total = 4usize + .checked_add(wire_len) + .ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?; + let mut len_val = u32::try_from(wire_len) + .map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?; + if quickack { + len_val |= 0x8000_0000; + } + Ok((len_val, total)) +} + async fn write_client_payload( client_writer: &mut CryptoWriter, proto_tag: ProtoTag, @@ -914,13 +1807,13 @@ async fn write_client_payload( data: &[u8], rng: &SecureRandom, frame_buf: &mut Vec, -) -> Result<()> +) -> Result where W: AsyncWrite + Unpin + Send + 'static, { let quickack = (flags & RPC_FLAG_QUICKACK) != 0; - match proto_tag { + let write_mode = match proto_tag { ProtoTag::Abridged => { if !data.len().is_multiple_of(4) { return Err(ProxyError::Proxy(format!( @@ -935,28 +1828,46 @@ where if quickack { first |= 0x80; } - frame_buf.clear(); - frame_buf.reserve(1 + data.len()); - frame_buf.push(first); - frame_buf.extend_from_slice(data); - client_writer - .write_all(frame_buf) - .await - .map_err(ProxyError::Io)?; + let wire_len = 1usize.saturating_add(data.len()); + if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(wire_len); + frame_buf.push(first); + frame_buf.extend_from_slice(data); + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Coalesced + } else { + let header = [first]; + client_writer.write_all(&header).await.map_err(ProxyError::Io)?; + client_writer.write_all(data).await.map_err(ProxyError::Io)?; + MeD2cWriteMode::Split + } } else if len_words < (1 << 24) { let mut first = 0x7fu8; if quickack { first |= 0x80; } let lw = (len_words as u32).to_le_bytes(); - frame_buf.clear(); - frame_buf.reserve(4 + data.len()); - frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); - frame_buf.extend_from_slice(data); - client_writer - .write_all(frame_buf) - .await - .map_err(ProxyError::Io)?; + let wire_len = 4usize.saturating_add(data.len()); + if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(wire_len); + frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); + frame_buf.extend_from_slice(data); + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Coalesced + } else { + let header = [first, lw[0], lw[1], lw[2]]; + client_writer.write_all(&header).await.map_err(ProxyError::Io)?; + client_writer.write_all(data).await.map_err(ProxyError::Io)?; + MeD2cWriteMode::Split + } } else { return Err(ProxyError::Proxy(format!( "Abridged frame too large: {}", @@ -976,28 +1887,46 @@ where } else { 0 }; - let mut len_val = (data.len() + padding_len) as u32; - if quickack { - len_val |= 0x8000_0000; - } - let total = 4 + data.len() + padding_len; - frame_buf.clear(); - frame_buf.reserve(total); - frame_buf.extend_from_slice(&len_val.to_le_bytes()); - frame_buf.extend_from_slice(data); - if padding_len > 0 { - let start = frame_buf.len(); - frame_buf.resize(start + padding_len, 0); - rng.fill(&mut frame_buf[start..]); - } - client_writer - .write_all(frame_buf) - .await - .map_err(ProxyError::Io)?; - } - } - Ok(()) + let (len_val, total) = + compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?; + if total <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(total); + frame_buf.extend_from_slice(&len_val.to_le_bytes()); + frame_buf.extend_from_slice(data); + if padding_len > 0 { + let start = frame_buf.len(); + frame_buf.resize(start + padding_len, 0); + rng.fill(&mut frame_buf[start..]); + } + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Coalesced + } else { + let header = len_val.to_le_bytes(); + client_writer.write_all(&header).await.map_err(ProxyError::Io)?; + client_writer.write_all(data).await.map_err(ProxyError::Io)?; + if padding_len > 0 { + frame_buf.clear(); + if frame_buf.capacity() < padding_len { + frame_buf.reserve(padding_len); + } + frame_buf.resize(padding_len, 0); + rng.fill(frame_buf.as_mut_slice()); + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + } + MeD2cWriteMode::Split + } + } + }; + + Ok(write_mode) } async fn write_client_ack( @@ -1020,84 +1949,33 @@ where } #[cfg(test)] -mod tests { - use super::*; - use tokio::time::{Duration as TokioDuration, timeout}; +#[path = "tests/middle_relay_security_tests.rs"] +mod security_tests; - #[test] - fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - } +#[cfg(test)] +#[path = "tests/middle_relay_idle_policy_security_tests.rs"] +mod idle_policy_security_tests; - #[tokio::test] - async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: Bytes::from_static(&[1, 2, 3]), - flags: 0, - }, - TokioDuration::from_millis(50), - ) - .await - .unwrap(); +#[cfg(test)] +#[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"] +mod desync_all_full_dedup_security_tests; - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } +#[cfg(test)] +#[path = "tests/middle_relay_stub_completion_security_tests.rs"] +mod stub_completion_security_tests; - #[tokio::test] - async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: Bytes::from_static(&[9]), - flags: 9, - }) - .await - .unwrap(); +#[cfg(test)] +#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"] +mod coverage_high_risk_security_tests; - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: Bytes::from_static(&[7, 7]), - flags: 7, - }, - TokioDuration::from_millis(100), - ) - .await - .unwrap(); - }); +#[cfg(test)] +#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"] +mod quota_overflow_lock_security_tests; - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); +#[cfg(test)] +#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"] +mod length_cast_hardening_security_tests; - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } -} +#[cfg(test)] +#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] +mod blackhat_campaign_integration_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index ab840f6..eebc188 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,13 +1,71 @@ //! Proxy Defs +// Apply strict linting to proxy production code while keeping test builds noise-tolerant. +#![cfg_attr(test, allow(warnings))] +#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] +#![cfg_attr( + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) +)] +#![cfg_attr( + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) +)] + pub mod adaptive_buffers; pub mod client; pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; -pub mod route_mode; pub mod relay; +pub mod route_mode; pub mod session_eviction; pub use client::ClientHandler; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2b12d5a..2431ff4 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -51,24 +51,19 @@ //! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to` //! - `SharedCounters` (atomics) let the watchdog read stats without locking -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::task::{Context, Poll}; -use std::time::Duration; -use tokio::io::{ - AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes, -}; -use tokio::time::Instant; -use tracing::{debug, trace, warn}; -use crate::error::Result; -use crate::proxy::adaptive_buffers::{ - self, AdaptiveTier, RelaySignalSample, SessionAdaptiveController, TierTransitionReason, -}; -use crate::proxy::session_eviction::SessionLease; +use crate::error::{ProxyError, Result}; use crate::stats::Stats; use crate::stream::BufferPool; +use dashmap::DashMap; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; +use tokio::time::Instant; +use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -83,7 +78,11 @@ const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800); /// 10 seconds gives responsive timeout detection (±10s accuracy) /// without measurable overhead from atomic reads. const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10); -const ADAPTIVE_TICK: Duration = Duration::from_millis(250); + +#[inline] +fn watchdog_delta(current: u64, previous: u64) -> u64 { + current.saturating_sub(previous) +} // ============= CombinedStream ============= @@ -160,16 +159,6 @@ struct SharedCounters { s2c_ops: AtomicU64, /// Milliseconds since relay epoch of last I/O activity last_activity_ms: AtomicU64, - /// Bytes requested to write to client (S→C direction). - s2c_requested_bytes: AtomicU64, - /// Total write operations for S→C direction. - s2c_write_ops: AtomicU64, - /// Number of partial writes to client. - s2c_partial_writes: AtomicU64, - /// Number of times S→C poll_write returned Pending. - s2c_pending_writes: AtomicU64, - /// Consecutive pending writes in S→C direction. - s2c_consecutive_pending_writes: AtomicU64, } impl SharedCounters { @@ -180,11 +169,6 @@ impl SharedCounters { c2s_ops: AtomicU64::new(0), s2c_ops: AtomicU64::new(0), last_activity_ms: AtomicU64::new(0), - s2c_requested_bytes: AtomicU64::new(0), - s2c_write_ops: AtomicU64::new(0), - s2c_partial_writes: AtomicU64::new(0), - s2c_pending_writes: AtomicU64::new(0), - s2c_consecutive_pending_writes: AtomicU64::new(0), } } @@ -225,6 +209,12 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + quota_limit: Option, + quota_exceeded: Arc, + quota_read_wake_scheduled: bool, + quota_write_wake_scheduled: bool, + quota_read_retry_active: Arc, + quota_write_retry_active: Arc, epoch: Instant, } @@ -234,11 +224,136 @@ impl StatsIo { counters: Arc, stats: Arc, user: String, + quota_limit: Option, + quota_exceeded: Arc, epoch: Instant, ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); - Self { inner, counters, stats, user, epoch } + Self { + inner, + counters, + stats, + user, + quota_limit, + quota_exceeded, + quota_read_wake_scheduled: false, + quota_write_wake_scheduled: false, + quota_read_retry_active: Arc::new(AtomicBool::new(false)), + quota_write_retry_active: Arc::new(AtomicBool::new(false)), + epoch, + } + } +} + +impl Drop for StatsIo { + fn drop(&mut self) { + self.quota_read_retry_active.store(false, Ordering::Relaxed); + self.quota_write_retry_active + .store(false, Ordering::Relaxed); + } +} + +#[derive(Debug)] +struct QuotaIoSentinel; + +impl std::fmt::Display for QuotaIoSentinel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("user data quota exceeded") + } +} + +impl std::error::Error for QuotaIoSentinel {} + +fn quota_io_error() -> io::Error { + io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel) +} + +fn is_quota_io_error(err: &io::Error) -> bool { + err.kind() == io::ErrorKind::PermissionDenied + && err + .get_ref() + .and_then(|source| source.downcast_ref::()) + .is_some() +} + +#[cfg(test)] +const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); +#[cfg(not(test))] +const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); + +fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { + tokio::task::spawn(async move { + loop { + if !retry_active.load(Ordering::Relaxed) { + break; + } + tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; + if !retry_active.load(Ordering::Relaxed) { + break; + } + waker.wake_by_ref(); + } + }); +} + +static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); + +#[cfg(test)] +const QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; + +#[cfg(test)] +fn quota_user_lock_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { + quota_user_lock_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(Mutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} + +fn quota_user_lock(user: &str) -> Arc> { + let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + return quota_overflow_user_lock(user); + } + + let created = Arc::new(Mutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } } } @@ -249,20 +364,82 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); + if this.quota_exceeded.load(Ordering::Relaxed) { + return Poll::Ready(Err(quota_io_error())); + } + + let quota_lock = this + .quota_limit + .is_some() + .then(|| quota_user_lock(&this.user)); + let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + this.quota_read_wake_scheduled = false; + this.quota_read_retry_active.store(false, Ordering::Relaxed); + Some(guard) + } + Err(_) => { + if !this.quota_read_wake_scheduled { + this.quota_read_wake_scheduled = true; + this.quota_read_retry_active.store(true, Ordering::Relaxed); + spawn_quota_retry_waker( + Arc::clone(&this.quota_read_retry_active), + cx.waker().clone(), + ); + } + return Poll::Pending; + } + } + } else { + None + }; + + if let Some(limit) = this.quota_limit + && this.stats.get_user_total_octets(&this.user) >= limit + { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { + let mut reached_quota_boundary = false; + if let Some(limit) = this.quota_limit { + let used = this.stats.get_user_total_octets(&this.user); + if used >= limit { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + let remaining = limit - used; + if (n as u64) > remaining { + // Fail closed: when a single read chunk would cross quota, + // stop relay immediately without accounting beyond the cap. + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + reached_quota_boundary = (n as u64) == remaining; + } + // C→S: client sent data - this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed); + this.counters + .c2s_bytes + .fetch_add(n as u64, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); this.stats.add_user_octets_from(&this.user, n as u64); this.stats.increment_user_msgs_from(&this.user); + if reached_quota_boundary { + this.quota_exceeded.store(true, Ordering::Relaxed); + } + trace!(user = %this.user, bytes = n, "C->S"); } Poll::Ready(Ok(())) @@ -279,43 +456,81 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - this.counters - .s2c_requested_bytes - .fetch_add(buf.len() as u64, Ordering::Relaxed); + if this.quota_exceeded.load(Ordering::Relaxed) { + return Poll::Ready(Err(quota_io_error())); + } - match Pin::new(&mut this.inner).poll_write(cx, buf) { - Poll::Ready(Ok(n)) => { - this.counters.s2c_write_ops.fetch_add(1, Ordering::Relaxed); - this.counters - .s2c_consecutive_pending_writes - .store(0, Ordering::Relaxed); - if n < buf.len() { - this.counters - .s2c_partial_writes - .fetch_add(1, Ordering::Relaxed); + let quota_lock = this + .quota_limit + .is_some() + .then(|| quota_user_lock(&this.user)); + let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + this.quota_write_wake_scheduled = false; + this.quota_write_retry_active + .store(false, Ordering::Relaxed); + Some(guard) } + Err(_) => { + if !this.quota_write_wake_scheduled { + this.quota_write_wake_scheduled = true; + this.quota_write_retry_active.store(true, Ordering::Relaxed); + spawn_quota_retry_waker( + Arc::clone(&this.quota_write_retry_active), + cx.waker().clone(), + ); + } + return Poll::Pending; + } + } + } else { + None + }; + + let write_buf = if let Some(limit) = this.quota_limit { + let used = this.stats.get_user_total_octets(&this.user); + if used >= limit { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + let remaining = (limit - used) as usize; + if buf.len() > remaining { + // Fail closed: do not emit partial S->C payload when remaining + // quota cannot accommodate the pending write request. + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + buf + } else { + buf + }; + + match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + Poll::Ready(Ok(n)) => { if n > 0 { // S→C: data written to client - this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed); + this.counters + .s2c_bytes + .fetch_add(n as u64, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); this.stats.add_user_octets_to(&this.user, n as u64); this.stats.increment_user_msgs_to(&this.user); + if let Some(limit) = this.quota_limit + && this.stats.get_user_total_octets(&this.user) >= limit + { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + trace!(user = %this.user, bytes = n, "S->C"); } Poll::Ready(Ok(n)) } - Poll::Pending => { - this.counters - .s2c_pending_writes - .fetch_add(1, Ordering::Relaxed); - this.counters - .s2c_consecutive_pending_writes - .fetch_add(1, Ordering::Relaxed); - Poll::Pending - } other => other, } } @@ -348,7 +563,8 @@ impl AsyncWrite for StatsIo { /// - Per-user stats: bytes and ops counted per direction /// - Periodic rate logging: every 10 seconds when active /// - Clean shutdown: both write sides are shut down on exit -/// - Error propagation: I/O errors are returned as `ProxyError::Io` +/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`, +/// other I/O failures are returned as `ProxyError::Io` pub async fn relay_bidirectional( client_reader: CR, client_writer: CW, @@ -357,11 +573,9 @@ pub async fn relay_bidirectional( c2s_buf_size: usize, s2c_buf_size: usize, user: &str, - dc_idx: i16, stats: Arc, + quota_limit: Option, _buffer_pool: Arc, - session_lease: SessionLease, - seed_tier: AdaptiveTier, ) -> Result<()> where CR: AsyncRead + Unpin + Send + 'static, @@ -371,6 +585,7 @@ where { let epoch = Instant::now(); let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); let user_owned = user.to_string(); // ── Combine split halves into bidirectional streams ────────────── @@ -383,43 +598,31 @@ where Arc::clone(&counters), Arc::clone(&stats), user_owned.clone(), + quota_limit, + Arc::clone("a_exceeded), epoch, ); // ── Watchdog: activity timeout + periodic rate logging ────────── let wd_counters = Arc::clone(&counters); let wd_user = user_owned.clone(); - let wd_dc = dc_idx; - let wd_stats = Arc::clone(&stats); - let wd_session = session_lease.clone(); + let wd_quota_exceeded = Arc::clone("a_exceeded); let watchdog = async { - let mut prev_c2s_log: u64 = 0; - let mut prev_s2c_log: u64 = 0; - let mut prev_c2s_sample: u64 = 0; - let mut prev_s2c_requested_sample: u64 = 0; - let mut prev_s2c_written_sample: u64 = 0; - let mut prev_s2c_write_ops_sample: u64 = 0; - let mut prev_s2c_partial_sample: u64 = 0; - let mut accumulated_log = Duration::ZERO; - let mut adaptive = SessionAdaptiveController::new(seed_tier); + let mut prev_c2s: u64 = 0; + let mut prev_s2c: u64 = 0; loop { - tokio::time::sleep(ADAPTIVE_TICK).await; - - if wd_session.is_stale() { - wd_stats.increment_reconnect_stale_close_total(); - warn!( - user = %wd_user, - dc = wd_dc, - "Session evicted by reconnect" - ); - return; - } + tokio::time::sleep(WATCHDOG_INTERVAL).await; let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); + if wd_quota_exceeded.load(Ordering::Relaxed) { + warn!(user = %wd_user, "User data quota reached, closing relay"); + return; + } + // ── Activity timeout ──────────────────────────────────── if idle >= ACTIVITY_TIMEOUT { let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed); @@ -434,80 +637,11 @@ where return; // Causes select! to cancel copy_bidirectional } - let c2s_total = wd_counters.c2s_bytes.load(Ordering::Relaxed); - let s2c_requested_total = wd_counters - .s2c_requested_bytes - .load(Ordering::Relaxed); - let s2c_written_total = wd_counters.s2c_bytes.load(Ordering::Relaxed); - let s2c_write_ops_total = wd_counters - .s2c_write_ops - .load(Ordering::Relaxed); - let s2c_partial_total = wd_counters - .s2c_partial_writes - .load(Ordering::Relaxed); - let consecutive_pending = wd_counters - .s2c_consecutive_pending_writes - .load(Ordering::Relaxed) as u32; - - let sample = RelaySignalSample { - c2s_bytes: c2s_total.saturating_sub(prev_c2s_sample), - s2c_requested_bytes: s2c_requested_total - .saturating_sub(prev_s2c_requested_sample), - s2c_written_bytes: s2c_written_total - .saturating_sub(prev_s2c_written_sample), - s2c_write_ops: s2c_write_ops_total - .saturating_sub(prev_s2c_write_ops_sample), - s2c_partial_writes: s2c_partial_total - .saturating_sub(prev_s2c_partial_sample), - s2c_consecutive_pending_writes: consecutive_pending, - }; - - if let Some(transition) = adaptive.observe(sample, ADAPTIVE_TICK.as_secs_f64()) { - match transition.reason { - TierTransitionReason::SoftConfirmed => { - wd_stats.increment_relay_adaptive_promotions_total(); - } - TierTransitionReason::HardPressure => { - wd_stats.increment_relay_adaptive_promotions_total(); - wd_stats.increment_relay_adaptive_hard_promotions_total(); - } - TierTransitionReason::QuietDemotion => { - wd_stats.increment_relay_adaptive_demotions_total(); - } - } - adaptive_buffers::record_user_tier(&wd_user, adaptive.max_tier_seen()); - debug!( - user = %wd_user, - dc = wd_dc, - from_tier = transition.from.as_u8(), - to_tier = transition.to.as_u8(), - reason = ?transition.reason, - throughput_ema_bps = sample - .c2s_bytes - .max(sample.s2c_written_bytes) - .saturating_mul(8) - .saturating_mul(4), - "Adaptive relay tier transition" - ); - } - - prev_c2s_sample = c2s_total; - prev_s2c_requested_sample = s2c_requested_total; - prev_s2c_written_sample = s2c_written_total; - prev_s2c_write_ops_sample = s2c_write_ops_total; - prev_s2c_partial_sample = s2c_partial_total; - - accumulated_log = accumulated_log.saturating_add(ADAPTIVE_TICK); - if accumulated_log < WATCHDOG_INTERVAL { - continue; - } - accumulated_log = Duration::ZERO; - // ── Periodic rate logging ─────────────────────────────── let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed); let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed); - let c2s_delta = c2s.saturating_sub(prev_c2s_log); - let s2c_delta = s2c.saturating_sub(prev_s2c_log); + let c2s_delta = watchdog_delta(c2s, prev_c2s); + let s2c_delta = watchdog_delta(s2c, prev_s2c); if c2s_delta > 0 || s2c_delta > 0 { let secs = WATCHDOG_INTERVAL.as_secs_f64(); @@ -521,8 +655,8 @@ where ); } - prev_c2s_log = c2s; - prev_s2c_log = s2c; + prev_c2s = c2s; + prev_s2c = s2c; } }; @@ -557,7 +691,6 @@ where let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed); let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed); let duration = epoch.elapsed(); - adaptive_buffers::record_user_tier(&user_owned, seed_tier); match copy_result { Some(Ok((c2s, s2c))) => { @@ -573,6 +706,22 @@ where ); Ok(()) } + Some(Err(e)) if is_quota_io_error(&e) => { + let c2s = counters.c2s_bytes.load(Ordering::Relaxed); + let s2c = counters.s2c_bytes.load(Ordering::Relaxed); + warn!( + user = %user_owned, + c2s_bytes = c2s, + s2c_bytes = s2c, + c2s_msgs = c2s_ops, + s2c_msgs = s2c_ops, + duration_secs = duration.as_secs(), + "Data quota reached, closing relay" + ); + Err(ProxyError::DataQuotaExceeded { + user: user_owned.clone(), + }) + } Some(Err(e)) => { // I/O error in one of the directions let c2s = counters.c2s_bytes.load(Ordering::Relaxed); @@ -606,3 +755,39 @@ where } } } + +#[cfg(test)] +#[path = "tests/relay_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/relay_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] +mod relay_quota_lock_pressure_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_boundary_blackhat_tests.rs"] +mod relay_quota_boundary_blackhat_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_model_adversarial_tests.rs"] +mod relay_quota_model_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_overflow_regression_tests.rs"] +mod relay_quota_overflow_regression_tests; + +#[cfg(test)] +#[path = "tests/relay_watchdog_delta_security_tests.rs"] +mod relay_watchdog_delta_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] +mod relay_quota_waker_storm_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] +mod relay_quota_wake_liveness_regression_tests; diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index 306c536..5aa7e91 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::watch; -pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover"; +pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated"; #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(u8)] @@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode { } impl RelayRouteMode { - pub(crate) fn as_u8(self) -> u8 { - self as u8 - } - - pub(crate) fn from_u8(value: u8) -> Self { - match value { - 1 => Self::Middle, - _ => Self::Direct, - } - } - pub(crate) fn as_str(self) -> &'static str { match self { Self::Direct => "direct", @@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState { #[derive(Clone)] pub(crate) struct RouteRuntimeController { - mode: Arc, - generation: Arc, direct_since_epoch_secs: Arc, tx: watch::Sender, } @@ -60,18 +47,13 @@ impl RouteRuntimeController { 0 }; Self { - mode: Arc::new(AtomicU8::new(initial_mode.as_u8())), - generation: Arc::new(AtomicU64::new(0)), direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)), tx, } } pub(crate) fn snapshot(&self) -> RouteCutoverState { - RouteCutoverState { - mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)), - generation: self.generation.load(Ordering::Relaxed), - } + *self.tx.borrow() } pub(crate) fn subscribe(&self) -> watch::Receiver { @@ -84,20 +66,28 @@ impl RouteRuntimeController { } pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option { - let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed); - if previous == mode.as_u8() { + let mut next = None; + let changed = self.tx.send_if_modified(|state| { + if state.mode == mode { + return false; + } + if matches!(mode, RelayRouteMode::Direct) { + self.direct_since_epoch_secs + .store(now_epoch_secs(), Ordering::Relaxed); + } else { + self.direct_since_epoch_secs.store(0, Ordering::Relaxed); + } + state.mode = mode; + state.generation = state.generation.saturating_add(1); + next = Some(*state); + true + }); + + if !changed { return None; } - if matches!(mode, RelayRouteMode::Direct) { - self.direct_since_epoch_secs - .store(now_epoch_secs(), Ordering::Relaxed); - } else { - self.direct_since_epoch_secs.store(0, Ordering::Relaxed); - } - let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; - let next = RouteCutoverState { mode, generation }; - self.tx.send_replace(next); - Some(next) + + next } } @@ -110,10 +100,10 @@ fn now_epoch_secs() -> u64 { pub(crate) fn is_session_affected_by_cutover( current: RouteCutoverState, - _session_mode: RelayRouteMode, + session_mode: RelayRouteMode, session_generation: u64, ) -> bool { - current.generation > session_generation + current.generation > session_generation && current.mode != session_mode } pub(crate) fn affected_cutover_state( @@ -129,9 +119,7 @@ pub(crate) fn affected_cutover_state( } pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration { - let mut value = session_id - ^ generation.rotate_left(17) - ^ 0x9e37_79b9_7f4a_7c15; + let mut value = session_id ^ generation.rotate_left(17) ^ 0x9e37_79b9_7f4a_7c15; value ^= value >> 30; value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9); value ^= value >> 27; @@ -140,3 +128,11 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio let ms = 1000 + (value % 1000); Duration::from_millis(ms) } + +#[cfg(test)] +#[path = "tests/route_mode_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/route_mode_coherence_adversarial_tests.rs"] +mod coherence_adversarial_tests; diff --git a/src/proxy/session_eviction.rs b/src/proxy/session_eviction.rs index c735cae..800e5b8 100644 --- a/src/proxy/session_eviction.rs +++ b/src/proxy/session_eviction.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + /// Session eviction is intentionally disabled in runtime. /// /// The initial `user+dc` single-lease model caused valid parallel client diff --git a/src/proxy/tests/client_adversarial_tests.rs b/src/proxy/tests/client_adversarial_tests.rs new file mode 100644 index 0000000..5bc90bc --- /dev/null +++ b/src/proxy/tests/client_adversarial_tests.rs @@ -0,0 +1,714 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::error::ProxyError; +use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +// ------------------------------------------------------------------ +// Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn client_stress_10k_connections_limit_strict() { + let user = "stress-user"; + let limit = 512; + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), limit); + + let iterations = 1000; + let mut tasks = Vec::new(); + + for i in 0..iterations { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + let user_str = user.to_string(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)), + 10000 + (i % 1000) as u16, + ); + + match RunningClientHandler::acquire_user_connection_reservation_static( + &user_str, &config, stats, peer, ip_tracker, + ) + .await + { + Ok(res) => Ok(res), + Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()), + Err(e) => panic!("Unexpected error: {:?}", e), + } + })); + } + + let results = futures::future::join_all(tasks).await; + let mut successes = 0; + let mut failures = 0; + let mut reservations = Vec::new(); + + for res in results { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, limit, "Should allow exactly 'limit' connections"); + assert_eq!( + failures, + iterations - limit, + "Should fail the rest with LimitExceeded" + ); + assert_eq!(stats.get_user_curr_connects(user), limit as u64); + + drop(reservations); + + ip_tracker.drain_cleanup_queue().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "Stats must converge to 0 after all drops" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "IP tracker must converge to 0" + ); +} + +// ------------------------------------------------------------------ +// Priority 3: IP Tracker Race Stress +// ------------------------------------------------------------------ + +#[tokio::test] +async fn client_ip_tracker_race_condition_stress() { + let user = "race-user"; + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 100).await; + + let iterations = 1000; + let mut tasks = Vec::new(); + + for i in 0..iterations { + let ip_tracker = Arc::clone(&ip_tracker); + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8)); + + tasks.push(tokio::spawn(async move { + for _ in 0..10 { + if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await { + ip_tracker.remove_ip("race-user", ip).await; + } + } + })); + } + + futures::future::join_all(tasks).await; + + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "IP count must be zero after balanced add/remove burst" + ); +} + +#[tokio::test] +async fn client_limit_burst_peak_never_exceeds_cap() { + let user = "peak-cap-user"; + let limit = 32; + let attempts = 256; + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), limit); + + let peak = Arc::new(AtomicU64::new(0)); + let mut tasks = Vec::with_capacity(attempts); + + for i in 0..attempts { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + let peak = Arc::clone(&peak); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 250 + 1) as u8)), + 20000 + i as u16, + ); + + let acquired = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker, + ) + .await; + + if let Ok(reservation) = acquired { + let now = stats.get_user_curr_connects(user); + loop { + let prev = peak.load(Ordering::Relaxed); + if now <= prev { + break; + } + if peak + .compare_exchange(prev, now, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + break; + } + } + tokio::time::sleep(Duration::from_millis(2)).await; + drop(reservation); + } + })); + } + + futures::future::join_all(tasks).await; + ip_tracker.drain_cleanup_queue().await; + + assert!( + peak.load(Ordering::Relaxed) <= limit as u64, + "peak concurrent reservations must not exceed configured cap" + ); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_quota_rejection_never_mutates_live_counters() { + let user = "quota-reject-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 0); + + let peer: SocketAddr = "198.51.100.201:31111".parse().unwrap(); + let res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(res, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_expiration_rejection_never_mutates_live_counters() { + let user = "expired-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_expirations.insert( + user.to_string(), + chrono::Utc::now() - chrono::Duration::seconds(1), + ); + + let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap(); + let res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(res, Err(ProxyError::UserExpired { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_ip_limit_failure_rolls_back_counter_exactly() { + let user = "ip-limit-rollback-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 16); + + let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap(); + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + + let second_peer: SocketAddr = "198.51.100.204:31114".parse().unwrap(); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + second_peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + second, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + assert_eq!(stats.get_user_curr_connects(user), 1); + + drop(first); + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_parallel_limit_checks_success_path_leaves_no_residue() { + let user = "parallel-check-success-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 128).await; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 128); + + let mut tasks = Vec::new(); + for i in 0..128u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 10, (i / 255) as u8, (i % 255 + 1) as u8)), + 32000 + i, + ); + RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) + .await + })); + } + + for result in futures::future::join_all(tasks).await { + assert!(result.unwrap().is_ok()); + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_parallel_limit_checks_failure_path_leaves_no_residue() { + let user = "parallel-check-failure-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 0).await; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 512); + + let mut tasks = Vec::new(); + for i in 0..64u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)), + 33000 + i, + ); + RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) + .await + })); + } + + let mut _denied = 0usize; + for result in futures::future::join_all(tasks).await { + match result.unwrap() { + Ok(()) => {} + Err(ProxyError::ConnectionLimitExceeded { .. }) => _denied += 1, + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_churn_mixed_success_failure_converges_to_zero_state() { + let user = "mixed-churn-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let mut tasks = Vec::new(); + for i in 0..200u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(192, 0, 2, (i % 16 + 1) as u8)), + 34000 + (i % 32), + ); + let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await; + + if let Ok(reservation) = maybe_res { + tokio::time::sleep(Duration::from_millis((i % 3) as u64)).await; + drop(reservation); + } + })); + } + + futures::future::join_all(tasks).await; + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one() { + let user = "same-ip-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let peer: SocketAddr = "203.0.113.44:35555".parse().unwrap(); + let mut tasks = Vec::new(); + + for _ in 0..64 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + tasks.push(tokio::spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await + })); + } + + let mut granted = 0usize; + let mut reservations = Vec::new(); + for result in futures::future::join_all(tasks).await { + match result.unwrap() { + Ok(reservation) => { + granted += 1; + reservations.push(reservation); + } + Err(ProxyError::ConnectionLimitExceeded { .. }) => {} + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!( + granted, 1, + "only one reservation may be granted for same IP with limit=1" + ); + drop(reservations); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_repeat_acquire_release_cycles_never_accumulate_state() { + let user = "repeat-cycle-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 32).await; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 32); + + for i in 0..500u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, (i / 250) as u8, (i % 250 + 1) as u8)), + 36000 + (i % 128), + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + drop(reservation); + } + + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_multi_user_isolation_under_parallel_limit_exhaustion() { + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert("u1".to_string(), 8); + config.access.user_max_tcp_conns.insert("u2".to_string(), 8); + + let mut tasks = Vec::new(); + for i in 0..128u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + tasks.push(tokio::spawn(async move { + let user = if i % 2 == 0 { "u1" } else { "u2" }; + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(100, 64, (i / 64) as u8, (i % 64 + 1) as u8)), + 37000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await + })); + } + + let mut u1_success = 0usize; + let mut u2_success = 0usize; + let mut reservations = Vec::new(); + for (idx, result) in futures::future::join_all(tasks) + .await + .into_iter() + .enumerate() + { + let user = if idx % 2 == 0 { "u1" } else { "u2" }; + match result.unwrap() { + Ok(reservation) => { + if user == "u1" { + u1_success += 1; + } else { + u2_success += 1; + } + reservations.push(reservation); + } + Err(ProxyError::ConnectionLimitExceeded { .. }) => {} + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!(u1_success, 8, "u1 must get exactly its own configured cap"); + assert_eq!(u2_success, 8, "u2 must get exactly its own configured cap"); + + drop(reservations); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects("u1"), 0); + assert_eq!(stats.get_user_curr_connects("u2"), 0); +} + +#[tokio::test] +async fn client_limit_recovery_after_full_rejection_wave() { + let user = "recover-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let first_peer: SocketAddr = "198.51.100.50:38001".parse().unwrap(); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + + for i in 0..64u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 51, 100, (i % 60 + 1) as u8)), + 38002 + i, + ); + let denied = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!( + denied, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + } + + drop(reservation); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + + let recovery_peer: SocketAddr = "198.51.100.200:38999".parse().unwrap(); + let recovered = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + recovery_peer, + ip_tracker.clone(), + ) + .await; + assert!( + recovered.is_ok(), + "capacity must recover after prior holder drops" + ); +} + +#[tokio::test] +async fn client_dual_limit_cross_product_never_leaks_on_reject() { + let user = "dual-limit-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 2).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 2); + + let p1: SocketAddr = "203.0.113.10:39001".parse().unwrap(); + let p2: SocketAddr = "203.0.113.11:39002".parse().unwrap(); + let r1 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + p1, + ip_tracker.clone(), + ) + .await + .unwrap(); + let r2 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + p2, + ip_tracker.clone(), + ) + .await + .unwrap(); + + for i in 0..32u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (50 + i) as u8)), + 39010 + i, + ); + let denied = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!( + denied, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + } + + assert_eq!(stats.get_user_curr_connects(user), 2); + drop((r1, r2)); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_check_user_limits_concurrent_churn_no_counter_drift() { + let user = "check-drift-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 64).await; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 64); + + let mut tasks = Vec::new(); + for i in 0..512u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(172, 20, (i / 255) as u8, (i % 255 + 1) as u8)), + 40000 + (i % 500), + ); + let _ = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer, + &ip_tracker, + ) + .await; + })); + } + + for task in futures::future::join_all(tasks).await { + task.unwrap(); + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs b/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs new file mode 100644 index 0000000..80f9834 --- /dev/null +++ b/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; + +const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60; + +#[test] +fn beobachten_ttl_exact_upper_bound_is_preserved() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "upper-bound TTL should remain unchanged" + ); +} + +#[test] +fn beobachten_ttl_above_upper_bound_is_clamped() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "TTL above security cap must be clamped" + ); +} + +#[test] +fn beobachten_ttl_u64_max_is_clamped_fail_safe() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = u64::MAX; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "extreme configured TTL must not become multi-century retention" + ); +} + +#[test] +fn positive_one_minute_maps_to_exact_60_seconds() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + assert_eq!(beobachten_ttl(&config), Duration::from_secs(60)); +} + +#[test] +fn adversarial_boundary_triplet_behaves_deterministically() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES - 1; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs((BEOBACHTEN_TTL_MAX_MINUTES - 1) * 60) + ); + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60) + ); + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60) + ); +} + +#[test] +fn light_fuzz_random_minutes_match_fail_safe_model() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + let mut seed = 0xD15E_A5E5_F00D_BAADu64; + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + config.general.beobachten_minutes = seed; + let ttl = beobachten_ttl(&config); + let expected = if seed == 0 { + Duration::from_secs(60) + } else { + Duration::from_secs(seed.min(BEOBACHTEN_TTL_MAX_MINUTES) * 60) + }; + + assert_eq!(ttl, expected, "ttl mismatch for minutes={seed}"); + assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)); + } +} + +#[test] +fn stress_monotonic_minutes_remain_monotonic_until_cap_then_flat() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + let mut prev = Duration::from_secs(0); + for minutes in 0..=(BEOBACHTEN_TTL_MAX_MINUTES + 4096) { + config.general.beobachten_minutes = minutes; + let ttl = beobachten_ttl(&config); + + assert!(ttl >= prev, "ttl must be non-decreasing as minutes grow"); + assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)); + + if minutes > BEOBACHTEN_TTL_MAX_MINUTES { + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "ttl must stay clamped once cap is exceeded" + ); + } + prev = ttl; + } +} diff --git a/src/proxy/tests/client_masking_blackhat_campaign_tests.rs b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs new file mode 100644 index 0000000..88d4a58 --- /dev/null +++ b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs @@ -0,0 +1,904 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{ + HANDSHAKE_LEN, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE, TLS_RECORD_APPLICATION, + TLS_VERSION, +}; +use crate::protocol::tls; +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct CampaignHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn build_mask_harness(secret_hex: &str, mask_port: u16) -> CampaignHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + CampaignHarness { + config, + stats: stats.clone(), + upstream_manager: new_upstream_manager(stats), + replay_checker: Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_record(record_type: u8, payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(record_type); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + wrap_tls_record(TLS_RECORD_APPLICATION, payload) +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +async fn run_tls_success_mtproto_fail_capture( + harness: CampaignHarness, + peer: SocketAddr, + client_hello: Vec, + bad_mtproto_record: Vec, + trailing_records: Vec>, + expected_forward: Vec, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = (*harness.config).clone(); + cfg.censorship.mask_port = backend_addr.port(); + let cfg = Arc::new(cfg); + + let expected = expected_forward.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + cfg, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&bad_mtproto_record).await.unwrap(); + for record in trailing_records { + client_side.write_all(&record).await.unwrap(); + } + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected_forward); + + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +async fn run_invalid_tls_capture(config: Arc, payload: Vec, expected: Vec) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = (*config).clone(); + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + let cfg = Arc::new(cfg); + + let expected_probe = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.77:45001".parse().unwrap(), + cfg, + stats, + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&payload).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +#[tokio::test] +async fn blackhat_campaign_01_tail_only_record_is_forwarded_after_tls_success_mtproto_fail() { + let secret = [0xA1u8; 16]; + let harness = build_mask_harness("a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1", 1); + let client_hello = make_valid_tls_client_hello(&secret, 11, 600, 0x41); + let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let tail = wrap_tls_application_data(b"blackhat-tail-01"); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.1:55001".parse().unwrap(), + client_hello, + bad_record, + vec![tail.clone()], + tail, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_02_two_ordered_records_preserved_after_fallback() { + let secret = [0xA2u8; 16]; + let harness = build_mask_harness("a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2", 1); + let client_hello = make_valid_tls_client_hello(&secret, 12, 600, 0x42); + let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let r1 = wrap_tls_application_data(b"first"); + let r2 = wrap_tls_application_data(b"second"); + let expected = [r1.clone(), r2.clone()].concat(); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.2:55002".parse().unwrap(), + client_hello, + bad_record, + vec![r1, r2], + expected, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_03_large_tls_application_record_survives_fallback() { + let secret = [0xA3u8; 16]; + let harness = build_mask_harness("a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3", 1); + let client_hello = make_valid_tls_client_hello(&secret, 13, 600, 0x43); + let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let big_payload = vec![0x5Au8; MAX_TLS_PLAINTEXT_SIZE]; + let big_record = wrap_tls_application_data(&big_payload); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.3:55003".parse().unwrap(), + client_hello, + bad_record, + vec![big_record.clone()], + big_record, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_04_coalesced_tail_in_failed_record_is_reframed_and_forwarded() { + let secret = [0xA4u8; 16]; + let harness = build_mask_harness("a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4", 1); + let client_hello = make_valid_tls_client_hello(&secret, 14, 600, 0x44); + + let coalesced_tail = b"coalesced-tail-blackhat".to_vec(); + let mut bad_payload = vec![0u8; HANDSHAKE_LEN]; + bad_payload.extend_from_slice(&coalesced_tail); + let bad_record = wrap_tls_application_data(&bad_payload); + let expected = wrap_tls_application_data(&coalesced_tail); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.4:55004".parse().unwrap(), + client_hello, + bad_record, + Vec::new(), + expected, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_05_coalesced_tail_plus_next_record_keep_wire_order() { + let secret = [0xA5u8; 16]; + let harness = build_mask_harness("a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5", 1); + let client_hello = make_valid_tls_client_hello(&secret, 15, 600, 0x45); + + let coalesced_tail = b"inline-tail".to_vec(); + let mut bad_payload = vec![0u8; HANDSHAKE_LEN]; + bad_payload.extend_from_slice(&coalesced_tail); + let bad_record = wrap_tls_application_data(&bad_payload); + let next_record = wrap_tls_application_data(b"next-record"); + + let expected = [ + wrap_tls_application_data(&coalesced_tail), + next_record.clone(), + ] + .concat(); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.5:55005".parse().unwrap(), + client_hello, + bad_record, + vec![next_record], + expected, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_06_replayed_tls_hello_is_masked_without_serverhello() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_mask_harness("a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6", backend_addr.port()); + let replay_checker = harness.replay_checker.clone(); + let client_hello = make_valid_tls_client_hello(&[0xA6; 16], 16, 600, 0x46); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let first_tail = wrap_tls_application_data(b"seed-tail"); + + let expected_hello = client_hello.clone(); + let expected_tail = first_tail.clone(); + + let accept_task = tokio::spawn(async move { + let (mut s1, _) = listener.accept().await.unwrap(); + let mut got_tail = vec![0u8; expected_tail.len()]; + s1.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + drop(s1); + + let (mut s2, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + s2.read_exact(&mut got_hello).await.unwrap(); + got_hello + }); + + let run_one = |checker: Arc, send_mtproto: bool| { + let mut cfg = (*harness.config).clone(); + cfg.censorship.mask_port = backend_addr.port(); + let cfg = Arc::new(cfg); + let hello = client_hello.clone(); + let invalid_mtproto_record = invalid_mtproto_record.clone(); + let first_tail = first_tail.clone(); + let stats = harness.stats.clone(); + let upstream = harness.upstream_manager.clone(); + let pool = harness.buffer_pool.clone(); + let rng = harness.rng.clone(); + let route = harness.route_runtime.clone(); + let ipt = harness.ip_tracker.clone(); + let beob = harness.beobachten.clone(); + + async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.6:55006".parse().unwrap(), + cfg, + stats, + upstream, + checker, + pool, + rng, + None, + route, + None, + ipt, + beob, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + if send_mtproto { + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&first_tail).await.unwrap(); + } else { + let mut one = [0u8; 1]; + let no_server_hello = tokio::time::timeout( + Duration::from_millis(300), + client_side.read_exact(&mut one), + ) + .await; + assert!(no_server_hello.is_err() || no_server_hello.unwrap().is_err()); + } + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + } + }; + + run_one(replay_checker.clone(), true).await; + run_one(replay_checker, false).await; + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, client_hello); +} + +#[tokio::test] +async fn blackhat_campaign_07_truncated_clienthello_exact_prefix_is_forwarded() { + let mut payload = vec![0u8; 5 + 37]; + payload[0] = 0x16; + payload[1] = 0x03; + payload[2] = 0x01; + payload[3..5].copy_from_slice(&600u16.to_be_bytes()); + payload[5..].fill(0x71); + + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), payload.clone(), payload).await; +} + +#[tokio::test] +async fn blackhat_campaign_08_out_of_bounds_len_forwards_header_only() { + let header = vec![0x16, 0x03, 0x01, 0xFF, 0xFF]; + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), header.clone(), header).await; +} + +#[tokio::test] +async fn blackhat_campaign_09_fragmented_header_then_partial_body_masks_seen_bytes_only() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + + let expected = { + let mut x = vec![0u8; 5 + 11]; + x[0] = 0x16; + x[1] = 0x03; + x[2] = 0x01; + x[3..5].copy_from_slice(&600u16.to_be_bytes()); + x[5..].fill(0xCC); + x + }; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.9:55009".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03]).await.unwrap(); + client_side.write_all(&[0x01, 0x02, 0x58]).await.unwrap(); + client_side.write_all(&vec![0xCC; 11]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got.len(), 16); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +#[tokio::test] +async fn blackhat_campaign_10_zero_handshake_timeout_with_delay_still_avoids_timeout_counter() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 0; + cfg.censorship.server_hello_delay_min_ms = 700; + cfg.censorship.server_hello_delay_max_ms = 700; + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + let started = Instant::now(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.10:55010".parse().unwrap(), + Arc::new(cfg), + stats.clone(), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut invalid = vec![0u8; 5 + 700]; + invalid[0] = 0x16; + invalid[1] = 0x03; + invalid[2] = 0x01; + invalid[3..5].copy_from_slice(&700u16.to_be_bytes()); + invalid[5..].fill(0x66); + + client_side.write_all(&invalid).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); + assert!(started.elapsed() >= Duration::from_millis(650)); +} + +#[tokio::test] +async fn blackhat_campaign_11_parallel_bad_tls_probes_all_masked_without_timeouts() { + let n = 24usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_port = backend_addr.port(); + + let stats = Arc::new(Stats::new()); + let accept_task = tokio::spawn(async move { + let mut seen = HashSet::new(); + for _ in 0..n { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut hdr = [0u8; 5]; + stream.read_exact(&mut hdr).await.unwrap(); + seen.insert(hdr.to_vec()); + } + seen + }); + + let mut tasks = Vec::new(); + for i in 0..n { + let mut hdr = [0u8; 5]; + hdr[0] = 0x16; + hdr[1] = 0x03; + hdr[2] = 0x01; + hdr[3] = 0xFF; + hdr[4] = i as u8; + + let cfg = Arc::new(cfg.clone()); + let stats = stats.clone(); + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + format!("198.51.100.11:{}", 56000 + i).parse().unwrap(), + cfg, + stats, + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + client_side.write_all(&hdr).await.unwrap(); + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + hdr.to_vec() + })); + } + + let mut expected = HashSet::new(); + for t in tasks { + expected.insert(t.await.unwrap()); + } + + let seen = tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(seen, expected); + assert_eq!(stats.get_handshake_timeouts(), 0); +} + +#[tokio::test] +async fn blackhat_campaign_12_parallel_tls_success_mtproto_fail_sessions_keep_isolation() { + let sessions = 16usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = HashSet::new(); + for i in 0..sessions { + let rec = wrap_tls_application_data(&vec![i as u8; 8 + i]); + expected.insert(rec); + } + + let accept_task = tokio::spawn(async move { + let mut got_set = HashSet::new(); + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut head = [0u8; 5]; + stream.read_exact(&mut head).await.unwrap(); + let len = u16::from_be_bytes([head[3], head[4]]) as usize; + let mut rec = vec![0u8; 5 + len]; + rec[..5].copy_from_slice(&head); + stream.read_exact(&mut rec[5..]).await.unwrap(); + got_set.insert(rec); + } + got_set + }); + + let mut tasks = Vec::new(); + for i in 0..sessions { + let mut harness = + build_mask_harness("abababababababababababababababab", backend_addr.port()); + let mut cfg = (*harness.config).clone(); + cfg.censorship.mask_port = backend_addr.port(); + harness.config = Arc::new(cfg); + tasks.push(tokio::spawn(async move { + let secret = [0xABu8; 16]; + let hello = + make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10)); + let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + format!("198.51.100.12:{}", 56100 + i).parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&bad).await.unwrap(); + client_side.write_all(&tail).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(5), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + tail + })); + } + + let mut produced = HashSet::new(); + for t in tasks { + produced.insert(t.await.unwrap()); + } + + let observed = tokio::time::timeout(Duration::from_secs(8), accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!(produced, expected); + assert_eq!(observed, expected); +} + +#[tokio::test] +async fn blackhat_campaign_13_backend_down_does_not_escalate_to_handshake_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 1; + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.13:55013".parse().unwrap(), + Arc::new(cfg), + stats.clone(), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let bad = vec![0x16, 0x03, 0x01, 0xFF, 0x00]; + client_side.write_all(&bad).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); +} + +#[tokio::test] +async fn blackhat_campaign_14_masking_disabled_path_finishes_cleanly() { + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = false; + cfg.timeouts.client_handshake = 1; + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.14:55014".parse().unwrap(), + Arc::new(cfg), + stats.clone(), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let bad = vec![0x16, 0x03, 0x01, 0xFF, 0xF0]; + client_side.write_all(&bad).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); +} + +#[tokio::test] +async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() { + let mut seed = 0x9E3779B97F4A7C15u64; + + for idx in 0..20u16 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let mut tls_len = (seed as usize) % 20000; + if idx % 3 == 0 { + tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024); + } + + let body_to_send = + if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) { + (seed as usize % 29).min(tls_len.saturating_sub(1)) + } else { + 0 + }; + + let mut probe = vec![0u8; 5 + body_to_send]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + for b in &mut probe[5..] { + seed = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + *b = (seed >> 24) as u8; + } + + let expected = probe.clone(); + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe, expected).await; + } +} + +#[tokio::test] +async fn blackhat_campaign_16_mixed_probe_burst_stress_finishes_without_panics() { + let cases = 18usize; + let mut tasks = Vec::new(); + + for i in 0..cases { + tasks.push(tokio::spawn(async move { + if i % 2 == 0 { + let mut probe = vec![0u8; 5 + (i % 13)]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill((0x90 + i as u8) ^ 0x5A); + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe) + .await; + } else { + let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8]; + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await; + } + })); + } + + for task in tasks { + task.await.unwrap(); + } +} diff --git a/src/proxy/tests/client_masking_budget_security_tests.rs b/src/proxy/tests/client_masking_budget_security_tests.rs new file mode 100644 index 0000000..d98c780 --- /dev/null +++ b/src/proxy/tests/client_masking_budget_security_tests.rs @@ -0,0 +1,255 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(config: ProxyConfig) -> PipelineHarness { + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[tokio::test] +async fn masking_runs_outside_handshake_timeout_budget_with_high_reject_delay() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + config.timeouts.client_handshake = 0; + config.censorship.server_hello_delay_min_ms = 730; + config.censorship.server_hello_delay_max_ms = 730; + + let harness = build_harness(config); + let stats = harness.stats.clone(); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.241:56541".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + let mut invalid_hello = vec![0u8; 5 + 600]; + invalid_hello[0] = 0x16; + invalid_hello[1] = 0x03; + invalid_hello[2] = 0x01; + invalid_hello[3..5].copy_from_slice(&600u16.to_be_bytes()); + invalid_hello[5..].fill(0x44); + + let started = Instant::now(); + client_side.write_all(&invalid_hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert!( + result.is_ok(), + "bad-client fallback must not be canceled by handshake timeout" + ); + assert_eq!( + stats.get_handshake_timeouts(), + 0, + "masking fallback path must not increment handshake timeout counter" + ); + assert!( + started.elapsed() >= Duration::from_millis(700), + "configured reject delay should still be visible before masking" + ); +} + +#[tokio::test] +async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + config.access.ignore_time_skew = true; + config.access.users.insert( + "user".to_string(), + "d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string(), + ); + + let harness = build_harness(config); + + let secret = [0xD0u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0, 600, 0x41); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"no-clienthello-reinject"); + let expected_trailing = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!( + got, expected_trailing, + "mask backend must receive only post-handshake trailing TLS records" + ); + }); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.242:56542".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} diff --git a/src/proxy/tests/client_masking_diagnostics_security_tests.rs b/src/proxy/tests/client_masking_diagnostics_security_tests.rs new file mode 100644 index 0000000..0d9ca99 --- /dev/null +++ b/src/proxy/tests/client_masking_diagnostics_security_tests.rs @@ -0,0 +1,208 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn percentile_ms(mut values: Vec, p_num: usize, p_den: usize) -> u128 { + values.sort_unstable(); + if values.is_empty() { + return 0; + } + let idx = ((values.len() - 1) * p_num) / p_den; + values[idx] +} + +async fn measure_reject_duration_ms(body_sent: usize) -> u128 { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 1; + cfg.censorship.server_hello_delay_min_ms = 700; + cfg.censorship.server_hello_delay_max_ms = 700; + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.170:56170".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill(0xA7); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + started.elapsed().as_millis() +} + +async fn capture_forwarded_len(body_sent: usize) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = false; + cfg.timeouts.client_handshake = 1; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.171:56171".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill(0xB4); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn diagnostic_timing_profiles_are_within_realistic_guardrails() { + let classes = [17usize, 511usize, 1023usize, 4095usize]; + for class in classes { + let mut samples = Vec::new(); + for _ in 0..8 { + samples.push(measure_reject_duration_ms(class).await); + } + + let p50 = percentile_ms(samples.clone(), 50, 100); + let p95 = percentile_ms(samples.clone(), 95, 100); + let max = *samples.iter().max().unwrap(); + println!( + "diagnostic_timing class={} p50={}ms p95={}ms max={}ms", + class, p50, p95, max + ); + + assert!(p50 >= 650, "p50 too low for delayed reject class={}", class); + assert!( + p95 <= 1200, + "p95 too high for delayed reject class={}", + class + ); + assert!( + max <= 1500, + "max too high for delayed reject class={}", + class + ); + } +} + +#[tokio::test] +async fn diagnostic_forwarded_size_profiles_by_probe_class() { + let classes = [ + 0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize, + ]; + let mut observed = Vec::new(); + + for class in classes { + let len = capture_forwarded_len(class).await; + println!("diagnostic_shape class={} forwarded_len={}", class, len); + observed.push(len as u128); + assert_eq!( + len, + 5 + class, + "unexpected forwarded len for class={}", + class + ); + } + + let p50 = percentile_ms(observed.clone(), 50, 100); + let p95 = percentile_ms(observed.clone(), 95, 100); + let max = *observed.iter().max().unwrap(); + println!( + "diagnostic_shape_summary p50={}bytes p95={}bytes max={}bytes", + p50, p95, max + ); + + assert!(p95 >= p50); + assert!(max >= p95); +} diff --git a/src/proxy/tests/client_masking_hard_adversarial_tests.rs b/src/proxy/tests/client_masking_hard_adversarial_tests.rs new file mode 100644 index 0000000..65e66d3 --- /dev/null +++ b/src/proxy/tests/client_masking_hard_adversarial_tests.rs @@ -0,0 +1,767 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct Harness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> Harness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + Harness { + config, + stats: stats.clone(), + upstream_manager: new_upstream_manager(stats), + replay_checker: Arc::new(ReplayChecker::new(512, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +async fn run_tls_success_mtproto_fail_capture( + secret_hex: &str, + secret: [u8; 16], + timestamp: u32, + trailing_records: Vec>, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let expected_len = trailing_records.iter().map(Vec::len).sum::(); + let expected_concat = trailing_records.concat(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_len]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let harness = build_harness(secret_hex, backend_addr.port()); + let client_hello = make_valid_tls_client_hello(&secret, timestamp, 600, 0x42); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.210:56010".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_tls_record_body(&mut client_side, tls_response_head).await; + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + for record in trailing_records { + client_side.write_all(&record).await.unwrap(); + } + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected_concat); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + got +} + +#[tokio::test] +async fn masking_budget_survives_zero_handshake_timeout_with_delay() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.timeouts.client_handshake = 0; + cfg.censorship.server_hello_delay_min_ms = 720; + cfg.censorship.server_hello_delay_max_ms = 720; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; 605]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.211:56011".parse().unwrap(), + config, + stats.clone(), + new_upstream_manager(stats.clone()), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut invalid_hello = vec![0u8; 605]; + invalid_hello[0] = 0x16; + invalid_hello[1] = 0x03; + invalid_hello[2] = 0x01; + invalid_hello[3..5].copy_from_slice(&600u16.to_be_bytes()); + invalid_hello[5..].fill(0xA1); + + let started = Instant::now(); + client_side.write_all(&invalid_hello).await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); + assert!(started.elapsed() >= Duration::from_millis(680)); +} + +#[tokio::test] +async fn tls_mtproto_fail_forwards_only_trailing_record() { + let tail = wrap_tls_application_data(b"tail-only"); + let got = run_tls_success_mtproto_fail_capture( + "c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1", + [0xC1; 16], + 1, + vec![tail.clone()], + ) + .await; + assert_eq!(got, tail); +} + +#[tokio::test] +async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_harness("c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2", backend_addr.port()); + let secret = [0xC2u8; 16]; + let hello = make_valid_tls_client_hello(&secret, 2, 600, 0x41); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let first_tail = wrap_tls_application_data(b"seed"); + + let expected_hello = hello.clone(); + let expected_tail = first_tail.clone(); + + let accept_task = tokio::spawn(async move { + let (mut s1, _) = listener.accept().await.unwrap(); + let mut got_tail = vec![0u8; expected_tail.len()]; + s1.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + drop(s1); + + let (mut s2, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + s2.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + }); + + let run_session = |send_mtproto: bool| { + let (server_side, mut client_side) = duplex(131072); + let config = harness.config.clone(); + let stats = harness.stats.clone(); + let upstream = harness.upstream_manager.clone(); + let replay = harness.replay_checker.clone(); + let pool = harness.buffer_pool.clone(); + let rng = harness.rng.clone(); + let route = harness.route_runtime.clone(); + let ipt = harness.ip_tracker.clone(); + let beob = harness.beobachten.clone(); + let hello = hello.clone(); + let invalid_mtproto_record = invalid_mtproto_record.clone(); + let first_tail = first_tail.clone(); + + async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.212:56012".parse().unwrap(), + config, + stats, + upstream, + replay, + pool, + rng, + None, + route, + None, + ipt, + beob, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + if send_mtproto { + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_tls_record_body(&mut client_side, head).await; + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&first_tail).await.unwrap(); + } else { + let mut one = [0u8; 1]; + let no_server_hello = tokio::time::timeout( + Duration::from_millis(300), + client_side.read_exact(&mut one), + ) + .await; + assert!(no_server_hello.is_err() || no_server_hello.unwrap().is_err()); + } + + client_side.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } + }; + + run_session(true).await; + run_session(false).await; + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn connects_bad_increments_once_per_invalid_mtproto() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_harness("c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", backend_addr.port()); + let stats = harness.stats.clone(); + let bad_before = stats.get_connects_bad(); + + let tail = wrap_tls_application_data(b"accounting"); + let expected_tail = tail.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_tail); + }); + + let hello = make_valid_tls_client_hello(&[0xC3; 16], 3, 600, 0x42); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.213:56013".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + read_tls_record_body(&mut client_side, head).await; + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&tail).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + client_side.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert_eq!(stats.get_connects_bad(), bad_before + 1); +} + +#[tokio::test] +async fn truncated_clienthello_forwards_only_seen_prefix() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let expected_prefix_len = 5 + 17; + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_prefix_len]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.214:56014".parse().unwrap(), + config, + stats, + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut hello = vec![0u8; 5 + 17]; + hello[0] = 0x16; + hello[1] = 0x03; + hello[2] = 0x01; + hello[3..5].copy_from_slice(&600u16.to_be_bytes()); + hello[5..].fill(0x55); + + client_side.write_all(&hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, hello); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn out_of_bounds_tls_len_forwards_header_only() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + + let config = Arc::new(cfg); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(8192); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.215:56015".parse().unwrap(), + config, + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let hdr = [0x16, 0x03, 0x01, 0x42, 0x69]; + client_side.write_all(&hdr).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, hdr); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn non_tls_with_modes_disabled_is_masked() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(8192); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.216:56016".parse().unwrap(), + config, + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let probe = *b"HELLO"; + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, probe); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn concurrent_tls_mtproto_fail_sessions_are_isolated() { + let sessions = 12usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + for idx in 0..sessions { + let payload = vec![idx as u8; 32 + idx]; + expected.insert(wrap_tls_application_data(&payload)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + assert!(remaining.remove(&record)); + } + assert!(remaining.is_empty()); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"; + let harness = build_harness(secret_hex, backend_addr.port()); + let hello = + make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8); + let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]); + let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + read_tls_record_body(&mut client_side, head).await; + client_side.write_all(&invalid_mtproto).await.unwrap(); + client_side.write_all(&trailing).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +macro_rules! tail_length_case { + ($name:ident, $hex:expr, $secret:expr, $ts:expr, $len:expr) => { + #[tokio::test] + async fn $name() { + let mut payload = vec![0u8; $len]; + for (i, b) in payload.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(17).wrapping_add(5); + } + let record = wrap_tls_application_data(&payload); + let got = + run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()]) + .await; + assert_eq!(got, record); + } + }; +} + +tail_length_case!( + tail_len_1_preserved, + "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", + [0xD1; 16], + 30, + 1 +); +tail_length_case!( + tail_len_2_preserved, + "d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2", + [0xD2; 16], + 31, + 2 +); +tail_length_case!( + tail_len_3_preserved, + "d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", + [0xD3; 16], + 32, + 3 +); +tail_length_case!( + tail_len_7_preserved, + "d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4", + [0xD4; 16], + 33, + 7 +); +tail_length_case!( + tail_len_31_preserved, + "d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5", + [0xD5; 16], + 34, + 31 +); +tail_length_case!( + tail_len_127_preserved, + "d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6", + [0xD6; 16], + 35, + 127 +); +tail_length_case!( + tail_len_511_preserved, + "d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7", + [0xD7; 16], + 36, + 511 +); +tail_length_case!( + tail_len_1023_preserved, + "d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8", + [0xD8; 16], + 37, + 1023 +); diff --git a/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs new file mode 100644 index 0000000..f7229ce --- /dev/null +++ b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs @@ -0,0 +1,358 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::{TcpListener, TcpStream}; + +const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn masking_config(mask_port: u16) -> Arc { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + Arc::new(cfg) +} + +async fn run_generic_probe_and_capture_prefix(payload: Vec, expected_prefix: Vec) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let reply = REPLY_404.to_vec(); + let prefix_len = expected_prefix.len(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; prefix_len]; + stream.read_exact(&mut got).await.unwrap(); + stream.write_all(&reply).await.unwrap(); + got + }); + + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.210:55110".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&payload).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let got = tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected_prefix); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +async fn read_http_probe_header(stream: &mut TcpStream) -> Vec { + let mut out = Vec::with_capacity(96); + let mut one = [0u8; 1]; + + loop { + stream.read_exact(&mut one).await.unwrap(); + out.push(one[0]); + if out.ends_with(b"\r\n\r\n") { + break; + } + assert!( + out.len() <= 512, + "probe header exceeded sane limit while waiting for terminator" + ); + } + + out +} + +#[tokio::test] +async fn blackhat_fragmented_plain_http_probe_masks_and_preserves_prefix() { + let payload = b"GET /probe-evasion HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + run_generic_probe_and_capture_prefix(payload.clone(), payload).await; +} + +#[tokio::test] +async fn blackhat_invalid_tls_like_probe_masks_and_preserves_header_prefix() { + let payload = vec![0x16, 0x03, 0x03, 0x00, 0x64, 0x01, 0x00]; + run_generic_probe_and_capture_prefix(payload.clone(), payload).await; +} + +#[tokio::test] +async fn integration_client_handler_plain_probe_masks_and_preserves_prefix() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let payload = b"GET /integration-probe HTTP/1.1\r\nHost: a.example\r\n\r\n".to_vec(); + let expected_prefix = payload.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_prefix.len()]; + stream.read_exact(&mut got).await.unwrap(); + stream.write_all(REPLY_404).await.unwrap(); + got + }); + + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&payload).await.unwrap(); + client.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let got = tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, payload); + + let result = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_small_probe_variants_always_mask_and_preserve_declared_prefix() { + let mut rng = StdRng::seed_from_u64(0xA11E_5EED_F0F0_CAFE); + + for i in 0..24usize { + let mut payload = if rng.random::() { + b"GET /fuzz HTTP/1.1\r\nHost: fuzz.example\r\n\r\n".to_vec() + } else { + vec![0x16, 0x03, 0x03, 0x00, 0x64] + }; + + let tail_len = rng.random_range(0..=8usize); + for _ in 0..tail_len { + payload.push(rng.random::()); + } + + let expected_prefix = payload.clone(); + run_generic_probe_and_capture_prefix(payload, expected_prefix).await; + + if i % 6 == 0 { + tokio::task::yield_now().await; + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() { + let session_count = 12usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + for idx in 0..session_count { + let probe = + format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + expected.insert(probe); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..session_count { + let (mut stream, _) = listener.accept().await.unwrap(); + let head = read_http_probe_header(&mut stream).await; + stream.write_all(REPLY_404).await.unwrap(); + assert!( + remaining.remove(&head), + "backend received unexpected or duplicated probe prefix" + ); + } + assert!( + remaining.is_empty(), + "all session prefixes must be observed exactly once" + ); + }); + + let mut tasks = Vec::with_capacity(session_count); + for idx in 0..session_count { + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let probe = + format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..50aa44c --- /dev/null +++ b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs @@ -0,0 +1,645 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct RedTeamHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> RedTeamHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + RedTeamHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn run_tls_success_mtproto_fail_session( + secret_hex: &str, + secret: [u8; 16], + timestamp: u32, + tail: Vec, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_harness(secret_hex, backend_addr.port()); + let client_hello = make_valid_tls_client_hello(&secret, timestamp, 600, 0x42); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(&tail); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; trailing_record.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.250:56900".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + let body_len = u16::from_be_bytes([head[3], head[4]]) as usize; + let mut body = vec![0u8; body_len]; + client_side.read_exact(&mut body).await.unwrap(); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side + .write_all(&wrap_tls_application_data(&tail)) + .await + .unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + forwarded +} + +#[tokio::test] +#[ignore = "red-team expected-fail: demonstrates that post-TLS fallback still forwards data to backend"] +async fn redteam_01_backend_receives_no_data_after_mtproto_fail() { + let forwarded = run_tls_success_mtproto_fail_session( + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + [0xAA; 16], + 1, + b"probe-a".to_vec(), + ) + .await; + assert!( + forwarded.is_empty(), + "backend unexpectedly received fallback bytes" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict no-fallback policy hypothesis"] +async fn redteam_02_backend_must_never_receive_tls_records_after_mtproto_fail() { + let forwarded = run_tls_success_mtproto_fail_session( + "abababababababababababababababab", + [0xAB; 16], + 2, + b"probe-b".to_vec(), + ) + .await; + assert_ne!( + forwarded[0], 0x17, + "received TLS application record despite strict policy" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: impossible timing uniformity target"] +async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "acacacacacacacacacacacacacacacac".to_string(), + ); + + let harness = RedTeamHarness { + config: Arc::new(cfg), + stats: Arc::new(Stats::new()), + upstream_manager: Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + Arc::new(Stats::new()), + )), + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + }; + + let hello = make_valid_tls_client_hello(&[0xAC; 16], 3, 600, 0x42); + let (server_side, mut client_side) = duplex(131072); + + let started = Instant::now(); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.251:56901".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert!( + started.elapsed() < Duration::from_millis(1), + "fallback path took longer than 1ms" + ); +} + +macro_rules! redteam_tail_must_not_forward_case { + ($name:ident, $hex:expr, $secret:expr, $ts:expr, $len:expr) => { + #[tokio::test] + #[ignore = "red-team expected-fail: strict no-forwarding hypothesis"] + async fn $name() { + let mut tail = vec![0u8; $len]; + for (i, b) in tail.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(31).wrapping_add(7); + } + let forwarded = run_tls_success_mtproto_fail_session($hex, $secret, $ts, tail).await; + assert!( + forwarded.is_empty(), + "strict model expects zero forwarded bytes, got {}", + forwarded.len() + ); + } + }; +} + +redteam_tail_must_not_forward_case!( + redteam_04_tail_len_1_not_forwarded, + "adadadadadadadadadadadadadadadad", + [0xAD; 16], + 4, + 1 +); +redteam_tail_must_not_forward_case!( + redteam_05_tail_len_2_not_forwarded, + "aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae", + [0xAE; 16], + 5, + 2 +); +redteam_tail_must_not_forward_case!( + redteam_06_tail_len_3_not_forwarded, + "afafafafafafafafafafafafafafafaf", + [0xAF; 16], + 6, + 3 +); +redteam_tail_must_not_forward_case!( + redteam_07_tail_len_7_not_forwarded, + "b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0", + [0xB0; 16], + 7, + 7 +); +redteam_tail_must_not_forward_case!( + redteam_08_tail_len_15_not_forwarded, + "b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", + [0xB1; 16], + 8, + 15 +); +redteam_tail_must_not_forward_case!( + redteam_09_tail_len_63_not_forwarded, + "b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", + [0xB2; 16], + 9, + 63 +); +redteam_tail_must_not_forward_case!( + redteam_10_tail_len_127_not_forwarded, + "b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", + [0xB3; 16], + 10, + 127 +); +redteam_tail_must_not_forward_case!( + redteam_11_tail_len_255_not_forwarded, + "b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", + [0xB4; 16], + 11, + 255 +); +redteam_tail_must_not_forward_case!( + redteam_12_tail_len_511_not_forwarded, + "b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5", + [0xB5; 16], + 12, + 511 +); +redteam_tail_must_not_forward_case!( + redteam_13_tail_len_1023_not_forwarded, + "b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6", + [0xB6; 16], + 13, + 1023 +); +redteam_tail_must_not_forward_case!( + redteam_14_tail_len_2047_not_forwarded, + "b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7", + [0xB7; 16], + 14, + 2047 +); +redteam_tail_must_not_forward_case!( + redteam_15_tail_len_4095_not_forwarded, + "b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8", + [0xB8; 16], + 15, + 4095 +); + +#[tokio::test] +#[ignore = "red-team expected-fail: impossible indistinguishability envelope"] +async fn redteam_16_timing_delta_between_paths_must_be_sub_1ms_under_concurrency() { + let runs = 20usize; + let mut durations = Vec::with_capacity(runs); + + for i in 0..runs { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let secret = [0xB9u8; 16]; + let harness = build_harness("b9b9b9b9b9b9b9b9b9b9b9b9b9b9b9b9", backend_addr.port()); + let hello = make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x42); + + let accept_task = tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.252:56902".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + let started = Instant::now(); + client_side.write_all(&hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + durations.push(started.elapsed()); + } + + let min = durations.iter().copied().min().unwrap(); + let max = durations.iter().copied().max().unwrap(); + assert!( + max - min <= Duration::from_millis(1), + "timing spread too wide for strict anti-probing envelope" + ); +} + +async fn measure_invalid_probe_duration_ms(delay_ms: u64, tls_len: u16, body_sent: usize) -> u128 { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 1; + cfg.censorship.server_hello_delay_min_ms = delay_ms; + cfg.censorship.server_hello_delay_max_ms = delay_ms; + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.253:56903".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + Arc::new(Stats::new()), + )), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0xD7); + + let started = Instant::now(); + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +async fn capture_forwarded_probe_len(tls_len: u16, body_sent: usize) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.timeouts.client_handshake = 1; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.254:56904".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + Arc::new(Stats::new()), + )), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0xBC); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +macro_rules! redteam_timing_envelope_case { + ($name:ident, $delay_ms:expr, $tls_len:expr, $body_sent:expr, $max_ms:expr) => { + #[tokio::test] + #[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"] + async fn $name() { + let elapsed_ms = + measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await; + assert!( + elapsed_ms <= $max_ms, + "timing envelope violated: elapsed={}ms, max={}ms", + elapsed_ms, + $max_ms + ); + } + }; +} + +macro_rules! redteam_constant_shape_case { + ($name:ident, $tls_len:expr, $body_sent:expr, $expected_len:expr) => { + #[tokio::test] + #[ignore = "red-team expected-fail: strict constant-shape backend fingerprint hypothesis"] + async fn $name() { + let got = capture_forwarded_probe_len($tls_len, $body_sent).await; + assert_eq!( + got, $expected_len, + "fingerprint shape mismatch: got={} expected={} (strict constant-shape model)", + got, $expected_len + ); + } + }; +} + +redteam_timing_envelope_case!(redteam_17_timing_env_very_tight_00, 700, 600, 0, 3); +redteam_timing_envelope_case!(redteam_18_timing_env_very_tight_01, 700, 600, 1, 3); +redteam_timing_envelope_case!(redteam_19_timing_env_very_tight_02, 700, 600, 7, 3); +redteam_timing_envelope_case!(redteam_20_timing_env_very_tight_03, 700, 600, 17, 3); +redteam_timing_envelope_case!(redteam_21_timing_env_very_tight_04, 700, 600, 31, 3); +redteam_timing_envelope_case!(redteam_22_timing_env_very_tight_05, 700, 600, 63, 3); +redteam_timing_envelope_case!(redteam_23_timing_env_very_tight_06, 700, 600, 127, 3); +redteam_timing_envelope_case!(redteam_24_timing_env_very_tight_07, 700, 600, 255, 3); +redteam_timing_envelope_case!(redteam_25_timing_env_very_tight_08, 700, 600, 511, 3); +redteam_timing_envelope_case!(redteam_26_timing_env_very_tight_09, 700, 600, 1023, 3); +redteam_timing_envelope_case!(redteam_27_timing_env_very_tight_10, 700, 600, 2047, 3); +redteam_timing_envelope_case!(redteam_28_timing_env_very_tight_11, 700, 600, 4095, 3); + +redteam_constant_shape_case!(redteam_29_constant_shape_00, 600, 0, 517); +redteam_constant_shape_case!(redteam_30_constant_shape_01, 600, 1, 517); +redteam_constant_shape_case!(redteam_31_constant_shape_02, 600, 7, 517); +redteam_constant_shape_case!(redteam_32_constant_shape_03, 600, 17, 517); +redteam_constant_shape_case!(redteam_33_constant_shape_04, 600, 31, 517); +redteam_constant_shape_case!(redteam_34_constant_shape_05, 600, 63, 517); +redteam_constant_shape_case!(redteam_35_constant_shape_06, 600, 127, 517); +redteam_constant_shape_case!(redteam_36_constant_shape_07, 600, 255, 517); +redteam_constant_shape_case!(redteam_37_constant_shape_08, 600, 511, 517); +redteam_constant_shape_case!(redteam_38_constant_shape_09, 600, 1023, 517); +redteam_constant_shape_case!(redteam_39_constant_shape_10, 600, 2047, 517); +redteam_constant_shape_case!(redteam_40_constant_shape_11, 600, 4095, 517); diff --git a/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..3a01a69 --- /dev/null +++ b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs @@ -0,0 +1,246 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.214:57014".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +fn pearson_corr(xs: &[f64], ys: &[f64]) -> f64 { + if xs.len() != ys.len() || xs.is_empty() { + return 0.0; + } + + let n = xs.len() as f64; + let mean_x = xs.iter().sum::() / n; + let mean_y = ys.iter().sum::() / n; + + let mut cov = 0.0; + let mut var_x = 0.0; + let mut var_y = 0.0; + + for (&x, &y) in xs.iter().zip(ys.iter()) { + let dx = x - mean_x; + let dy = y - mean_y; + cov += dx * dy; + var_x += dx * dx; + var_y += dy * dy; + } + + if var_x == 0.0 || var_y == 0.0 { + return 0.0; + } + + cov / (var_x.sqrt() * var_y.sqrt()) +} + +fn lcg_sizes(count: usize, floor: usize, cap: usize) -> Vec { + let mut x = 0x9E3779B97F4A7C15u64; + let span = cap.saturating_mul(3); + let mut out = Vec::with_capacity(count + 8); + + for _ in 0..count { + x = x + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + let v = (x as usize) % span.max(1); + out.push(v); + } + + // Inject edge and boundary-heavy probes. + out.extend_from_slice(&[ + 0, + floor.saturating_sub(1), + floor, + floor.saturating_add(1), + cap.saturating_sub(1), + cap, + cap.saturating_add(1), + cap.saturating_mul(2), + ]); + out +} + +async fn collect_distribution( + sizes: &[usize], + hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let mut out = Vec::with_capacity(sizes.len()); + for &body in sizes { + out.push(run_probe_capture(body, 1200, hardening, floor, cap).await); + } + out +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict decorrelation target for hardened output lengths"] +async fn redteam_fuzz_01_hardened_output_length_correlation_should_be_below_0_2() { + let floor = 512usize; + let cap = 4096usize; + let sizes = lcg_sizes(24, floor, cap); + + let hardened = collect_distribution(&sizes, true, floor, cap).await; + let x: Vec = sizes.iter().map(|v| *v as f64).collect(); + let y_hard: Vec = hardened.iter().map(|v| *v as f64).collect(); + + let corr_hard = pearson_corr(&x, &y_hard).abs(); + println!( + "redteam_fuzz corr_hardened={corr_hard:.4} samples={}", + sizes.len() + ); + + assert!( + corr_hard < 0.2, + "strict model expects near-zero size correlation; observed corr={corr_hard:.4}" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict class-collapse ratio target"] +async fn redteam_fuzz_02_hardened_unique_output_ratio_should_be_below_5pct() { + let floor = 512usize; + let cap = 4096usize; + let sizes = lcg_sizes(24, floor, cap); + + let hardened = collect_distribution(&sizes, true, floor, cap).await; + + let in_unique = { + let mut s = std::collections::BTreeSet::new(); + for v in &sizes { + s.insert(*v); + } + s.len() + }; + + let out_unique = { + let mut s = std::collections::BTreeSet::new(); + for v in &hardened { + s.insert(*v); + } + s.len() + }; + + let ratio = out_unique as f64 / in_unique as f64; + println!( + "redteam_fuzz unique_ratio_hardened={ratio:.4} out_unique={} in_unique={}", + out_unique, in_unique + ); + + assert!( + ratio <= 0.05, + "strict model expects near-total collapse; observed ratio={ratio:.4}" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict separability improvement target"] +async fn redteam_fuzz_03_hardened_signal_must_be_10x_lower_than_plain() { + let floor = 512usize; + let cap = 4096usize; + let sizes = lcg_sizes(24, floor, cap); + + let plain = collect_distribution(&sizes, false, floor, cap).await; + let hardened = collect_distribution(&sizes, true, floor, cap).await; + + let x: Vec = sizes.iter().map(|v| *v as f64).collect(); + let y_plain: Vec = plain.iter().map(|v| *v as f64).collect(); + let y_hard: Vec = hardened.iter().map(|v| *v as f64).collect(); + + let corr_plain = pearson_corr(&x, &y_plain).abs(); + let corr_hard = pearson_corr(&x, &y_hard).abs(); + + println!("redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}"); + + assert!( + corr_hard <= corr_plain * 0.1, + "strict model expects 10x suppression; plain={corr_plain:.4} hardened={corr_hard:.4}" + ); +} diff --git a/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs new file mode 100644 index 0000000..48e94a5 --- /dev/null +++ b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs @@ -0,0 +1,179 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn expected_bucket(total: usize, floor: usize, cap: usize) -> usize { + if total == 0 || floor == 0 || cap < floor { + return total; + } + + if total >= cap { + return total; + } + + let mut bucket = floor; + while bucket < total { + match bucket.checked_mul(2) { + Some(next) => bucket = next, + None => return total, + } + if bucket > cap { + return cap; + } + } + bucket +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.199:56999".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn shape_hardening_non_power_of_two_cap_collapses_probe_classes() { + let floor = 1000usize; + let cap = 1500usize; + + let low = run_probe_capture(1195, 700, true, floor, cap).await; + let high = run_probe_capture(1494, 700, true, floor, cap).await; + + assert_eq!(low.len(), 1500); + assert_eq!(high.len(), 1500); +} + +#[tokio::test] +async fn shape_hardening_disabled_keeps_non_power_of_two_cap_lengths_distinct() { + let floor = 1000usize; + let cap = 1500usize; + + let low = run_probe_capture(1195, 700, false, floor, cap).await; + let high = run_probe_capture(1494, 700, false, floor, cap).await; + + assert_eq!(low.len(), 1200); + assert_eq!(high.len(), 1499); +} + +#[tokio::test] +async fn shape_hardening_parallel_stress_collapses_sub_cap_probes() { + let floor = 1000usize; + let cap = 1500usize; + let mut tasks = Vec::new(); + + for idx in 0..24usize { + let body = 1001 + (idx * 19 % 480); + tasks.push(tokio::spawn(async move { + run_probe_capture(body, 1200, true, floor, cap).await.len() + })); + } + + for task in tasks { + let observed = task.await.unwrap(); + assert_eq!(observed, 1500); + } +} + +#[tokio::test] +async fn shape_hardening_light_fuzz_matches_bucket_oracle() { + let floor = 512usize; + let cap = 4096usize; + + for step in 1usize..=36usize { + let total = 1 + (((step * 313) ^ (step << 7)) % (cap + 300)); + let body = total.saturating_sub(5); + + let got = run_probe_capture(body, 650, true, floor, cap).await; + let expected = expected_bucket(total, floor, cap); + assert_eq!( + got.len(), + expected, + "step={step} total={total} expected={expected} got={} ", + got.len() + ); + } +} diff --git a/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..f91e687 --- /dev/null +++ b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs @@ -0,0 +1,238 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.211:57011".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +async fn measure_reject_ms(body_sent: usize) -> u128 { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.censorship.server_hello_delay_min_ms = 700; + cfg.censorship.server_hello_delay_max_ms = 700; + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.212:57012".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill(0x44); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +#[tokio::test] +#[ignore = "red-team expected-fail: above-cap exact length still leaks classifier signal"] +async fn redteam_shape_01_above_cap_flows_should_collapse_to_single_class() { + let floor = 512usize; + let cap = 4096usize; + + let a = run_probe_capture(5000, 7000, true, floor, cap).await; + let b = run_probe_capture(6000, 7000, true, floor, cap).await; + + assert_eq!( + a.len(), + b.len(), + "strict anti-classifier model expects same backend length class above cap" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: current padding bytes are deterministic zeros"] +async fn redteam_shape_02_padding_tail_must_be_non_deterministic() { + let floor = 512usize; + let cap = 4096usize; + let got = run_probe_capture(17, 600, true, floor, cap).await; + + assert!(got.len() > 22, "test requires padding tail to exist"); + + let tail = &got[22..]; + assert!( + tail.iter().any(|b| *b != 0), + "padding tail is fully zeroed and thus deterministic" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: exact-floor probes still expose boundary class"] +async fn redteam_shape_03_exact_floor_input_should_not_be_fixed_point() { + let floor = 512usize; + let cap = 4096usize; + let got = run_probe_capture(507, 600, true, floor, cap).await; + + assert!( + got.len() > floor, + "strict model expects extra blur even when input lands exactly on floor" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict one-bucket collapse hypothesis"] +async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() { + let floor = 512usize; + let cap = 4096usize; + let classes = [ + 17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize, + ]; + + let mut observed = Vec::new(); + for body in classes { + observed.push(run_probe_capture(body, 1200, true, floor, cap).await.len()); + } + + let first = observed[0]; + for v in observed { + assert_eq!( + v, first, + "strict model expects one collapsed class across all sub-cap probes" + ); + } +} + +#[tokio::test] +#[ignore = "red-team expected-fail: over-strict micro-timing invariance"] +async fn redteam_shape_05_reject_timing_spread_should_be_under_2ms() { + let classes = [17usize, 511usize, 1023usize, 2047usize, 4095usize]; + let mut values = Vec::new(); + + for class in classes { + values.push(measure_reject_ms(class).await); + } + + let min = *values.iter().min().unwrap(); + let max = *values.iter().max().unwrap(); + assert!( + min == 700 && max == 700, + "strict model requires exact 700ms for every malformed class: min={min}ms max={max}ms" + ); +} + +#[test] +#[ignore = "red-team expected-fail: secure-by-default hypothesis"] +fn redteam_shape_06_shape_hardening_should_be_secure_by_default() { + let cfg = ProxyConfig::default(); + assert!( + cfg.censorship.mask_shape_hardening, + "strict model expects shape hardening enabled by default" + ); +} diff --git a/src/proxy/tests/client_masking_shape_hardening_security_tests.rs b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs new file mode 100644 index 0000000..f2bec42 --- /dev/null +++ b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.188:56888".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn shape_hardening_disabled_keeps_original_probe_length() { + let got = run_probe_capture(17, 600, false, 512, 4096).await; + assert_eq!(got.len(), 22); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]); +} + +#[tokio::test] +async fn shape_hardening_enabled_pads_small_probe_to_floor_bucket() { + let got = run_probe_capture(17, 600, true, 512, 4096).await; + assert_eq!(got.len(), 512); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]); +} + +#[tokio::test] +async fn shape_hardening_enabled_pads_mid_probe_to_next_bucket() { + let got = run_probe_capture(511, 600, true, 512, 4096).await; + assert_eq!(got.len(), 1024); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]); +} + +#[tokio::test] +async fn shape_hardening_respects_cap_and_avoids_padding_above_cap() { + let got = run_probe_capture(5000, 7000, true, 512, 4096).await; + assert_eq!(got.len(), 5005); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x1b, 0x58]); +} diff --git a/src/proxy/tests/client_masking_stress_adversarial_tests.rs b/src/proxy/tests/client_masking_stress_adversarial_tests.rs new file mode 100644 index 0000000..5c00c63 --- /dev/null +++ b/src/proxy/tests/client_masking_stress_adversarial_tests.rs @@ -0,0 +1,256 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +struct StressHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn build_harness(mask_port: u16, secret_hex: &str) -> StressHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + StressHarness { + config, + stats: stats.clone(), + upstream_manager: new_upstream_manager(stats), + replay_checker: Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +async fn run_parallel_tail_fallback_case( + sessions: usize, + payload_len: usize, + write_chunk: usize, + ts_base: u32, + peer_port_base: u16, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + for idx in 0..sessions { + let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3]; + expected.insert(wrap_tls_application_data(&payload)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + assert!(remaining.remove(&record)); + } + assert!(remaining.is_empty()); + }); + + let mut tasks = Vec::with_capacity(sessions); + + for idx in 0..sessions { + let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0"); + let hello = + make_valid_tls_client_hello(&[0xE0; 16], ts_base + idx as u32, 600, 0x40 + (idx as u8)); + + let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3]; + let trailing = wrap_tls_application_data(&payload); + // Keep source IPs unique across stress cases so global pre-auth probe state + // cannot contaminate unrelated sessions and make this test nondeterministic. + let peer_ip_third = 100 + ((ts_base as u8) / 10); + let peer_ip_fourth = (idx as u8).saturating_add(1); + let peer: SocketAddr = format!( + "198.51.{}.{}:{}", + peer_ip_third, + peer_ip_fourth, + peer_port_base + idx as u16 + ) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut server_hello_head = [0u8; 5]; + client_side + .read_exact(&mut server_hello_head) + .await + .unwrap(); + assert_eq!(server_hello_head[0], 0x16); + read_tls_record_body(&mut client_side, server_hello_head).await; + + client_side.write_all(&invalid_mtproto).await.unwrap(); + for chunk in trailing.chunks(write_chunk.max(1)) { + client_side.write_all(chunk).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(8), accept_task) + .await + .unwrap() + .unwrap(); +} + +macro_rules! stress_case { + ($name:ident, $sessions:expr, $payload_len:expr, $chunk:expr, $ts:expr, $port:expr) => { + #[tokio::test] + async fn $name() { + run_parallel_tail_fallback_case($sessions, $payload_len, $chunk, $ts, $port).await; + } + }; +} + +stress_case!(stress_masking_parallel_s01, 4, 16, 1, 1000, 57000); +stress_case!(stress_masking_parallel_s02, 5, 24, 2, 1010, 57010); +stress_case!(stress_masking_parallel_s03, 6, 32, 3, 1020, 57020); +stress_case!(stress_masking_parallel_s04, 7, 40, 4, 1030, 57030); +stress_case!(stress_masking_parallel_s05, 8, 48, 5, 1040, 57040); +stress_case!(stress_masking_parallel_s06, 9, 56, 6, 1050, 57050); +stress_case!(stress_masking_parallel_s07, 10, 64, 7, 1060, 57060); +stress_case!(stress_masking_parallel_s08, 11, 72, 8, 1070, 57070); +stress_case!(stress_masking_parallel_s09, 12, 80, 9, 1080, 57080); +stress_case!(stress_masking_parallel_s10, 13, 88, 10, 1090, 57090); +stress_case!(stress_masking_parallel_s11, 6, 128, 11, 1100, 57100); +stress_case!(stress_masking_parallel_s12, 7, 160, 12, 1110, 57110); +stress_case!(stress_masking_parallel_s13, 8, 192, 13, 1120, 57120); +stress_case!(stress_masking_parallel_s14, 9, 224, 14, 1130, 57130); +stress_case!(stress_masking_parallel_s15, 10, 256, 15, 1140, 57140); +stress_case!(stress_masking_parallel_s16, 11, 288, 16, 1150, 57150); +stress_case!(stress_masking_parallel_s17, 12, 320, 17, 1160, 57160); +stress_case!(stress_masking_parallel_s18, 13, 352, 18, 1170, 57170); +stress_case!(stress_masking_parallel_s19, 14, 384, 19, 1180, 57180); +stress_case!(stress_masking_parallel_s20, 15, 416, 20, 1190, 57190); +stress_case!(stress_masking_parallel_s21, 16, 448, 21, 1200, 57200); +stress_case!(stress_masking_parallel_s22, 17, 480, 22, 1210, 57210); diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs new file mode 100644 index 0000000..6338e23 --- /dev/null +++ b/src/proxy/tests/client_security_tests.rs @@ -0,0 +1,4443 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::AesCtr; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::ProtoTag; +use crate::protocol::tls; +use crate::proxy::handshake::HandshakeSuccess; +use crate::stream::{CryptoReader, CryptoWriter}; +use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use std::net::Ipv4Addr; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::{TcpListener, TcpStream}; + +#[test] +fn synthetic_local_addr_uses_configured_port_for_zero() { + let addr = synthetic_local_addr(0); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), 0); +} + +#[test] +fn synthetic_local_addr_uses_configured_port_for_max() { + let addr = synthetic_local_addr(u16::MAX); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), u16::MAX); +} + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { + let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); + let stats = Arc::new(crate::stats::Stats::new()); + let user = "sync-drop-user".to_string(); + let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap(); + + ip_tracker.set_user_limit(&user, 1).await; + ip_tracker.check_and_add(&user, ip).await.unwrap(); + stats.increment_user_curr_connects(&user); + + assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); + assert_eq!(stats.get_user_curr_connects(&user), 1); + + let reservation = + UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); + + // Drop the reservation synchronously without any tokio::spawn/await yielding! + drop(reservation); + + // The IP is now inside the cleanup_queue, check that the queue has length 1 + let queue_len = ip_tracker.cleanup_queue_len_for_tests(); + assert_eq!( + queue_len, 1, + "Reservation drop must push directly to synchronized IP queue" + ); + + assert_eq!( + stats.get_user_curr_connects(&user), + 0, + "Stats must decrement immediately" + ); + + ip_tracker.drain_cleanup_queue().await; + assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0); +} + +#[tokio::test] +async fn relay_task_abort_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "abort-user"; + let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "task abort must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "task abort must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn relay_cutover_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "cutover-user"; + let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("relay must terminate after cutover") + .expect("relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover must terminate direct relay session" + ); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "cutover exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "cutover exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn integration_route_cutover_and_quota_overlap_fails_closed_and_releases_state() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (mut stream, _) = tg_listener.accept().await.unwrap(); + stream.write_all(&[0x41, 0x42]).await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + let user = "cutover-quota-overlap-user"; + let peer_addr: SocketAddr = "198.51.100.240:50010".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.access.user_data_quota.insert(user.to_string(), 1); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + let observed_progress = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) >= 1 + || ip_tracker.get_active_ip_count(user).await >= 1 + || relay_task.is_finished() + { + return true; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap_or(false); + assert!( + observed_progress, + "overlap race test precondition must observe activation or bounded early termination" + ); + + tokio::time::sleep(Duration::from_millis(5)).await; + let _ = route_runtime.set_mode(RelayRouteMode::Middle); + + let relay_result = tokio::time::timeout(Duration::from_secs(3), relay_task) + .await + .expect("overlap race relay must terminate") + .expect("overlap race relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(relay_result, Err(ProxyError::Proxy(ref msg)) if msg == crate::proxy::route_mode::ROUTE_SWITCH_ERROR_MSG), + "overlap race must fail closed via quota enforcement or generic cutover termination" + ); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "overlap race exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "overlap race exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn stress_drop_without_release_converges_to_zero_user_and_ip_state() { + let user = "gap-t05-drop-stress-user"; + let mut config = crate::config::ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4096); + + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + + let mut reservations = Vec::new(); + for idx in 0..512u16 { + let peer = std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new( + 198, + 51, + (idx >> 8) as u8, + (idx & 0xff) as u8, + )), + 30_000 + idx, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed in stress precondition"); + reservations.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 512); + + for reservation in reservations { + std::thread::spawn(move || drop(reservation)) + .join() + .expect("drop thread must not panic"); + } + + tokio::time::timeout(std::time::Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + } + }) + .await + .expect("drop-only path must eventually release all user/IP reservations"); +} + +#[tokio::test] +async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() { + let mut cfg = crate::config::ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs.clear(); + + let config = std::sync::Arc::new(cfg); + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new( + vec![crate::config::UpstreamConfig { + upstream_type: crate::config::UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new( + 128, + std::time::Duration::from_secs(60), + )); + let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); + let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); + let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new( + crate::proxy::route_mode::RelayRouteMode::Direct, + )); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: std::net::SocketAddr = "198.51.100.80:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(std::time::Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load() { + let mut cfg = crate::config::ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs = vec!["10.0.0.0/8".parse().unwrap()]; + + let config = std::sync::Arc::new(cfg); + + for idx in 0..32u16 { + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new( + vec![crate::config::UpstreamConfig { + upstream_type: crate::config::UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new( + 64, + std::time::Duration::from_secs(60), + )); + let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); + let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); + let route_runtime = + std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new( + crate::proxy::route_mode::RelayRouteMode::Direct, + )); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(1024); + let peer = std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8)), + 55_000 + idx, + ); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config.clone(), + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.10:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(std::time::Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!( + matches!(result, Err(ProxyError::InvalidProxyProtocol)), + "burst idx {idx}: untrusted source must be rejected" + ); + } +} + +#[tokio::test] +async fn reservation_limit_failure_does_not_leak_curr_connects_counter() { + let user = "leak-check-user"; + let mut config = crate::config::ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(crate::stats::Stats::new()); + let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let first_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 1)), 50001); + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let second_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 2)), 50002); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + second_peer, + ip_tracker.clone(), + ) + .await; + + assert!( + matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user), + "second reservation must be rejected at the configured tcp-conns limit" + ); + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "failed acquisition must not leak a counter increment" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed acquisition must not mutate IP tracker state" + ); + + first.release().await; + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn short_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn tls12_record_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x03, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.78:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn handle_client_stream_increments_connects_all_exactly_once() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.177:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + drop(client_side); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "handle_client_stream must increment connects_all exactly once" + ); +} + +#[tokio::test] +async fn running_client_handler_increments_connects_all_exactly_once() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [0x16, 0x03, 0x01, 0x00, 0x10]; + + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "ClientHandler::run must increment connects_all exactly once" + ); +} + +#[tokio::test] +async fn partial_tls_header_stall_triggers_handshake_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.170:55201".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); +} + +fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec { + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + + let total_len = 5 + tls_len; + let mut handshake = vec![0x42u8; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_client_hello_with_len(secret, timestamp, 600) +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(0x16); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + record +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&[0x03, 0x03]); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +#[tokio::test] +async fn valid_tls_path_does_not_fall_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x11u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "11111111111111111111111111111111".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.80:55002".parse().unwrap(); + let stats_for_assert = stats.clone(); + let bad_before = stats_for_assert.get_connects_bad(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Mask backend must not be contacted on authenticated TLS path" + ); + + let bad_after = stats_for_assert.get_connects_bad(); + assert_eq!( + bad_before, bad_after, + "Authenticated TLS path must not increment connects_bad" + ); +} + +#[tokio::test] +async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x33u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_tls_payload = b"still-tls-after-fallback".to_vec(); + let trailing_tls_record = wrap_tls_application_data(&trailing_tls_payload); + + let expected_trailing_tls_record = trailing_tls_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut trailing = vec![0u8; expected_trailing_tls_record.len()]; + stream.read_exact(&mut trailing).await.unwrap(); + assert_eq!(trailing, expected_trailing_tls_record); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "33333333333333333333333333333333".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(32768); + let peer: SocketAddr = "198.51.100.90:55111".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&tls_app_record).await.unwrap(); + client_side.write_all(&trailing_tls_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x44u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_tls_payload = b"second-tls-record".to_vec(); + let trailing_tls_record = wrap_tls_application_data(&trailing_tls_payload); + + let expected_trailing_tls_record = trailing_tls_record.clone(); + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut trailing = vec![0u8; expected_trailing_tls_record.len()]; + stream.read_exact(&mut trailing).await.unwrap(); + assert_eq!(trailing, expected_trailing_tls_record); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "44444444444444444444444444444444".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client.write_all(&tls_app_record).await.unwrap(); + client.write_all(&trailing_tls_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let probe = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.censorship.alpn_enforce = true; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "66666666666666666666666666666666".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.66:55211".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x77u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "77777777777777777777777777777777".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.77:55212".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn burst_invalid_tls_probes_are_masked_verbatim() { + const N: usize = 12; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x88u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + for _ in 0..N { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + } + }); + + let mut handlers = Vec::with_capacity(N); + for i in 0..N { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "88888888888888888888888888888888".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = format!("198.51.100.{}:{}", 100 + i, 56000 + i) + .parse() + .unwrap(); + let probe_bytes = probe.clone(); + + let h = tokio::spawn(async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe_bytes).await.unwrap(); + drop(client_side); + + tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap() + .unwrap(); + }); + handlers.push(h); + } + + for h in handlers { + tokio::time::timeout(Duration::from_secs(5), h) + .await + .unwrap() + .unwrap(); + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn unexpected_eof_is_classified_without_string_matching() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)); + let peer_ip: IpAddr = "198.51.100.200".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[expected_64_got_0]"), + "UnexpectedEof must be classified as expected_64_got_0" + ); + assert!( + snapshot.contains("198.51.100.200-1"), + "Classified record must include source IP" + ); +} + +#[test] +fn non_eof_error_is_classified_as_other() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let non_eof = ProxyError::Io(std::io::Error::other("different error")); + let peer_ip: IpAddr = "203.0.113.201".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &non_eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[other]"), + "Non-EOF errors must map to other" + ); + assert!( + snapshot.contains("203.0.113.201-1"), + "Classified record must include source IP" + ); + assert!( + !snapshot.contains("[expected_64_got_0]"), + "Non-EOF errors must not be misclassified as expected_64_got_0" + ); +} + +#[test] +fn beobachten_ttl_zero_minutes_is_floored_to_one_minute() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(60), + "beobachten_minutes=0 must be fail-closed to a one-minute minimum TTL" + ); +} + +#[test] +fn beobachten_ttl_positive_minutes_remain_unchanged() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 7; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(7 * 60), + "configured positive beobacten TTL must be preserved" + ); +} + +#[tokio::test] +async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let stats = Stats::new(); + stats.increment_user_curr_connects("user"); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.210:50000".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 0); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { + let user = "check-helper-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + + let first = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!( + first.is_ok(), + "first check-only limit validation must succeed" + ); + + let second = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!( + second.is_ok(), + "second check-only validation must not fail from leaked state" + ); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn stress_check_user_limits_static_success_never_leaks_state() { + let user = "check-helper-stress-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + + for i in 0..4096u16 { + let peer_addr = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 110, (i % 250) as u8 + 1)), + 40000 + (i % 1024), + ); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!( + result.is_ok(), + "check-only helper must remain leak-free under stress" + ); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "stress success loop must not leak user connection counters" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "stress success loop must not leak active IP reservations" + ); +} + +#[tokio::test] +async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak() { + let user = "rollback-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 128); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let keeper_peer: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + let keeper = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + keeper_peer, + ip_tracker.clone(), + ) + .await + .expect("keeper reservation must succeed"); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u8 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 101, i.saturating_add(1))), + 41000 + i as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await; + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "rollback-storm-user" + )); + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "failed distinct-IP attempts must rollback acquired user slots" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed distinct-IP attempts must not leave extra active IPs" + ); + + keeper.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn explicit_reservation_release_cleans_user_and_ip_immediately() { + let user = "release-user"; + let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "explicit release must synchronously free user connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "explicit release must synchronously remove reserved user IP" + ); +} + +#[tokio::test] +async fn explicit_reservation_release_does_not_double_decrement_on_drop() { + let user = "release-once-user"; + let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + .expect("reservation acquisition must succeed"); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release must disarm drop and prevent double decrement" + ); +} + +#[tokio::test] +async fn drop_fallback_eventually_cleans_user_and_ip_reservation() { + let user = "drop-fallback-user"; + let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + drop(reservation); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must eventually clean both user slot and active IP"); +} + +#[tokio::test] +async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() { + let user = "cross-ip-user"; + let peer1: SocketAddr = "198.51.100.243:50005".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + first.release().await; + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed immediately after explicit release"); + second.release().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn release_abort_storm_does_not_leak_user_or_ip_reservations() { + const ATTEMPTS: usize = 256; + + let user = "release-abort-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), ATTEMPTS + 16); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for idx in 0..ATTEMPTS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 114, (idx % 250 + 1) as u8)), + 52000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed in abort storm"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("release abort storm must not leak user slots or active IP entries"); +} + +#[tokio::test] +async fn release_abort_loop_preserves_immediate_same_ip_reacquire() { + const ITERATIONS: usize = 128; + + let user = "release-abort-reacquire-user"; + let peer: SocketAddr = "198.51.100.246:53001".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for _ in 0..ITERATIONS { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("baseline acquisition must succeed"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("aborted release must still converge to zero footprint"); + } + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("same-ip reacquire must succeed after repeated abort-release churn"); + reservation.release().await; +} + +#[tokio::test] +async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() { + const RESERVATIONS: usize = 192; + + let user = "mixed-wave-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 115, (idx % 250 + 1) as u8)), + 54000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("mixed-wave acquisition must succeed"); + reservations.push(reservation); + } + + let mut seed: u64 = 0xDEAD_BEEF_CAFE_BA5E; + let mut join_set = tokio::task::JoinSet::new(); + for reservation in reservations { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + match seed % 3 { + 0 => { + join_set.spawn(async move { + reservation.release().await; + }); + } + 1 => { + drop(reservation); + } + _ => { + let task = tokio::spawn(async move { + reservation.release().await; + }); + task.abort(); + let _ = task.await; + } + } + } + + while let Some(result) = join_set.join_next().await { + result.expect("release subtask must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("mixed release/drop/abort wave must converge to zero footprint"); +} + +#[tokio::test] +async fn parallel_users_abort_release_isolation_preserves_independent_cleanup() { + let user_a = "abort-isolation-a"; + let user_b = "abort-isolation-b"; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user_a.to_string(), 64); + config + .access + .user_max_tcp_conns + .insert(user_b.to_string(), 64); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for idx in 0..64usize { + let user = if idx % 2 == 0 { user_a } else { user_b }; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 18, 0, (idx % 250 + 1) as u8)), + 55000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("parallel-user acquisition must succeed"); + + tasks.spawn(async move { + let t = tokio::spawn(async move { + reservation.release().await; + }); + t.abort(); + let _ = t.await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("parallel-user abort task must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user_a) == 0 + && stats.get_user_curr_connects(user_b) == 0 + && ip_tracker.get_active_ip_count(user_a).await == 0 + && ip_tracker.get_active_ip_count(user_b).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("parallel users must cleanup independently under abort churn"); +} + +#[tokio::test] +async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { + const RESERVATIONS: usize = 64; + + let user = "release-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let ip = std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8); + let peer = SocketAddr::new(IpAddr::V4(ip), 51000 + idx as u16); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition in storm must succeed"); + reservations.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), RESERVATIONS as u64); + assert_eq!(ip_tracker.get_active_ip_count(user).await, RESERVATIONS); + + let mut tasks = tokio::task::JoinSet::new(); + for reservation in reservations { + tasks.spawn(async move { + reservation.release().await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("release task must not panic"); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release storm must drain user current-connection counter to zero" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "release storm must clear all active IP entries" + ); +} + +#[tokio::test] +async fn relay_connect_error_releases_user_and_ip_before_return() { + let user = "relay-error-user"; + let peer_addr: SocketAddr = "198.51.100.245:50007".parse().unwrap(); + + let dead_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dead_port = dead_listener.local_addr().unwrap().port(); + drop(dead_listener); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + config + .dc_overrides + .insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, _client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let result = RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + ) + .await; + + assert!( + result.is_err(), + "relay must fail when upstream DC is unreachable" + ); + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "error return must release user slot before returning" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "error return must release user IP reservation before returning" + ); +} + +#[tokio::test] +async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() { + let user = "same-ip-mixed-user"; + let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation_a.release().await; + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "explicit release must decrement only one active reservation" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "same IP must remain active while second reservation exists" + ); + + drop(reservation_b); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must clear final same-IP reservation"); +} + +#[tokio::test] +async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { + let user = "same-ip-drop-one-user"; + let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + drop(reservation_a); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("dropping one reservation must keep same-IP activity for remaining reservation"); + + reservation_b.release().await; + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("final release must converge to zero footprint after async fallback cleanup"); +} + +#[tokio::test] +async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config + .access + .user_data_quota + .insert("user".to_string(), 1024); + + let stats = Stats::new(); + stats.add_user_octets_from("user", 1024); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::DataQuotaExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Quota-rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_quota_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() { + let mut config = ProxyConfig::default(); + config.access.user_expirations.insert( + "user".to_string(), + chrono::Utc::now() - chrono::Duration::seconds(1), + ); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.212:50002".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::UserExpired { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn same_ip_second_reservation_succeeds_under_unique_ip_limit_one() { + let user = "same-ip-unique-limit-user"; + let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation from same IP must succeed under unique-ip limit=1"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; + second.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn second_distinct_ip_is_rejected_under_unique_ip_limit_one() { + let user = "distinct-ip-unique-limit-user"; + let peer1: SocketAddr = "198.51.100.249:50011".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + second, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "distinct-ip-unique-limit-user" + )); + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; +} + +#[tokio::test] +async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() { + let user = "cross-thread-drop-user"; + let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread drop must still converge to zero user and IP footprint"); +} + +#[tokio::test] +async fn immediate_reacquire_after_cross_thread_drop_succeeds() { + let user = "cross-thread-reacquire-user"; + let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread cleanup must settle before reacquire check"); + + let reacquire = RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer_addr, ip_tracker, + ) + .await; + assert!( + reacquire.is_ok(), + "reacquire must succeed after cross-thread drop cleanup" + ); +} + +#[tokio::test] +async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { + const PARALLEL_IPS: usize = 64; + const ATTEMPTS_PER_IP: usize = 8; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + stats.increment_user_curr_connects("user"); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..PARALLEL_IPS { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + + tasks.spawn(async move { + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 100, (i + 1) as u8)); + for _ in 0..ATTEMPTS_PER_IP { + let peer_addr = SocketAddr::new(ip, 40000 + i as u16); + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + } + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Concurrent rejected attempts must not leave active IP reservations" + ); + + let recent = ip_tracker + .get_recent_ips_for_users(&["user".to_string()]) + .await; + assert!( + recent.get("user").map(|ips| ips.is_empty()).unwrap_or(true), + "Concurrent rejected attempts must not leave recent IP footprint" + ); + + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur under concurrent rejection storms" + ); +} + +#[tokio::test] +async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)), + 30000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + "user", &config, stats, peer, ip_tracker, + ) + .await + .ok() + }); + } + + let mut successes = 0u64; + let mut held_reservations = Vec::new(); + while let Some(joined) = tasks.join_next().await { + if let Some(reservation) = joined.unwrap() { + successes += 1; + held_reservations.push(reservation); + } + } + + assert_eq!( + successes, 1, + "exactly one concurrent acquire must pass for a limit=1 user" + ); + assert_eq!(stats.get_user_curr_connects("user"), 1); + + drop(held_reservations); +} + +#[tokio::test] +async fn untrusted_proxy_header_source_is_rejected() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs = vec!["10.10.0.0/16".parse().unwrap()]; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.44:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn empty_proxy_trusted_cidrs_rejects_proxy_header_by_default() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs.clear(); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.45:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_PLAINTEXT_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_PLAINTEXT_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.123:55123".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "Oversized TLS probe must be classified as bad" + ); +} + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_PLAINTEXT_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_PLAINTEXT_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_min_minus_1_is_rejected_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [ + 0x16, + 0x03, + 0x01, + (((MIN_TLS_CLIENT_HELLO_SIZE - 1) >> 8) & 0xff) as u8, + ((MIN_TLS_CLIENT_HELLO_SIZE - 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.130:55130".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "TLS record length below minimum structural ClientHello size must be rejected" + ); +} + +#[tokio::test] +async fn tls_record_len_min_minus_1_is_rejected_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [ + 0x16, + 0x03, + 0x01, + (((MIN_TLS_CLIENT_HELLO_SIZE - 1) >> 8) & 0xff) as u8, + ((MIN_TLS_CLIENT_HELLO_SIZE - 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x55u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_PLAINTEXT_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "55555555555555555555555555555555".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.55:56055".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!( + record_header[0], 0x16, + "Valid max-length ClientHello must be accepted" + ); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_PLAINTEXT_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "66666666666666666666666666666666".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client.read_exact(&mut record_header).await.unwrap(); + assert_eq!( + record_header[0], 0x16, + "Valid max-length ClientHello must be accepted" + ); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + + let no_mask_connect = + tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback in ClientHandler path" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} + +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn wait_for_user_and_ip_zero( + stats: &Arc, + ip_tracker: &Arc, + user: &str, +) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cleanup must converge to zero user and IP footprint"); +} + +async fn burst_acquire_distinct_ips( + user: &'static str, + config: Arc, + stats: Arc, + ip_tracker: Arc, + third_octet: u8, + attempts: u16, +) -> (Vec, usize) { + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..attempts { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let host = (i as u8).saturating_add(1); + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third_octet, host)), + 55000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await + }); + } + + let mut successes = Vec::new(); + let mut failures = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("burst acquire task must not panic") { + Ok(reservation) => successes.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user: ref denied_user } + if denied_user == user + )); + failures = failures.saturating_add(1); + } + } + } + + (successes, failures) +} + +#[tokio::test] +async fn deterministic_mixed_reservation_churn_preserves_counter_and_eventual_cleanup() { + let user = "deterministic-churn-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 12); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut seed = 0xD1F2_A4C8_991B_77E1u64; + let mut reservations: Vec> = Vec::new(); + + for step in 0..220u64 { + let op = (lcg_next(&mut seed) % 100) as u8; + let active = reservations.iter().filter(|entry| entry.is_some()).count(); + + if active == 0 || op < 55 { + let ip_octet = (lcg_next(&mut seed) % 16 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 120, ip_octet)), + 52000 + (step % 4000) as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + if let Ok(reservation) = result { + reservations.push(Some(reservation)); + } else { + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "deterministic-churn-user" + )); + } + } else { + let selected = reservations + .iter() + .enumerate() + .filter(|(_, entry)| entry.is_some()) + .map(|(idx, _)| idx) + .nth((lcg_next(&mut seed) as usize) % active) + .unwrap(); + + let reservation = reservations[selected].take().unwrap(); + if op < 80 { + reservation.release().await; + } else { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("cross-thread drop must not panic"); + } + } + + let live_slots = reservations.iter().filter(|entry| entry.is_some()).count() as u64; + assert_eq!( + stats.get_user_curr_connects(user), + live_slots, + "current-connects counter must match number of live reservations" + ); + assert!( + stats.get_user_curr_connects(user) <= 12, + "current-connects must stay within configured TCP limit" + ); + assert!( + ip_tracker.get_active_ip_count(user).await <= 4, + "active unique IPs must stay within configured per-user IP limit" + ); + } + + for reservation in reservations.into_iter().flatten() { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() { + let user = "drop-storm-reacquire-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 64); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let mut initial = Vec::new(); + for i in 0..32u16 { + let ip_octet = (i % 8 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 120, ip_octet)), + 53000 + i, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + initial.push(reservation); + } + + let mut second_half = initial.split_off(16); + + let mut releases = Vec::new(); + for reservation in initial { + releases.push(tokio::spawn(async move { + reservation.release().await; + })); + } + for release_task in releases { + release_task.await.expect("release task must not panic"); + } + + let mut drop_threads = Vec::new(); + for reservation in second_half.drain(..) { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread drop worker must not panic"); + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; + + let mut reacquire_tasks = tokio::task::JoinSet::new(); + for i in 0..16u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + reacquire_tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 121, (i + 1) as u8)), + 54000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await + }); + } + + let mut acquired = Vec::new(); + while let Some(joined) = reacquire_tasks.join_next().await { + match joined.expect("reacquire task must not panic") { + Ok(reservation) => acquired.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user } + if user == "drop-storm-reacquire-user" + )); + } + } + } + + assert!( + acquired.len() <= 8, + "parallel distinct-IP reacquire wave must not exceed per-user unique IP limit" + ); + for reservation in acquired { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { + let user: &'static str = "scheduled-attack-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 6); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 2).await; + + let mut base = Vec::new(); + for i in 0..5u16 { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), + 56000 + i, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("near-limit warmup reservation must succeed"); + base.push(reservation); + } + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let (wave1_success, wave1_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 131, + 32, + ) + .await; + assert_eq!(wave1_success.len(), 1); + assert_eq!(wave1_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 6); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let released = base.pop().expect("must have releasable reservation"); + released.release().await; + for reservation in wave1_success { + reservation.release().await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 4 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("window cleanup must settle to expected occupancy"); + + let (wave2_success, wave2_fail) = + burst_acquire_distinct_ips(user, config, stats.clone(), ip_tracker.clone(), 132, 32).await; + assert_eq!(wave2_success.len(), 1); + assert_eq!(wave2_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let tail = base.split_off(2); + + let mut drop_threads = Vec::new(); + for reservation in base { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread scheduled cleanup must not panic"); + } + + for reservation in tail { + reservation.release().await; + } + for reservation in wave2_success { + reservation.release().await; + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_mode_switch_burst_churn_preserves_limits_and_cleanup() { + let user: &'static str = "scheduled-mode-switch-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 3).await; + + let base_peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 140, 1)), 57000); + let mut base = Vec::new(); + for i in 0..7u16 { + let peer = SocketAddr::new(base_peer.ip(), base_peer.port().saturating_add(i)); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("base occupancy reservation must succeed"); + base.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 7); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for round in 0..8u8 { + let (wave_success, wave_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 141u8.saturating_add(round), + 24, + ) + .await; + + assert!( + wave_success.len() <= 2, + "burst must not exceed available unique-IP headroom under limit=3" + ); + assert_eq!(wave_success.len() + wave_fail, 24); + assert_eq!( + stats.get_user_curr_connects(user), + 7 + wave_success.len() as u64, + "slot counter must reflect base occupancy plus successful burst leases" + ); + assert!(ip_tracker.get_active_ip_count(user).await <= 3); + + if round % 2 == 0 { + for reservation in wave_success { + reservation.release().await; + } + let rotated = base.pop().expect("base rotation reservation must exist"); + rotated.release().await; + } else { + for reservation in wave_success { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop-heavy burst cleanup thread must not panic"); + } + let rotated = base.pop().expect("base rotation reservation must exist"); + std::thread::spawn(move || { + drop(rotated); + }) + .join() + .expect("drop-heavy base cleanup thread must not panic"); + } + + let replacement = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + base_peer, + ip_tracker.clone(), + ) + .await + .expect("base replacement reservation must succeed after each round"); + base.push(replacement); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 7 + && ip_tracker.get_active_ip_count(user).await <= 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("round cleanup must converge to steady base occupancy"); + } + + for reservation in base { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} diff --git a/src/proxy/tests/client_timing_profile_adversarial_tests.rs b/src/proxy/tests/client_timing_profile_adversarial_tests.rs new file mode 100644 index 0000000..69a9ff4 --- /dev/null +++ b/src/proxy/tests/client_timing_profile_adversarial_tests.rs @@ -0,0 +1,370 @@ +//! Differential timing-profile adversarial tests. +//! Compare malformed in-range TLS truncation probes with plain web baselines, +//! ensuring masking behavior stays in similar latency buckets. + +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::{TcpListener, TcpStream}; + +const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + +#[derive(Clone, Copy, Debug)] +enum ProbeClass { + MalformedTlsTruncation, + PlainWebBaseline, +} + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn malformed_tls_probe() -> Vec { + vec![ + 0x16, + 0x03, + 0x03, + ((MIN_TLS_CLIENT_HELLO_SIZE >> 8) & 0xff) as u8, + (MIN_TLS_CLIENT_HELLO_SIZE & 0xff) as u8, + 0x41, + ] +} + +fn plain_web_probe() -> Vec { + b"GET /timing-profile HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec() +} + +fn summarize(samples_ms: &[u128]) -> (f64, u128, u128, u128) { + let mut sorted = samples_ms.to_vec(); + sorted.sort_unstable(); + let sum: u128 = sorted.iter().copied().sum(); + let mean = sum as f64 / sorted.len() as f64; + let min = sorted[0]; + let p95_idx = ((sorted.len() as f64) * 0.95).floor() as usize; + let p95 = sorted[p95_idx.min(sorted.len() - 1)]; + let max = sorted[sorted.len() - 1]; + (mean, min, p95, max) +} + +async fn run_generic_once(class: ProbeClass) -> u128 { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = REPLY_404.to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 5]; + stream.read_exact(&mut buf).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + if matches!(class, ProbeClass::PlainWebBaseline) { + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + } + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.210:55110".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + let probe = match class { + ProbeClass::MalformedTlsTruncation => malformed_tls_probe(), + ProbeClass::PlainWebBaseline => plain_web_probe(), + }; + + let started = Instant::now(); + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +async fn run_client_handler_once(class: ProbeClass) -> u128 { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let backend_reply = REPLY_404.to_vec(); + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut buf = [0u8; 5]; + stream.read_exact(&mut buf).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + if matches!(class, ProbeClass::PlainWebBaseline) { + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + } + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let probe = match class { + ProbeClass::MalformedTlsTruncation => malformed_tls_probe(), + ProbeClass::PlainWebBaseline => plain_web_probe(), + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + let started = Instant::now(); + client.write_all(&probe).await.unwrap(); + client.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +#[tokio::test] +async fn differential_timing_generic_malformed_tls_vs_plain_web_mask_profile_similar() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 20; + + let mut malformed = Vec::with_capacity(ITER); + let mut plain = Vec::with_capacity(ITER); + + for _ in 0..ITER { + malformed.push(run_generic_once(ProbeClass::MalformedTlsTruncation).await); + plain.push(run_generic_once(ProbeClass::PlainWebBaseline).await); + } + + let (m_mean, m_min, m_p95, m_max) = summarize(&malformed); + let (p_mean, p_min, p_p95, p_max) = summarize(&plain); + + println!( + "TIMING_DIFF generic class=malformed mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + m_mean, + m_min, + m_p95, + m_max, + (m_mean as u128) / BUCKET_MS, + m_p95 / BUCKET_MS + ); + println!( + "TIMING_DIFF generic class=plain_web mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + p_mean, + p_min, + p_p95, + p_max, + (p_mean as u128) / BUCKET_MS, + p_p95 / BUCKET_MS + ); + + let mean_bucket_delta = ((m_mean as i128) - (p_mean as i128)).abs() / (BUCKET_MS as i128); + let p95_bucket_delta = ((m_p95 as i128) - (p_p95 as i128)).abs() / (BUCKET_MS as i128); + + assert!( + mean_bucket_delta <= 1, + "generic timing mean diverged: malformed_mean_ms={:.2}, plain_mean_ms={:.2}", + m_mean, + p_mean + ); + assert!( + p95_bucket_delta <= 2, + "generic timing p95 diverged: malformed_p95_ms={}, plain_p95_ms={}", + m_p95, + p_p95 + ); +} + +#[tokio::test] +async fn differential_timing_client_handler_malformed_tls_vs_plain_web_mask_profile_similar() { + const ITER: usize = 16; + const BUCKET_MS: u128 = 20; + + let mut malformed = Vec::with_capacity(ITER); + let mut plain = Vec::with_capacity(ITER); + + for _ in 0..ITER { + malformed.push(run_client_handler_once(ProbeClass::MalformedTlsTruncation).await); + plain.push(run_client_handler_once(ProbeClass::PlainWebBaseline).await); + } + + let (m_mean, m_min, m_p95, m_max) = summarize(&malformed); + let (p_mean, p_min, p_p95, p_max) = summarize(&plain); + + println!( + "TIMING_DIFF handler class=malformed mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + m_mean, + m_min, + m_p95, + m_max, + (m_mean as u128) / BUCKET_MS, + m_p95 / BUCKET_MS + ); + println!( + "TIMING_DIFF handler class=plain_web mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + p_mean, + p_min, + p_p95, + p_max, + (p_mean as u128) / BUCKET_MS, + p_p95 / BUCKET_MS + ); + + let mean_bucket_delta = ((m_mean as i128) - (p_mean as i128)).abs() / (BUCKET_MS as i128); + let p95_bucket_delta = ((m_p95 as i128) - (p_p95 as i128)).abs() / (BUCKET_MS as i128); + + assert!( + mean_bucket_delta <= 1, + "handler timing mean diverged: malformed_mean_ms={:.2}, plain_mean_ms={:.2}", + m_mean, + p_mean + ); + assert!( + p95_bucket_delta <= 2, + "handler timing p95 diverged: malformed_p95_ms={}, plain_p95_ms={}", + m_p95, + p_p95 + ); +} diff --git a/src/proxy/tests/client_tls_clienthello_size_security_tests.rs b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs new file mode 100644 index 0000000..0c864e7 --- /dev/null +++ b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs @@ -0,0 +1,209 @@ +//! TLS ClientHello size validation tests for proxy anti-censorship security +//! Covers positive, negative, edge, adversarial, and fuzz cases. +//! Ensures proxy does not reveal itself on probe failures. + +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; + +fn test_probe_for_len(len: usize) -> [u8; 5] { + [ + 0x16, + 0x03, + 0x03, + ((len >> 8) & 0xff) as u8, + (len & 0xff) as u8, + ] +} + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = test_probe_for_len(len); + let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe, "mask backend must receive original probe bytes"); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.123:55123".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!( + observed, backend_reply, + "invalid TLS path must be masked as a real site" + ); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + let expected_bad = if expect_bad_increment { + bad_before + 1 + } else { + bad_before + }; + assert_eq!( + stats.get_connects_bad(), + expected_bad, + "unexpected connects_bad classification for tls_len={len}" + ); +} + +#[tokio::test] +async fn tls_client_hello_lower_bound_minus_one_is_masked_and_counted_bad() { + run_probe_and_assert_masking(MIN_TLS_CLIENT_HELLO_SIZE - 1, true).await; +} + +#[tokio::test] +async fn tls_client_hello_upper_bound_plus_one_is_masked_and_counted_bad() { + run_probe_and_assert_masking(MAX_TLS_PLAINTEXT_SIZE + 1, true).await; +} + +#[tokio::test] +async fn tls_client_hello_header_zero_len_is_masked_and_counted_bad() { + run_probe_and_assert_masking(0, true).await; +} + +#[test] +fn tls_client_hello_len_bounds_unit_adversarial_sweep() { + let cases = [ + (0usize, false), + (1usize, false), + (99usize, false), + (100usize, true), + (101usize, true), + (511usize, true), + (512usize, true), + (MAX_TLS_PLAINTEXT_SIZE - 1, true), + (MAX_TLS_PLAINTEXT_SIZE, true), + (MAX_TLS_PLAINTEXT_SIZE + 1, false), + (u16::MAX as usize, false), + (usize::MAX, false), + ]; + + for (len, expected) in cases { + assert_eq!( + tls_clienthello_len_in_bounds(len), + expected, + "unexpected bounds result for tls_len={len}" + ); + } +} + +#[test] +fn tls_client_hello_len_bounds_light_fuzz_deterministic_lcg() { + let mut x: u32 = 0xA5A5_5A5A; + for _ in 0..2_048 { + x = x.wrapping_mul(1_664_525).wrapping_add(1_013_904_223); + let base = (x as usize) & 0x3fff; + let len = match x & 0x7 { + 0 => MIN_TLS_CLIENT_HELLO_SIZE - 1, + 1 => MIN_TLS_CLIENT_HELLO_SIZE, + 2 => MIN_TLS_CLIENT_HELLO_SIZE + 1, + 3 => MAX_TLS_PLAINTEXT_SIZE - 1, + 4 => MAX_TLS_PLAINTEXT_SIZE, + 5 => MAX_TLS_PLAINTEXT_SIZE + 1, + _ => base, + }; + let expect_bad = !(MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&len); + assert_eq!( + tls_clienthello_len_in_bounds(len), + !expect_bad, + "deterministic fuzz mismatch for tls_len={len}" + ); + } +} + +#[test] +fn tls_client_hello_len_bounds_stress_many_evaluations() { + for _ in 0..100_000 { + assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); + assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); + assert!(!tls_clienthello_len_in_bounds( + MIN_TLS_CLIENT_HELLO_SIZE - 1 + )); + assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); + } +} + +#[tokio::test] +async fn tls_client_hello_masking_integration_repeated_small_probes() { + for _ in 0..25 { + run_probe_and_assert_masking(MIN_TLS_CLIENT_HELLO_SIZE - 1, true).await; + } +} diff --git a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs new file mode 100644 index 0000000..79a8640 --- /dev/null +++ b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs @@ -0,0 +1,572 @@ +//! Black-hat adversarial tests for truncated in-range TLS ClientHello probes. +//! These tests encode a strict anti-probing expectation: malformed TLS traffic +//! should still be masked as a legitimate website response. + +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::time::sleep; + +fn in_range_probe_header() -> [u8; 5] { + [ + 0x16, + 0x03, + 0x03, + ((MIN_TLS_CLIENT_HELLO_SIZE >> 8) & 0xff) as u8, + (MIN_TLS_CLIENT_HELLO_SIZE & 0xff) as u8, + ] +} + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn truncated_in_range_record(actual_body_len: usize) -> Vec { + let mut out = in_range_probe_header().to_vec(); + out.extend(std::iter::repeat_n(0x41, actual_body_len)); + out +} + +async fn write_fragmented( + writer: &mut W, + bytes: &[u8], + chunks: &[usize], + delay_ms: u64, +) { + let mut offset = 0usize; + for &chunk in chunks { + if offset >= bytes.len() { + break; + } + let end = (offset + chunk).min(bytes.len()); + writer.write_all(&bytes[offset..end]).await.unwrap(); + offset = end; + if delay_ms > 0 { + sleep(Duration::from_millis(delay_ms)).await; + } + } + if offset < bytes.len() { + writer.write_all(&bytes[offset..]).await.unwrap(); + } +} + +async fn run_blackhat_generic_fragmented_probe_should_mask( + payload: Vec, + chunks: &[usize], + delay_ms: u64, + backend_reply: Vec, +) { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + let probe_header = in_range_probe_header(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe_header); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.202:55002".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + write_fragmented(&mut client_side, &payload, chunks, delay_ms).await; + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); +} + +async fn run_blackhat_client_handler_fragmented_probe_should_mask( + payload: Vec, + chunks: &[usize], + delay_ms: u64, + backend_reply: Vec, +) { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe_header = in_range_probe_header(); + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe_header); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + write_fragmented(&mut client, &payload, chunks, delay_ms).await; + client.shutdown().await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe = in_range_probe_header(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.201:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + // Security expectation: even malformed in-range TLS should be masked. + // This invariant must hold to avoid probe-distinguishable EOF/timeout behavior. + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe = in_range_probe_header(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + client.shutdown().await.unwrap(); + + // Security expectation: malformed in-range TLS should still be masked. + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_generic_truncated_min_body_1_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[6], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_truncated_min_body_8_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(8), + &[13], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_truncated_min_body_99_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1), + &[5, MIN_TLS_CLIENT_HELLO_SIZE - 1], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_fragmented_header_then_close_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(0), + &[1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(5), + &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_slowloris_fragmented_min_probe_should_mask_but_times_out() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[1, 1, 1, 1, 1, 1], + 250, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_truncated_min_body_1_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[6], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_truncated_min_body_8_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(8), + &[13], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_truncated_min_body_99_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1), + &[5, MIN_TLS_CLIENT_HELLO_SIZE - 1], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_fragmented_header_then_close_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(0), + &[1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(5), + &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_slowloris_fragmented_min_probe_should_mask_but_times_out() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[1, 1, 1, 1, 1, 1], + 250, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} diff --git a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs new file mode 100644 index 0000000..95e49f7 --- /dev/null +++ b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs @@ -0,0 +1,2982 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{ + HANDSHAKE_LEN, MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION, + TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, +}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn wrap_tls_record(record_type: u8, payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(record_type); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn wrap_invalid_mtproto_with_coalesced_tail(tail: &[u8]) -> Vec { + let mut payload = vec![0u8; HANDSHAKE_LEN]; + payload.extend_from_slice(tail); + wrap_tls_application_data(&payload) +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x81u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0, 600, 0x42); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = b"masked-trailing-record".to_vec(); + let trailing_record = wrap_tls_application_data(&trailing_payload); + let backend_response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + let expected_trailing_record = trailing_record.clone(); + let expected_response = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + + stream.write_all(&expected_response).await.unwrap(); + }); + + let harness = build_harness("81818181818181818181818181818181", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.181:56001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x82u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 1, 600, 0x43); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_record = wrap_tls_application_data(b"x"); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("82828282828282828282828282828282", backend_addr.port()); + let bad_before = harness.stats.get_connects_bad(); + + let (server_side, mut client_side) = duplex(65536); + let peer: SocketAddr = "198.51.100.182:56002".parse().unwrap(); + let stats_for_assert = harness.stats.clone(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + let bad_after = stats_for_assert.get_connects_bad(); + assert_eq!( + bad_after, + bad_before + 1, + "connects_bad must increase exactly once for invalid MTProto after valid TLS" + ); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x83u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 2, 600, 0x44); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_record = wrap_tls_application_data(&[]); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("83838383838383838383838383838383", backend_addr.port()); + let (server_side, mut client_side) = duplex(65536); + let peer: SocketAddr = "198.51.100.183:56003".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x84u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 3, 600, 0x45); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = vec![0xAB; MAX_TLS_CIPHERTEXT_SIZE]; + let trailing_record = wrap_tls_application_data(&trailing_payload); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("84848484848484848484848484848484", backend_addr.port()); + let (server_side, mut client_side) = duplex(2 * 1024 * 1024); + let peer: SocketAddr = "198.51.100.184:56004".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() { + let lengths = [0usize, 1, 2, 3, 7, 15, 31, 63, 127, 255, 1024, 4096]; + + for (idx, payload_len) in lengths.iter().copied().enumerate() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x85u8; 16]; + let client_hello = + make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + + let mut payload = vec![0u8; payload_len]; + for (i, b) in payload.iter_mut().enumerate() { + *b = ((idx as u8).wrapping_mul(29)).wrapping_add(i as u8); + } + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("85858585858585858585858585858585", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = format!("198.51.100.185:{}", 56010 + idx as u16) + .parse() + .unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { + let sessions = 24usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected_records = std::collections::HashSet::new(); + let secret = [0x86u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); + let payload = vec![idx as u8; 64 + idx]; + let trailing = wrap_tls_application_data(&payload); + expected_records.insert(trailing); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected_records; + for idx in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + + let _ = idx; + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "unexpected trailing TLS record in concurrent isolation test" + ); + } + + assert!( + remaining.is_empty(), + "all expected client sessions must be matched exactly once" + ); + }); + + let mut client_tasks = Vec::with_capacity(sessions); + + for idx in 0..sessions { + let harness = build_harness("86868686868686868686868686868686", backend_addr.port()); + let secret = [0x86u8; 16]; + let client_hello = + make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = vec![idx as u8; 64 + idx]; + let trailing_record = wrap_tls_application_data(&trailing_payload); + + let peer: SocketAddr = format!("198.51.100.186:{}", 57000 + idx as u16) + .parse() + .unwrap(); + + client_tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in client_tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x87u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 9, 600, 0x57); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let payload = b"fragmented-writes-to-test-stream-boundary-robustness".to_vec(); + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("87878787878787878787878787878787", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.187:56087".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + + for chunk in trailing_record.chunks(3) { + client_side.write_all(chunk).await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x88u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 10, 600, 0x58); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"bytewise-header"); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("88888888888888888888888888888888", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.188:56088".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + for b in trailing_record.iter().copied() { + client_side.write_all(&[b]).await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x89u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 11, 600, 0x59); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let mut payload = vec![0u8; 2048]; + for (i, b) in payload.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(17).wrapping_add(3); + } + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("89898989898989898989898989898989", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.189:56089".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + + let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; + let mut pos = 0usize; + let mut idx = 0usize; + while pos < trailing_record.len() { + let step = chaos[idx % chaos.len()]; + let end = (pos + step).min(trailing_record.len()); + client_side + .write_all(&trailing_record[pos..end]) + .await + .unwrap(); + pos = end; + idx += 1; + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Au8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 12, 600, 0x5A); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let r1 = wrap_tls_application_data(b"alpha"); + let r2 = wrap_tls_application_data(b"beta-beta"); + let r3 = wrap_tls_application_data(b"gamma-gamma-gamma"); + let expected = [r1.clone(), r2.clone(), r3.clone()].concat(); + let expected_concat = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_concat.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_concat); + }); + + let harness = build_harness("8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.190:56090".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&r1).await.unwrap(); + client_side.write_all(&r2).await.unwrap(); + client_side.write_all(&r3).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Bu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 13, 600, 0x5B); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"half-close-probe"); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + + let mut tail = [0u8; 1]; + let n = stream.read(&mut tail).await.unwrap(); + assert_eq!( + n, 0, + "backend must observe EOF after client write half-close" + ); + }); + + let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.191:56091".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Cu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 14, 600, 0x5C); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"backend-half-close"); + let backend_response = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let expected_trailing = trailing_record.clone(); + let response = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + + stream.write_all(&response).await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let harness = build_harness("8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.192:56092".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Du8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 15, 600, 0x5D); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"backend-reset"); + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let harness = build_harness("8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.193:56093".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + let write_res = client_side.write_all(&trailing_record).await; + assert!( + write_res.is_ok() || write_res.is_err(), + "write completion is environment dependent under backend reset" + ); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Eu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 16, 600, 0x5E); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let payload = vec![0xEC; 8192]; + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + let mut offset = 0usize; + while offset < got_trailing.len() { + let step = (offset % 97).max(1).min(got_trailing.len() - offset); + stream + .read_exact(&mut got_trailing[offset..offset + step]) + .await + .unwrap(); + offset += step; + tokio::time::sleep(Duration::from_millis(1)).await; + } + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.194:56094".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhello() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Fu8; 16]; + let replayed_hello = make_valid_tls_client_hello(&secret, 17, 600, 0x5F); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"first-session"); + + let expected_second = replayed_hello.clone(); + let expected_trailing = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut s1, _) = listener.accept().await.unwrap(); + let mut got1_tail = vec![0u8; expected_trailing.len()]; + s1.read_exact(&mut got1_tail).await.unwrap(); + assert_eq!(got1_tail, expected_trailing); + drop(s1); + + let (mut s2, _) = listener.accept().await.unwrap(); + let mut got2 = vec![0u8; expected_second.len()]; + s2.read_exact(&mut got2).await.unwrap(); + assert_eq!(got2, expected_second); + }); + + let harness = build_harness("8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f", backend_addr.port()); + let stats_for_assert = harness.stats.clone(); + let bad_before = stats_for_assert.get_connects_bad(); + + let run_session = |hello: Vec, send_mtproto: bool| { + let (server_side, mut client_side) = duplex(131072); + let config = harness.config.clone(); + let stats = harness.stats.clone(); + let upstream = harness.upstream_manager.clone(); + let replay = harness.replay_checker.clone(); + let pool = harness.buffer_pool.clone(); + let rng = harness.rng.clone(); + let route = harness.route_runtime.clone(); + let ipt = harness.ip_tracker.clone(); + let beob = harness.beobachten.clone(); + let invalid_mtproto_record = invalid_mtproto_record.clone(); + let trailing_record = trailing_record.clone(); + async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.195:56095".parse().unwrap(), + config, + stats, + upstream, + replay, + pool, + rng, + None, + route, + None, + ipt, + beob, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + if send_mtproto { + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + } else { + let mut one = [0u8; 1]; + let no_server_hello = tokio::time::timeout( + Duration::from_millis(300), + client_side.read_exact(&mut one), + ) + .await; + assert!( + no_server_hello.is_err() || no_server_hello.unwrap().is_err(), + "replayed TLS hello must not receive authenticated TLS ServerHello" + ); + } + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } + }; + + run_session(replayed_hello.clone(), true).await; + run_session(replayed_hello.clone(), false).await; + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + let bad_after = stats_for_assert.get_connects_bad(); + assert!( + bad_after >= bad_before + 2, + "both invalid-mtproto and replayed-tls paths must increment bad connection accounting" + ); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x90u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 18, 600, 0x60); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let a = wrap_tls_application_data(&vec![0xA1; 2048]); + let b = wrap_tls_application_data(&vec![0xB2; 3072]); + let c = wrap_tls_application_data(&vec![0xC3; 1536]); + let expected = [a.clone(), b.clone(), c.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_payload.len()]; + let mut pos = 0usize; + while pos < got.len() { + let step = (pos % 257).max(1).min(got.len() - pos); + stream.read_exact(&mut got[pos..pos + step]).await.unwrap(); + pos += step; + tokio::time::sleep(Duration::from_millis(1)).await; + } + assert_eq!(got, expected_payload); + }); + + let harness = build_harness("90909090909090909090909090909090", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.196:56096".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + + let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; + for record in [&a, &b, &c] { + let mut pos = 0usize; + let mut idx = 0usize; + while pos < record.len() { + let step = chaos[idx % chaos.len()]; + let end = (pos + step).min(record.len()); + client_side.write_all(&record[pos..end]).await.unwrap(); + pos = end; + idx += 1; + } + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x91u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 19, 600, 0x61); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let ccs = wrap_tls_record(0x14, &[0x01]); + let app = wrap_tls_application_data(b"opaque"); + let alert = wrap_tls_record(0x15, &[0x01, 0x00]); + let expected = [ccs.clone(), app.clone(), alert.clone()].concat(); + let expected_records = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_records.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_records); + }); + + let harness = build_harness("91919191919191919191919191919191", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.197:56097".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&ccs).await.unwrap(); + client_side.write_all(&app).await.unwrap(); + client_side.write_all(&alert).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak() { + let sessions = 40usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected_records = std::collections::HashSet::new(); + let secret = [0x92u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); + let payload = vec![idx as u8; 33 + (idx % 17)]; + let record = wrap_tls_application_data(&payload); + expected_records.insert(record); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected_records; + for idx in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + + let _ = idx; + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "unexpected trailing TLS record in short-session chaos test" + ); + } + + assert!( + remaining.is_empty(), + "all expected sessions must be consumed exactly once" + ); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let harness = build_harness("92929292929292929292929292929292", backend_addr.port()); + let secret = [0x92u8; 16]; + let client_hello = + make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let payload = vec![idx as u8; 33 + (idx % 17)]; + let record = wrap_tls_application_data(&payload); + + let peer: SocketAddr = format!("198.51.100.198:{}", 58000 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + for chunk in record.chunks((idx % 9) + 1) { + client_side.write_all(chunk).await.unwrap(); + } + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_small_is_forwarded_as_tls_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 300, 600, 0x31); + let coalesced_tail = b"coalesced-tail-small".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.210:56110".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_large_is_forwarded_as_tls_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA2u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 301, 600, 0x32); + let coalesced_tail = vec![0xAB; 4096]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.211:56111".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_keeps_order_before_following_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 302, 600, 0x33); + let coalesced_tail = b"coalesced-first".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let following_record = wrap_tls_application_data(b"following-record"); + let expected_concat = [expected_tail_record.clone(), following_record.clone()].concat(); + let expected_records = expected_concat.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_records = vec![0u8; expected_records.len()]; + stream.read_exact(&mut got_records).await.unwrap(); + assert_eq!(got_records, expected_records); + }); + + let harness = build_harness("a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.212:56112".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.write_all(&following_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_fragmented_client_write_is_forwarded() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA4u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 303, 600, 0x34); + let coalesced_tail = vec![0xCD; 1536]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.213:56113".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let steps = [7usize, 3, 13, 5, 11, 2, 17, 19]; + let mut offset = 0usize; + let mut i = 0usize; + while offset < coalesced_record.len() { + let step = steps[i % steps.len()]; + let end = (offset + step).min(coalesced_record.len()); + client_side + .write_all(&coalesced_record[offset..end]) + .await + .unwrap(); + offset = end; + i += 1; + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_max_payload_is_forwarded() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA5u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 304, 600, 0x35); + let coalesced_tail = vec![0xEF; MAX_TLS_CIPHERTEXT_SIZE - HANDSHAKE_LEN]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.214:56114".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_identical_following_record_must_not_duplicate_or_reorder() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xB1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 400, 600, 0x21); + let tail = b"same-payload-record".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let tail_record = wrap_tls_application_data(&tail); + let expected = [tail_record.clone(), tail_record.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_payload.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_payload); + + let mut tail = [0u8; 1]; + let n = stream.read(&mut tail).await.unwrap(); + assert_eq!(n, 0, "fallback stream must not emit extra bytes"); + }); + + let harness = build_harness("b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.220:56120".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.write_all(&tail_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_tls_header_looking_bytes_must_stay_payload() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xB2u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 401, 600, 0x22); + let mut tail = vec![0x16, 0x03, 0x03, 0x00, 0x10]; + tail.extend_from_slice(b"not-a-real-record-boundary"); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail_record = wrap_tls_application_data(&tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.221:56121".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_client_half_close_must_not_truncate_prepended_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xB3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 402, 600, 0x23); + let tail = vec![0xAA; 3072]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail_record = wrap_tls_application_data(&tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!(n, 0, "backend must observe EOF after client half-close"); + }); + + let harness = build_harness("b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.222:56122".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_multi_session_no_cross_bleed_under_churn() { + let sessions = 16usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + let secret = [0xB4u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, 450 + idx as u32, 600, 0x40 + idx as u8); + let tail = vec![idx as u8; 17 + idx]; + expected.insert(wrap_tls_application_data(&tail)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "unexpected record or duplicated session routing" + ); + } + assert!(remaining.is_empty(), "all sessions must map one-to-one"); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let harness = build_harness("b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", backend_addr.port()); + let hello = make_valid_tls_client_hello(&secret, 450 + idx as u32, 600, 0x40 + idx as u8); + let tail = vec![idx as u8; 17 + idx]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let peer: SocketAddr = format!("198.51.100.223:{}", 56200 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + for chunk in coalesced_record.chunks((idx % 7) + 1) { + client_side.write_all(chunk).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_single_byte_tail_is_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 500, 600, 0x11); + let tail = vec![0x7F]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1", backend_addr.port()); + let (server_side, mut client_side) = duplex(65536); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.230:56130".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_exact_tls_header_size_payload_is_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC2u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 501, 600, 0x12); + let tail = vec![0xAA, 0xBB, 0xCC, 0xDD, 0xEE]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2", backend_addr.port()); + let (server_side, mut client_side) = duplex(65536); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.231:56131".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_all_zero_payload_is_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 502, 600, 0x13); + let tail = vec![0u8; 2048]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.232:56132".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_following_control_records_are_not_mutated() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC4u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 503, 600, 0x14); + let tail = b"tail-before-controls".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let tail_record = wrap_tls_application_data(&tail); + let ccs = wrap_tls_record(0x14, &[0x01]); + let alert = wrap_tls_record(0x15, &[0x01, 0x00]); + let app = wrap_tls_application_data(b"control-final-app"); + let expected = [tail_record, ccs.clone(), alert.clone(), app.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_payload = vec![0u8; expected_payload.len()]; + stream.read_exact(&mut got_payload).await.unwrap(); + assert_eq!(got_payload, expected_payload); + }); + + let harness = build_harness("c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.233:56133".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.write_all(&ccs).await.unwrap(); + client_side.write_all(&alert).await.unwrap(); + client_side.write_all(&app).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_then_following_records_fragmented_chaos_stays_ordered() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC5u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 504, 600, 0x15); + let tail = vec![0xAC; 900]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let tail_record = wrap_tls_application_data(&tail); + let r1 = wrap_tls_application_data(b"r1"); + let r2 = wrap_tls_application_data(&vec![0xDD; 257]); + let expected = [tail_record, r1.clone(), r2.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_payload = vec![0u8; expected_payload.len()]; + stream.read_exact(&mut got_payload).await.unwrap(); + assert_eq!(got_payload, expected_payload); + }); + + let harness = build_harness("c5c5c5c5c5c5c5c5c5c5c5c5c5c5c5c5", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.234:56134".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let pattern = [3usize, 1, 5, 2, 7, 11, 13, 17, 19]; + let mut idx = 0usize; + for data in [&coalesced_record, &r1, &r2] { + let mut pos = 0usize; + while pos < data.len() { + let step = pattern[idx % pattern.len()]; + idx += 1; + let end = (pos + step).min(data.len()); + client_side.write_all(&data[pos..end]).await.unwrap(); + pos = end; + } + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_backend_response_integrity_after_fallback() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC6u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 505, 600, 0x16); + let tail = b"coalesced-request-body".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let backend_response = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let expected_resp = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + + stream.write_all(&expected_resp).await.unwrap(); + }); + + let harness = build_harness("c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.235:56135".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + let mut observed = Vec::new(); + let mut buf = [0u8; 512]; + let mut found = false; + for _ in 0..32 { + let n = tokio::time::timeout(Duration::from_millis(200), client_side.read(&mut buf)) + .await + .unwrap() + .unwrap(); + if n == 0 { + break; + } + observed.extend_from_slice(&buf[..n]); + if observed + .windows(backend_response.len()) + .any(|w| w == backend_response.as_slice()) + { + found = true; + break; + } + } + assert!( + found, + "backend plaintext response must be observable on client stream after fallback" + ); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_connects_bad_increments_exactly_once() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC7u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 506, 600, 0x17); + let tail = b"count-bad-once".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + + let harness = build_harness("c7c7c7c7c7c7c7c7c7c7c7c7c7c7c7c7", backend_addr.port()); + let stats = harness.stats.clone(); + let bad_before = stats.get_connects_bad(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.236:56136".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + let bad_after = stats.get_connects_bad(); + assert_eq!( + bad_after, + bad_before + 1, + "invalid MTProto after valid TLS must increment connects_bad exactly once" + ); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_parallel_32_sessions_no_cross_bleed() { + let sessions = 32usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + let secret = [0xC8u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, 550 + idx as u32, 600, 0x20 + idx as u8); + let tail = vec![idx as u8; 48 + (idx % 11)]; + expected.insert(wrap_tls_application_data(&tail)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "session mixup detected in parallel-32 blackhat test" + ); + } + assert!( + remaining.is_empty(), + "all expected sessions must be consumed" + ); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let harness = build_harness("c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8", backend_addr.port()); + let hello = make_valid_tls_client_hello(&secret, 550 + idx as u32, 600, 0x20 + idx as u8); + let tail = vec![idx as u8; 48 + (idx % 11)]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let peer: SocketAddr = format!("198.51.100.237:{}", 56300 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let chunk = (idx % 13) + 1; + for part in coalesced_record.chunks(chunk) { + client_side.write_all(part).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_repeated_tls_like_prefixes_are_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC9u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 507, 600, 0x18); + let mut tail = Vec::new(); + for _ in 0..64 { + tail.extend_from_slice(&[0x16, 0x03, 0x03, 0x00, 0x20]); + } + tail.extend_from_slice(b"suffix-data"); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c9c9c9c9c9c9c9c9c9c9c9c9c9c9c9c9", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.238:56138".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_drop_after_write_still_delivers_prepended_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xCAu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 508, 600, 0x19); + let tail = vec![0xBE; 1024]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("cacacacacacacacacacacacacacacaca", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.239:56139".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + drop(client_side); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_zero_following_record_after_coalesced_is_not_invented() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xCBu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 509, 600, 0x1A); + let tail = b"terminal-tail".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!(n, 0, "no synthetic extra record must appear"); + }); + + let harness = build_harness("cbcbcbcbcbcbcbcbcbcbcbcbcbcbcbcb", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.240:56140".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_light_fuzz_mixed_followup_records_stay_byte_exact() { + let mut seed = 0xA11C_E2E5_F00D_BAADu64; + + for case in 0..24u32 { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let tail_len = (seed as usize % 1536) + 1; + let mut tail = vec![0u8; tail_len]; + for (i, b) in tail.iter_mut().enumerate() { + *b = (seed as u8).wrapping_add(i as u8).wrapping_mul(13); + } + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let follow_type = match seed & 0x3 { + 0 => TLS_RECORD_APPLICATION, + 1 => TLS_RECORD_ALERT, + 2 => TLS_RECORD_CHANGE_CIPHER, + _ => TLS_RECORD_HANDSHAKE, + }; + let follow_len = (seed as usize % 96) + (case as usize % 3); + let mut follow_payload = vec![0u8; follow_len]; + for (i, b) in follow_payload.iter_mut().enumerate() { + *b = (case as u8).wrapping_mul(29).wrapping_add(i as u8); + } + + let secret = [0xD1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 600 + case, 600, 0x33); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let follow_record = wrap_tls_record(follow_type, &follow_payload); + let expected_wire = [expected_tail.clone(), follow_record.clone()].concat(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_wire.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_wire); + }); + + let harness = build_harness("d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = format!("198.51.100.250:{}", 57000 + case as u16) + .parse() + .unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let mut local_seed = seed ^ 0x55AA_55AA_1234_5678; + for data in [&coalesced_record, &follow_record] { + let mut pos = 0usize; + while pos < data.len() { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + let step = ((local_seed as usize % 17) + 1).min(data.len() - pos); + let end = pos + step; + client_side.write_all(&data[pos..end]).await.unwrap(); + pos = end; + } + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } +} diff --git a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs new file mode 100644 index 0000000..08f52d1 --- /dev/null +++ b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs @@ -0,0 +1,37 @@ +use super::*; + +#[test] +fn wrap_tls_application_record_empty_payload_emits_zero_length_record() { + let record = wrap_tls_application_record(&[]); + assert_eq!(record.len(), 5); + assert_eq!(record[0], TLS_RECORD_APPLICATION); + assert_eq!(&record[1..3], &TLS_VERSION); + assert_eq!(&record[3..5], &0u16.to_be_bytes()); +} + +#[test] +fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() { + let total = (u16::MAX as usize) + 37; + let payload = vec![0xA5u8; total]; + let record = wrap_tls_application_record(&payload); + + let mut offset = 0usize; + let mut recovered = Vec::with_capacity(total); + let mut frames = 0usize; + + while offset + 5 <= record.len() { + assert_eq!(record[offset], TLS_RECORD_APPLICATION); + assert_eq!(&record[offset + 1..offset + 3], &TLS_VERSION); + let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize; + let body_start = offset + 5; + let body_end = body_start + len; + assert!(body_end <= record.len(), "declared TLS record length must be in-bounds"); + recovered.extend_from_slice(&record[body_start..body_end]); + offset = body_end; + frames += 1; + } + + assert_eq!(offset, record.len(), "record parser must consume exact output size"); + assert_eq!(frames, 2, "oversized payload should split into exactly two records"); + assert_eq!(recovered, payload, "chunked records must preserve full payload"); +} diff --git a/src/proxy/tests/direct_relay_business_logic_tests.rs b/src/proxy/tests/direct_relay_business_logic_tests.rs new file mode 100644 index 0000000..37f9897 --- /dev/null +++ b/src/proxy/tests/direct_relay_business_logic_tests.rs @@ -0,0 +1,56 @@ +use super::*; +use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4, TG_DATACENTERS_V6}; +use std::net::SocketAddr; + +#[test] +fn business_scope_hint_accepts_exact_boundary_length() { + let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN)); + assert_eq!( + validated_scope_hint(&value), + Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + ); +} + +#[test] +fn business_scope_hint_rejects_missing_prefix_even_when_charset_is_valid() { + assert_eq!(validated_scope_hint("alpha-01"), None); +} + +#[test] +fn business_known_dc_uses_ipv4_table_by_default() { + let cfg = ProxyConfig::default(); + let resolved = get_dc_addr_static(2, &cfg).expect("known dc must resolve"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_negative_dc_maps_by_absolute_value() { + let cfg = ProxyConfig::default(); + let resolved = + get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_known_dc_uses_ipv6_table_when_preferred_and_enabled() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + + let resolved = get_dc_addr_static(1, &cfg).expect("known dc must resolve on ipv6 path"); + let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_unknown_dc_uses_configured_default_dc_when_in_range() { + let mut cfg = ProxyConfig::default(); + cfg.default_dc = Some(4); + + let resolved = + get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} diff --git a/src/proxy/tests/direct_relay_common_mistakes_tests.rs b/src/proxy/tests/direct_relay_common_mistakes_tests.rs new file mode 100644 index 0000000..8429449 --- /dev/null +++ b/src/proxy/tests/direct_relay_common_mistakes_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4}; +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::Mutex; + +#[test] +fn common_invalid_override_entries_fallback_to_static_table() { + let mut cfg = ProxyConfig::default(); + cfg.dc_overrides.insert( + "2".to_string(), + vec!["bad-address".to_string(), "still-bad".to_string()], + ); + + let resolved = + get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides + .insert("3".to_string(), vec!["203.0.113.203:443".to_string()]); + + let resolved = + get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists"); + assert_eq!(resolved, "203.0.113.203:443".parse::().unwrap()); +} + +#[test] +fn common_scope_hint_rejects_unicode_lookalike_characters() { + assert_eq!(validated_scope_hint("scope_аlpha"), None); + assert_eq!(validated_scope_hint("scope_Αlpha"), None); +} + +#[cfg(unix)] +#[test] +fn common_anchored_open_rejects_nul_filename() { + use std::os::unix::ffi::OsStringExt; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-direct-relay-nul-{}", std::process::id())); + std::fs::create_dir_all(&parent).expect("parent directory must be creatable"); + + let path = SanitizedUnknownDcLogPath { + resolved_path: parent.join("placeholder.log"), + allowed_parent: parent, + file_name: std::ffi::OsString::from_vec(vec![b'a', 0, b'b']), + }; + + let err = open_unknown_dc_log_append_anchored(&path) + .expect_err("anchored open must fail on NUL in filename"); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); +} + +#[cfg(unix)] +#[test] +fn common_anchored_open_creates_owner_only_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-direct-relay-perm-{}", std::process::id())); + std::fs::create_dir_all(&parent).expect("parent directory must be creatable"); + + let sanitized = SanitizedUnknownDcLogPath { + resolved_path: parent.join("unknown-dc.log"), + allowed_parent: parent.clone(), + file_name: std::ffi::OsString::from("unknown-dc.log"), + }; + + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must create regular file"); + use std::io::Write; + writeln!(file, "dc_idx=1").expect("write must succeed"); + + let mode = std::fs::metadata(parent.join("unknown-dc.log")) + .expect("metadata must be readable") + .permissions() + .mode() + & 0o777; + assert_eq!(mode, 0o600); +} + +#[test] +fn common_duplicate_dc_attempts_do_not_consume_unique_slots() { + let set = Mutex::new(HashSet::new()); + + assert!(should_log_unknown_dc_with_set(&set, 100)); + assert!(!should_log_unknown_dc_with_set(&set, 100)); + assert!(should_log_unknown_dc_with_set(&set, 101)); + assert_eq!(set.lock().expect("set lock must be available").len(), 2); +} diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs new file mode 100644 index 0000000..16fe8da --- /dev/null +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -0,0 +1,1910 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::protocol::constants::ProtoTag; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::UpstreamManager; +use std::fs; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; +use tokio::io::AsyncReadExt; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration as TokioDuration, timeout}; + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn nonempty_line_count(text: &str) -> usize { + text.lines().filter(|line| !line.trim().is_empty()).count() +} + +#[test] +fn unknown_dc_log_is_deduplicated_per_dc_idx() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + assert!(should_log_unknown_dc(777)); + assert!( + !should_log_unknown_dc(777), + "same unknown dc_idx must not be logged repeatedly" + ); + assert!( + should_log_unknown_dc(778), + "different unknown dc_idx must still be loggable" + ); +} + +#[test] +fn unknown_dc_log_respects_distinct_limit() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + for dc in 1..=UNKNOWN_DC_LOG_DISTINCT_LIMIT { + assert!( + should_log_unknown_dc(dc as i16), + "expected first-time unknown dc_idx to be loggable" + ); + } + + assert!( + !should_log_unknown_dc(i16::MAX), + "distinct unknown dc_idx entries above limit must not be logged" + ); +} + +#[test] +fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { + let poisoned = Arc::new(std::sync::Mutex::new( + std::collections::HashSet::::new(), + )); + let poisoned_for_thread = poisoned.clone(); + + let _ = std::thread::spawn(move || { + let _guard = poisoned_for_thread + .lock() + .expect("poison setup lock must be available"); + panic!("intentional poison for fail-closed regression"); + }) + .join(); + + assert!( + !should_log_unknown_dc_with_set(poisoned.as_ref(), 4242), + "poisoned unknown-DC dedup lock must fail closed" + ); +} + +#[test] +fn unsafe_unknown_dc_log_path_does_not_consume_dedup_slot() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_123; + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some("../telemt-unknown-dc-unsafe.log".to_string()); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + assert!( + should_log_unknown_dc(dc_idx), + "rejected unsafe log path must not consume unknown-dc dedup entry" + ); +} + +#[test] +fn stress_unknown_dc_log_concurrent_unique_churn_respects_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let accepted = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + + // Adversarial model: many concurrent peers rotate dc_idx values rapidly. + for worker in 0..16usize { + let accepted = Arc::clone(&accepted); + workers.push(std::thread::spawn(move || { + let base = (worker * 2048) as i32; + for offset in 0..512i32 { + let raw = base + offset; + let dc = (raw % i16::MAX as i32) as i16; + if should_log_unknown_dc(dc) { + accepted.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must not panic"); + } + + assert_eq!( + accepted.load(Ordering::Relaxed), + UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "concurrent unique churn must never admit more than the configured distinct cap" + ); +} + +#[test] +fn light_fuzz_unknown_dc_log_mixed_duplicates_never_exceeds_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + // Deterministic xorshift sequence for reproducible mixed duplicate fuzzing. + let mut s: u64 = 0xA5A5_5A5A_C3C3_3C3C; + let mut admitted = 0usize; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let dc = (s as i16).wrapping_sub(i16::MAX / 2); + if should_log_unknown_dc(dc) { + admitted += 1; + } + } + + assert!( + admitted <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "mixed-duplicate fuzzed inputs must not admit more than cap" + ); +} + +#[test] +fn scope_hint_accepts_ascii_alnum_and_dash_within_limit() { + assert_eq!(validated_scope_hint("scope_alpha-1"), Some("alpha-1")); + assert_eq!(validated_scope_hint("scope_AZ09"), Some("AZ09")); +} + +#[test] +fn scope_hint_rejects_invalid_or_oversized_values() { + assert_eq!(validated_scope_hint("plain_user"), None); + assert_eq!(validated_scope_hint("scope_"), None); + assert_eq!(validated_scope_hint("scope_a/b"), None); + assert_eq!(validated_scope_hint("scope_bad space"), None); + assert_eq!(validated_scope_hint("scope_bad.dot"), None); + + let oversized = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN + 1)); + assert_eq!(validated_scope_hint(&oversized), None); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_parent_traversal_inputs() { + assert!( + sanitize_unknown_dc_log_path("../unknown-dc.txt").is_none(), + "parent traversal paths must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("logs/../unknown-dc.txt").is_none(), + "embedded parent traversal must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("./../unknown-dc.txt").is_none(), + "relative parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_absolute_paths_with_existing_parent() { + let absolute = std::env::temp_dir().join("unknown-dc.txt"); + let absolute_str = absolute + .to_str() + .expect("temp absolute path must be valid UTF-8"); + + let sanitized = sanitize_unknown_dc_log_path(absolute_str) + .expect("absolute paths with existing parent must be accepted"); + assert_eq!(sanitized.resolved_path, absolute); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_absolute_parent_traversal() { + assert!( + sanitize_unknown_dc_log_path("/tmp/../etc/passwd").is_none(), + "absolute parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-{}", std::process::id())); + fs::create_dir_all(&base).expect("temp test directory must be creatable"); + + let candidate = base.join("unknown-dc.txt"); + let candidate_relative = format!( + "target/telemt-unknown-dc-log-{}/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) + .expect("safe relative path with existing parent must be accepted"); + assert_eq!(sanitized.resolved_path, candidate); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_empty_or_dot_only_inputs() { + assert!( + sanitize_unknown_dc_log_path("").is_none(), + "empty path must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path(".").is_none(), + "dot-only path without filename must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_directory_only_as_filename_projection() { + let sanitized = sanitize_unknown_dc_log_path("target/") + .expect("directory-only input is interpreted as filename projection in current sanitizer"); + assert!( + sanitized.resolved_path.ends_with("target"), + "directory-only input should resolve to canonical parent plus filename projection" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_dot_prefixed_relative_path() { + let rel_dir = format!("target/telemt-unknown-dc-dot-{}", std::process::id()); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("dot-prefixed test directory must be creatable"); + + let rel_candidate = format!("./{rel_dir}/unknown-dc.log"); + let expected = abs_dir.join("unknown-dc.log"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("dot-prefixed safe path must be accepted"); + assert_eq!(sanitized.resolved_path, expected); +} + +#[test] +fn light_fuzz_unknown_dc_path_parentdir_inputs_always_rejected() { + let mut s: u64 = 0xD00D_BAAD_1234_5678; + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let a = (s as usize) % 32; + let b = ((s >> 8) as usize) % 32; + let candidate = format!("target/{a}/../{b}/unknown-dc.log"); + assert!( + sanitize_unknown_dc_log_path(&candidate).is_none(), + "parent-dir candidate must be rejected: {candidate}" + ); + } +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_nonexistent_parent_directory() { + let rel_candidate = format!( + "target/telemt-unknown-dc-missing-{}/nested/unknown-dc.txt", + std::process::id() + ); + + assert!( + sanitize_unknown_dc_log_path(&rel_candidate).is_none(), + "path with missing parent must be rejected to avoid implicit directory creation" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-log-symlink-internal-{}", + std::process::id() + )); + let real_parent = base.join("real_parent"); + fs::create_dir_all(&real_parent).expect("real parent dir must be creatable"); + + let symlink_parent = base.join("internal_link"); + let _ = fs::remove_file(&symlink_parent); + symlink(&real_parent, &symlink_parent).expect("internal symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-internal-{}/internal_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent that resolves inside workspace must be accepted"); + assert!( + sanitized.resolved_path.starts_with(&real_parent), + "sanitized path must resolve to canonical internal parent" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-log-symlink-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("symlink test directory must be creatable"); + + let symlink_parent = base.join("escape_link"); + let _ = fs::remove_file(&symlink_parent); + symlink("/tmp", &symlink_parent).expect("symlink parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-{}/escape_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent must canonicalize to target path"); + assert!( + sanitized.resolved_path.starts_with(Path::new("/tmp")), + "sanitized path must resolve to canonical symlink target" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_symlinked_target_escape() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-target-link-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("target-link base must be creatable"); + + let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id())); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("target symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-target-link-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate should sanitize before final revalidation"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "final revalidation must reject symlinked target escape" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-nofollow-{}", std::process::id())); + fs::create_dir_all(&base).expect("nofollow base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-nofollow-outside-{}.log", + std::process::id() + )); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_broken_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-broken-link-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("broken-link base must be creatable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(base.join("missing-target.log"), &linked_target) + .expect("broken symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for broken symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "broken symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_unknown_dc_open_append_symlink_flip_never_writes_outside_file() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-symlink-flip-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("symlink-flip base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-symlink-flip-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside-baseline\n").expect("outside baseline file must be writable"); + let outside_before = fs::read_to_string(&outside).expect("outside baseline must be readable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + for step in 0..1024usize { + let _ = fs::remove_file(&target); + if step % 2 == 0 { + symlink(&outside, &target).expect("symlink creation in flip loop must succeed"); + } + if let Ok(mut file) = open_unknown_dc_log_append(&target) { + writeln!(file, "dc_idx={step}").expect("append on regular file must succeed"); + } + } + + let outside_after = fs::read_to_string(&outside).expect("outside file must remain readable"); + assert_eq!( + outside_after, outside_before, + "outside file must never be modified under symlink-flip adversarial churn" + ); +} + +#[test] +fn unknown_dc_open_append_creates_regular_file() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-open-{}", std::process::id())); + fs::create_dir_all(&base).expect("open test base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + { + let mut file = open_unknown_dc_log_append(&target) + .expect("regular target must be creatable with append open"); + writeln!(file, "dc_idx=1234").expect("append write must succeed"); + } + + let meta = fs::symlink_metadata(&target).expect("created target metadata must be readable"); + assert!(meta.file_type().is_file(), "target must be a regular file"); + assert!( + !meta.file_type().is_symlink(), + "regular target open path must not produce symlink artifacts" + ); +} + +#[test] +fn stress_unknown_dc_open_append_regular_file_preserves_line_integrity() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-open-stress-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("stress open base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + let writes = 2048usize; + for idx in 0..writes { + let mut file = open_unknown_dc_log_append(&target) + .expect("stress append open on regular file must succeed"); + writeln!(file, "dc_idx={idx}").expect("stress append write must succeed"); + } + + let content = fs::read_to_string(&target).expect("stress output file must be readable"); + assert_eq!( + nonempty_line_count(&content), + writes, + "regular-file append stress must preserve one logical line per write" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-safe-target-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("safe target base must be creatable"); + + let target = base.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("safe target seed write must succeed"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-safe-target-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("safe candidate must sanitize"); + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must allow safe existing regular files" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_rejects_deleted_parent_after_sanitize() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-vanish-parent-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("vanish-parent base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-vanish-parent-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent deletion"); + + fs::remove_dir_all(&base).expect("test parent directory must be removable"); + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when sanitized parent disappears before write" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-parent-swap-{}", + std::process::id() + )); + fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-parent-swap-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent swap"); + + let moved = parent.with_extension("bak"); + let _ = fs::remove_dir_all(&moved); + fs::rename(&parent, &moved).expect("parent must be movable for swap simulation"); + symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when canonical parent is swapped to a symlinked target" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-check-open-race-{}", + std::process::id() + )); + fs::create_dir_all(&parent).expect("check-open-race parent must be creatable"); + + let target = parent.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("seed target file must be writable"); + let rel_candidate = format!( + "target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "precondition: target should initially pass revalidation" + ); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-check-open-race-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + fs::remove_file(&target).expect("target removal before flip must succeed"); + symlink(&outside, &target).expect("target symlink flip must be creatable"); + + let err = open_unknown_dc_log_append(&sanitized.resolved_path) + .expect_err("nofollow open must fail after symlink flip between check and open"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink flip in check/open window must be neutralized by O_NOFOLLOW" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-parent-swap-openat-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-parent-swap-openat-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent swap"); + fs::write(&sanitized.resolved_path, "seed\n").expect("seed target file must be writable"); + + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "precondition: target should initially pass revalidation" + ); + + let outside_parent = std::env::temp_dir().join(format!( + "telemt-unknown-dc-parent-swap-openat-outside-{}", + std::process::id() + )); + fs::create_dir_all(&outside_parent).expect("outside parent directory must be creatable"); + let outside_target = outside_parent.join("unknown-dc.log"); + let _ = fs::remove_file(&outside_target); + + let moved = base.with_extension("bak"); + let _ = fs::remove_dir_all(&moved); + fs::rename(&base, &moved).expect("base parent must be movable for swap simulation"); + symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable"); + + let err = open_unknown_dc_log_append_anchored(&sanitized) + .expect_err("anchored open must fail when parent is swapped to symlink"); + let raw = err.raw_os_error(); + assert!( + matches!( + raw, + Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT) + ), + "anchored open must fail closed on parent swap race, got raw_os_error={raw:?}" + ); + assert!( + !outside_target.exists(), + "anchored open must never create a log file in swapped outside parent" + ); +} + +#[cfg(unix)] +#[test] +fn anchored_open_nix_path_writes_expected_lines() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-open-ok-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-open-ok base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let mut first = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must create log file in allowed parent"); + append_unknown_dc_line(&mut first, 31_200).expect("first append must succeed"); + + let mut second = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored reopen must succeed for existing regular file"); + append_unknown_dc_line(&mut second, 31_201).expect("second append must succeed"); + + let content = + fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable"); + let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + assert_eq!(lines.len(), 2, "expected one line per anchored append call"); + assert!( + lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"), + "anchored append output must contain both expected dc_idx lines" + ); +} + +#[cfg(unix)] +#[test] +fn anchored_open_parallel_appends_preserve_line_integrity() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-open-parallel-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-open-parallel base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let mut workers = Vec::new(); + for idx in 0..64i16 { + let sanitized = sanitized.clone(); + workers.push(std::thread::spawn(move || { + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must succeed in worker"); + append_unknown_dc_line(&mut file, 32_000 + idx).expect("worker append must succeed"); + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + let content = + fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable"); + let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + assert_eq!(lines.len(), 64, "expected one complete line per worker append"); + for line in lines { + assert!( + line.starts_with("dc_idx="), + "line must keep dc_idx prefix and not be interleaved: {line}" + ); + let value = line + .strip_prefix("dc_idx=") + .expect("prefix checked above") + .parse::(); + assert!( + value.is_ok(), + "line payload must remain parseable i16 and not be corrupted: {line}" + ); + } +} + +#[cfg(unix)] +#[test] +fn anchored_open_creates_private_0600_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-perms-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-perms base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must create file with restricted mode"); + append_unknown_dc_line(&mut file, 31_210).expect("initial append must succeed"); + drop(file); + + let mode = fs::metadata(&sanitized.resolved_path) + .expect("created log file metadata must be readable") + .permissions() + .mode() + & 0o777; + assert_eq!( + mode, 0o600, + "anchored open must create unknown-dc log file with owner-only rw permissions" + ); +} + +#[cfg(unix)] +#[test] +fn anchored_open_rejects_existing_symlink_target() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-symlink-target-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-symlink-target base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-anchored-symlink-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside\n").expect("outside baseline file must be writable"); + + let _ = fs::remove_file(&sanitized.resolved_path); + symlink(&outside, &sanitized.resolved_path) + .expect("target symlink for anchored-open rejection test must be creatable"); + + let err = open_unknown_dc_log_append_anchored(&sanitized) + .expect_err("anchored open must reject symlinked filename target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "anchored open should fail closed with ELOOP on symlinked target" + ); +} + +#[cfg(unix)] +#[test] +fn anchored_open_high_contention_multi_write_preserves_complete_lines() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-contention-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-contention base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let workers = 24usize; + let rounds = 40usize; + let mut threads = Vec::new(); + + for worker in 0..workers { + let sanitized = sanitized.clone(); + threads.push(std::thread::spawn(move || { + for round in 0..rounds { + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must succeed under contention"); + let dc_idx = 20_000i16.wrapping_add((worker * rounds + round) as i16); + append_unknown_dc_line(&mut file, dc_idx) + .expect("each contention append must complete"); + } + })); + } + + for thread in threads { + thread.join().expect("contention worker must not panic"); + } + + let content = fs::read_to_string(&sanitized.resolved_path) + .expect("contention output file must be readable"); + let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + assert_eq!( + lines.len(), + workers * rounds, + "every contention append must produce exactly one line" + ); + + let mut unique = std::collections::HashSet::new(); + for line in lines { + assert!( + line.starts_with("dc_idx="), + "line must preserve expected prefix under heavy contention: {line}" + ); + let value = line + .strip_prefix("dc_idx=") + .expect("prefix validated") + .parse::() + .expect("line payload must remain parseable i16 under contention"); + unique.insert(value); + } + + assert_eq!( + unique.len(), + workers * rounds, + "contention output must not lose or duplicate logical writes" + ); +} + +#[cfg(unix)] +#[test] +fn append_unknown_dc_line_returns_error_for_read_only_descriptor() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-append-ro-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("append-ro base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-append-ro-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable"); + + let mut readonly = std::fs::OpenOptions::new() + .read(true) + .open(&sanitized.resolved_path) + .expect("readonly file open must succeed"); + + append_unknown_dc_line(&mut readonly, 31_222) + .expect_err("append on readonly descriptor must fail closed"); + + let content_after = + fs::read_to_string(&sanitized.resolved_path).expect("seed file must remain readable"); + assert_eq!( + nonempty_line_count(&content_after), + 1, + "failed readonly append must not modify persisted unknown-dc log content" + ); +} + +#[tokio::test] +async fn unknown_dc_absolute_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_001; + let file_path = std::env::temp_dir().join(format!( + "telemt-unknown-dc-abs-{}-{}.log", + std::process::id(), + dc_idx + )); + let _ = fs::remove_file(&file_path); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some( + file_path + .to_str() + .expect("temp file path must be valid UTF-8") + .to_string(), + ); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&file_path) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("absolute unknown-DC log path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "absolute unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_safe_relative_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_002; + let rel_dir = format!("target/telemt-unknown-dc-int-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("integration test log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("safe relative path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_same_index_burst_writes_only_once() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_010; + let rel_dir = format!("target/telemt-unknown-dc-same-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("same-index log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for _ in 0..64 { + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut content = None; + for _ in 0..30 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let text = content.expect("same-index burst must produce at least one log write"); + assert_eq!( + nonempty_line_count(&text), + 1, + "same unknown dc index must be deduplicated to one file line" + ); +} + +#[tokio::test] +async fn unknown_dc_distinct_burst_is_hard_capped_on_file_writes() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-unknown-dc-cap-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("cap log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for i in 0..(UNKNOWN_DC_LOG_DISTINCT_LIMIT + 128) { + let dc_idx = 20_000i16.wrapping_add(i as i16); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut final_text = String::new(); + for _ in 0..80 { + if let Ok(text) = fs::read_to_string(&abs_file) { + final_text = text; + if nonempty_line_count(&final_text) >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + break; + } + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let line_count = nonempty_line_count(&final_text); + assert!( + line_count > 0, + "distinct unknown-dc burst must write at least one line" + ); + assert!( + line_count <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "distinct unknown-dc writes must stay within dedup hard cap" + ); +} + +#[cfg(unix)] +#[tokio::test] +async fn unknown_dc_symlinked_target_escape_is_not_written_integration() { + use std::os::unix::fs::symlink; + + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-no-write-link-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("integration symlink base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "baseline\n").expect("outside baseline file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let rel_file = format!( + "target/telemt-unknown-dc-no-write-link-{}/unknown-dc.log", + std::process::id() + ); + let dc_idx: i16 = 31_050; + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let before = fs::read_to_string(&outside).expect("must read baseline outside file"); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + tokio::time::sleep(Duration::from_millis(80)).await; + let after = fs::read_to_string(&outside).expect("must read outside file after attempt"); + + assert_eq!( + after, before, + "symlink target escape must not be written by unknown-DC logging" + ); +} + +#[test] +fn fallback_dc_never_panics_with_single_dc_list() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.default_dc = Some(42); + + let addr = get_dc_addr_static(999, &cfg).expect("fallback dc must resolve safely"); + let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); + assert_eq!(addr, expected); +} + +#[tokio::test] +async fn direct_relay_abort_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50000".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xabad1dea, + )); + + let started = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await; + assert!( + started.is_ok(), + "direct relay must increment route gauge before abort" + ); + + relay_task.abort(); + let joined = relay_task.await; + assert!( + joined.is_err(), + "aborted direct relay task must return join error" + ); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay task is aborted mid-flight" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn direct_relay_cutover_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50002".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xface_cafe, + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("direct relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("direct relay must terminate after cutover") + .expect("direct relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate direct relay session" + ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay exits on cutover" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let mut held_streams = Vec::with_capacity(session_count); + for _ in 0..session_count { + let (stream, _) = tg_listener.accept().await.unwrap(); + held_streams.push(stream); + } + tokio::time::sleep(Duration::from_secs(60)).await; + drop(held_streams); + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-direct-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 51000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xA000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(Duration::from_secs(4), async { + loop { + if stats.get_current_connections_direct() == session_count as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all direct sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Middle + } else { + RelayRouteMode::Direct + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(Duration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(Duration::from_secs(10), relay_task) + .await + .expect("direct relay task must finish under cutover storm") + .expect("direct relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all direct sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after cutover storm" + ); + + drop(client_sides); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[test] +fn prefer_v6_override_matrix_prefers_matching_family_then_degrades_safely() { + let dc_idx: i16 = 2; + + let mut cfg_a = ProxyConfig::default(); + cfg_a.network.prefer = 6; + cfg_a.network.ipv6 = Some(true); + cfg_a.dc_overrides.insert( + dc_idx.to_string(), + vec![ + "203.0.113.90:443".to_string(), + "[2001:db8::90]:443".to_string(), + ], + ); + let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve"); + assert!( + a.is_ipv6(), + "prefer_v6 should choose v6 override when present" + ); + + let mut cfg_b = ProxyConfig::default(); + cfg_b.network.prefer = 6; + cfg_b.network.ipv6 = Some(true); + cfg_b + .dc_overrides + .insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]); + let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve"); + assert!( + b.is_ipv4(), + "when no v6 override exists, v4 override must be used" + ); + + let mut cfg_c = ProxyConfig::default(); + cfg_c.network.prefer = 6; + cfg_c.network.ipv6 = Some(true); + let c = get_dc_addr_static(dc_idx, &cfg_c).expect("table fallback must resolve"); + assert_eq!( + c, + SocketAddr::new(TG_DATACENTERS_V6[(dc_idx as usize) - 1], TG_DATACENTER_PORT), + "without overrides, prefer_v6 path must resolve from static v6 datacenter table" + ); +} + +#[test] +fn prefer_v6_override_matrix_ignores_invalid_entries_and_keeps_fail_closed_fallback() { + let dc_idx: i16 = 3; + + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides.insert( + dc_idx.to_string(), + vec![ + "not-an-addr".to_string(), + "also:bad".to_string(), + "203.0.113.55:443".to_string(), + ], + ); + + let addr = get_dc_addr_static(dc_idx, &cfg) + .expect("at least one valid override must keep resolution alive"); + assert_eq!(addr, "203.0.113.55:443".parse::().unwrap()); +} + +#[test] +fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() { + for idx in 1..=5i16 { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides.insert( + idx.to_string(), + vec![ + format!("203.0.113.{}:443", 100 + idx), + format!("[2001:db8::{}]:443", 100 + idx), + ], + ); + + let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve"); + let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve"); + assert_eq!( + first, second, + "override resolution must stay deterministic for dc {idx}" + ); + assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); + } +} + +#[tokio::test] +async fn negative_direct_relay_dc_connection_refused_fails_fast() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Reserve an ephemeral port and immediately release it to deterministically + // exercise the direct-connect failure path without long-lived hangs. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + drop(listener); + + let mut config_with_override = ProxyConfig::default(); + config_with_override + .dc_overrides + .insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let result = timeout( + TokioDuration::from_secs(2), + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xABCD_1234, + ), + ) + .await + .expect("direct relay must fail fast on connection-refused upstream"); + + assert!( + result.is_err(), + "connection-refused upstream must fail closed" + ); +} + +#[tokio::test] +async fn adversarial_direct_relay_cutover_integrity() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Mock upstream server. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + // Read handshake nonce. + let mut nonce = [0u8; 64]; + let _ = stream.read_exact(&mut nonce).await; + // Keep connection open. + tokio::time::sleep(TokioDuration::from_secs(5)).await; + }); + + let mut config_with_override = ProxyConfig::default(); + config_with_override + .dc_overrides + .insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let stats_for_task = stats.clone(); + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(async move { + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats_for_task, + config, + buffer_pool, + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xABCD_1234, + ) + .await + }); + + timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("direct relay session must start before cutover"); + + // Trigger cutover. + route_runtime.set_mode(RelayRouteMode::Middle).unwrap(); + + // The session should terminate after the staggered delay (1000-2000ms). + let result = timeout(TokioDuration::from_secs(5), session_task) + .await + .expect("Session must terminate after cutover") + .expect("Session must not panic"); + + assert!( + matches!( + result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "Session must terminate with route switch error on cutover" + ); +} diff --git a/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs new file mode 100644 index 0000000..325cffd --- /dev/null +++ b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs @@ -0,0 +1,200 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +fn nonempty_line_count(text: &str) -> usize { + text.lines().filter(|line| !line.trim().is_empty()).count() +} + +#[test] +fn subtle_stress_single_unknown_dc_under_concurrency_logs_once() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let winners = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + + for _ in 0..128 { + let winners = Arc::clone(&winners); + workers.push(std::thread::spawn(move || { + if should_log_unknown_dc(31_333) { + winners.fetch_add(1, Ordering::Relaxed); + } + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + assert_eq!(winners.load(Ordering::Relaxed), 1); +} + +#[test] +fn subtle_light_fuzz_scope_hint_matches_oracle() { + fn oracle(input: &str) -> bool { + let Some(rest) = input.strip_prefix("scope_") else { + return false; + }; + !rest.is_empty() + && rest.len() <= MAX_SCOPE_HINT_LEN + && rest.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-') + } + + let mut state: u64 = 0xC0FF_EE11_D15C_AFE5; + for _ in 0..4_096 { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + + let len = (state as usize % 72) + 1; + let mut s = String::with_capacity(len + 6); + if (state & 1) == 0 { + s.push_str("scope_"); + } else { + s.push_str("user_"); + } + + for idx in 0..len { + let v = ((state >> ((idx % 8) * 8)) & 0xff) as u8; + let ch = match v % 6 { + 0 => (b'a' + (v % 26)) as char, + 1 => (b'A' + (v % 26)) as char, + 2 => (b'0' + (v % 10)) as char, + 3 => '-', + 4 => '_', + _ => '.', + }; + s.push(ch); + } + + let got = validated_scope_hint(&s).is_some(); + assert_eq!(got, oracle(&s), "mismatch for input: {s}"); + } +} + +#[test] +fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() { + let mut state: u64 = 0x1234_5678_9ABC_DEF0; + + for _ in 0..2_048 { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = if (state & 1) == 0 { 4 } else { 6 }; + cfg.network.ipv6 = Some((state & 2) != 0); + cfg.default_dc = Some(((state >> 8) as u8).max(1)); + + let dc_idx = (state as i16).wrapping_sub(16_384); + let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail"); + + assert_eq!( + resolved.port(), + crate::protocol::constants::TG_DATACENTER_PORT + ); + let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true); + assert_eq!(resolved.is_ipv6(), expect_v6); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn subtle_integration_parallel_same_dc_logs_one_line() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-direct-relay-same-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = std::fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let cfg = Arc::new(cfg); + let mut tasks = Vec::new(); + for _ in 0..32 { + let cfg = Arc::clone(&cfg); + tasks.push(tokio::spawn(async move { + let _ = get_dc_addr_static(31_777, cfg.as_ref()); + })); + } + for task in tasks { + task.await.expect("task must not panic"); + } + + for _ in 0..60 { + if let Ok(content) = std::fs::read_to_string(&abs_file) + && nonempty_line_count(&content) == 1 + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + + let content = std::fs::read_to_string(&abs_file).unwrap_or_default(); + assert_eq!(nonempty_line_count(&content), 1); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn subtle_integration_parallel_unique_dcs_log_unique_lines() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-direct-relay-unique-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = std::fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let cfg = Arc::new(cfg); + let dcs = [ + 31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908, + ]; + let mut tasks = Vec::new(); + + for dc in dcs { + let cfg = Arc::clone(&cfg); + tasks.push(tokio::spawn(async move { + let _ = get_dc_addr_static(dc, cfg.as_ref()); + })); + } + + for task in tasks { + task.await.expect("task must not panic"); + } + + for _ in 0..80 { + if let Ok(content) = std::fs::read_to_string(&abs_file) + && nonempty_line_count(&content) >= 8 + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + + let content = std::fs::read_to_string(&abs_file).unwrap_or_default(); + assert!( + nonempty_line_count(&content) >= 8, + "expected at least one line per unique dc, content: {content}" + ); +} diff --git a/src/proxy/tests/handshake_adversarial_tests.rs b/src/proxy/tests/handshake_adversarial_tests.rs new file mode 100644 index 0000000..93832f7 --- /dev/null +++ b/src/proxy/tests/handshake_adversarial_tests.rs @@ -0,0 +1,563 @@ +use super::*; +use crate::crypto::sha256; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg +} + +// ------------------------------------------------------------------ +// Mutational Bit-Flipping Tests (OWASP ASVS 5.1.4) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn mtproto_handshake_bit_flip_anywhere_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "11223344556677889900aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + // Baseline check + let res = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + match res { + HandshakeResult::Success(_) => {} + _ => panic!("Baseline failed: expected Success"), + } + + // Flip bits in the encrypted part (beyond the key material) + for byte_pos in SKIP_LEN..HANDSHAKE_LEN { + let mut h = base; + h[byte_pos] ^= 0x01; // Flip 1 bit + let res = handle_mtproto_handshake( + &h, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Flip at byte {byte_pos} bit 0 must be rejected" + ); + } +} + +// ------------------------------------------------------------------ +// Adversarial Probing / Timing Neutrality (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn mtproto_handshake_timing_neutrality_mocked() { + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); + + const ITER: usize = 50; + + let mut start = Instant::now(); + for _ in 0..ITER { + let _ = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + } + let duration_success = start.elapsed(); + + start = Instant::now(); + for i in 0..ITER { + let mut h = base; + h[SKIP_LEN + (i % 48)] ^= 0xFF; + let _ = handle_mtproto_handshake( + &h, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + } + let duration_fail = start.elapsed(); + + let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64) + .abs() + / ITER as f64; + + // Threshold (loose for CI) + assert!( + avg_diff_ms < 100.0, + "Timing difference too large: {} ms/iter", + avg_diff_ms + ); +} + +// ------------------------------------------------------------------ +// Stress Tests (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn auth_probe_throttle_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + + // Record enough failures for one IP to trigger backoff + let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(target_ip, now); + } + + assert!(auth_probe_is_throttled(target_ip, now)); + + // Stress test with many unique IPs + for i in 0..500u32 { + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 256) as u8)); + auth_probe_record_failure(ip, now); + } + + let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}" + ); +} + +#[tokio::test] +async fn mtproto_handshake_abridged_prefix_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + handshake[0] = 0xef; // Abridged prefix + let config = ProxyConfig::default(); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + // MTProxy stops immediately on 0xef + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn mtproto_handshake_preferred_user_mismatch_continues() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret1_hex = "11111111111111111111111111111111"; + let secret2_hex = "22222222222222222222222222222222"; + + let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1); + let mut config = ProxyConfig::default(); + config + .access + .users + .insert("user1".to_string(), secret1_hex.to_string()); + config + .access + .users + .insert("user2".to_string(), secret2_hex.to_string()); + config.access.ignore_time_skew = true; + config.general.modes.secure = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap(); + + // Even if we prefer user1, if user2 matches, it should succeed. + let res = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + Some("user1"), + ) + .await; + if let HandshakeResult::Success((_, _, success)) = res { + assert_eq!(success.user, "user2"); + } else { + panic!("Handshake failed even though user2 matched"); + } +} + +#[tokio::test] +async fn mtproto_handshake_concurrent_flood_stability() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut config = test_config_with_secret_hex(secret_hex); + config.access.ignore_time_skew = true; + let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); + let config = Arc::new(config); + + let mut tasks = Vec::new(); + for i in 0..50 { + let base = base; + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap(); + + tasks.push(tokio::spawn(async move { + let res = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + matches!(res, HandshakeResult::Success(_)) + })); + } + + // We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk), + // but the system must not panic or hang. + for task in tasks { + let _ = task.await.unwrap(); + } +} + +#[tokio::test] +async fn mtproto_replay_is_rejected_across_distinct_peers() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "0123456789abcdeffedcba9876543210"; + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + + let first_peer: SocketAddr = "198.51.100.10:41001".parse().unwrap(); + let second_peer: SocketAddr = "198.51.100.11:41002".parse().unwrap(); + + let first = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + second_peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(replay, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "89abcdef012345670123456789abcdef"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(8192, Duration::from_secs(60)); + + for i in 0..512usize { + let mut mutated = base; + let pos = (SKIP_LEN + (i * 31) % (HANDSHAKE_LEN - SKIP_LEN)).min(HANDSHAKE_LEN - 1); + mutated[pos] ^= ((i as u8) | 1).rotate_left((i % 8) as u32); + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, (i / 254) as u8, (i % 254 + 1) as u8)), + 42000 + (i % 1000) as u16, + ); + + let res = tokio::time::timeout( + Duration::from_millis(250), + handle_mtproto_handshake( + &mutated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed mutation must complete in bounded time"); + + assert!( + matches!( + res, + HandshakeResult::BadClient { .. } | HandshakeResult::Success(_) + ), + "mutation corpus must stay within explicit handshake outcomes" + ); + } +} + +#[tokio::test] +async fn auth_probe_success_clears_throttled_peer_state() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let target_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 90)); + let now = Instant::now(); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(target_ip, now); + } + assert!(auth_probe_is_throttled(target_ip, now)); + + auth_probe_record_success(target_ip); + assert!( + !auth_probe_is_throttled(target_ip, now + Duration::from_millis(1)), + "successful auth must clear per-peer throttle state" + ); +} + +#[tokio::test] +async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let mut invalid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + invalid[SKIP_LEN + 3] ^= 0xff; + + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(64, Duration::from_secs(60)); + + for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) { + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new( + 10, + (i / 65535) as u8, + ((i / 255) % 255) as u8, + (i % 255 + 1) as u8, + )), + 43000 + (i % 20000) as u16, + ); + let res = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. })); + } + + let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "probe map must remain bounded under invalid storm: {tracked}" + ); +} + +#[tokio::test] +async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "f0e1d2c3b4a5968778695a4b3c2d1e0f"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(10_000, Duration::from_secs(60)); + + let mut seed: u64 = 0xC0FF_EE12_3456_789A; + for i in 0..2_048usize { + let mut mutated = base; + for _ in 0..4 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let idx = SKIP_LEN + (seed as usize % (HANDSHAKE_LEN - SKIP_LEN)); + mutated[idx] ^= ((seed >> 11) as u8).wrapping_add(1); + } + + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 123, (i / 254) as u8, (i % 254 + 1) as u8)), + 45000 + (i % 2000) as u16, + ); + + let outcome = tokio::time::timeout( + Duration::from_millis(250), + handle_mtproto_handshake( + &mutated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("mutation iteration must complete in bounded time"); + + assert!( + matches!( + outcome, + HandshakeResult::BadClient { .. } | HandshakeResult::Success(_) + ), + "mutations must remain fail-closed/auth-only" + ); + } +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn mtproto_blackhat_20k_mutation_soak_never_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(50_000, Duration::from_secs(120)); + + let mut seed: u64 = 0xA5A5_5A5A_DEAD_BEEF; + for i in 0..20_000usize { + let mut mutated = base; + for _ in 0..3 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let idx = SKIP_LEN + (seed as usize % (HANDSHAKE_LEN - SKIP_LEN)); + mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); + } + + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(172, 31, (i / 254) as u8, (i % 254 + 1) as u8)), + 47000 + (i % 15000) as u16, + ); + + let _ = tokio::time::timeout( + Duration::from_millis(250), + handle_mtproto_handshake( + &mutated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("soak mutation must complete in bounded time"); + } +} diff --git a/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs b/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs new file mode 100644 index 0000000..d8fac4f --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs @@ -0,0 +1,187 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_preauth_throttle_activates_after_failure_threshold() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 20)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, now); + } + + assert!( + auth_probe_is_throttled(ip, now), + "peer must be throttled once fail streak reaches threshold" + ); +} + +#[test] +fn negative_unrelated_peer_remains_unthrottled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let attacker = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 12)); + let benign = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 13)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(attacker, now); + } + + assert!(auth_probe_is_throttled(attacker, now)); + assert!( + !auth_probe_is_throttled(benign, now), + "throttle state must stay scoped to normalized peer key" + ); +} + +#[test] +fn edge_expired_entry_is_pruned_and_no_longer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 41)); + let base = Instant::now(); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, base); + } + + let expired_at = base + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1); + assert!( + !auth_probe_is_throttled(ip, expired_at), + "expired entries must not keep throttling peers" + ); + + let state = auth_probe_state_map(); + assert!( + state.get(&normalize_auth_probe_ip(ip)).is_none(), + "expired lookup should prune stale state" + ); +} + +#[test] +fn adversarial_saturation_grace_requires_extra_failures_before_preauth_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(198, 18, 0, 7)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, now); + } + auth_probe_note_saturation(now); + + assert!( + !auth_probe_should_apply_preauth_throttle(ip, now), + "during global saturation, peer must receive configured grace window" + ); + + for _ in 0..AUTH_PROBE_SATURATION_GRACE_FAILS { + auth_probe_record_failure(ip, now + Duration::from_millis(1)); + } + + assert!( + auth_probe_should_apply_preauth_throttle(ip, now + Duration::from_millis(1)), + "after grace failures are exhausted, preauth throttle must activate" + ); +} + +#[test] +fn integration_over_cap_insertion_keeps_probe_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 1024) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + ((idx / 65_536) % 256) as u8, + ((idx / 256) % 256) as u8, + (idx % 256) as u8, + )); + auth_probe_record_failure(ip, now); + } + + let tracked = auth_probe_state_map().len(); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "probe map must remain hard bounded under insertion storm" + ); +} + +#[test] +fn light_fuzz_randomized_failures_preserve_cap_and_nonzero_streaks() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut seed = 0x4D53_5854_6F66_6175u64; + let now = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + auth_probe_record_failure(ip, now + Duration::from_millis((seed & 0x3f) as u64)); + } + + let state = auth_probe_state_map(); + assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES); + for entry in state.iter() { + assert!(entry.value().fail_streak > 0); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_flood_keeps_state_hard_capped() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut tasks = Vec::new(); + + for worker in 0..8u8 { + tasks.push(tokio::spawn(async move { + for i in 0..4096u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_millis((i % 4) as u64)); + } + })); + } + + for task in tasks { + task.await.expect("stress worker must not panic"); + } + + let tracked = auth_probe_state_map().len(); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "parallel failure flood must not exceed cap" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(172, 3, 4, 5)); + let _ = auth_probe_is_throttled(probe, start + Duration::from_millis(2)); +} diff --git a/src/proxy/tests/handshake_fuzz_security_tests.rs b/src/proxy/tests/handshake_fuzz_security_tests.rs new file mode 100644 index 0000000..efb596b --- /dev/null +++ b/src/proxy/tests/handshake_fuzz_security_tests.rs @@ -0,0 +1,278 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::crypto::AesCtr; +use crate::crypto::sha256; +use crate::protocol::constants::ProtoTag; +use crate::stats::ReplayChecker; +use std::net::SocketAddr; +use std::sync::MutexGuard; +use tokio::time::{Duration as TokioDuration, timeout}; + +fn make_mtproto_handshake_with_proto_bytes( + secret_hex: &str, + proto_bytes: [u8; 4], + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg +} + +fn auth_probe_test_guard() -> MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[tokio::test] +async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "11223344556677889900aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let first = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0x00, 0x00, 0x00, 0x00], + 1, + )); + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0xff, 0xff, 0xff, 0xff], + 1, + )); + corpus.push(make_valid_mtproto_handshake( + "ffeeddccbbaa99887766554433221100", + ProtoTag::Secure, + 1, + )); + + let mut seed = 0xF0F0_F00D_BAAD_CAFEu64; + for _ in 0..32 { + let mut mutated = base; + for _ in 0..4 { + seed = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.into_iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "99887766554433221100ffeeddccbbaa"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap(); + + let first = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("base handshake must not hang"); + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("duplicate handshake must not hang"); + assert!(matches!(replay, HandshakeResult::BadClient { .. })); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + let mut prekey_flip = base; + prekey_flip[SKIP_LEN] ^= 0x80; + corpus.push(prekey_flip); + + let mut iv_flip = base; + iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01; + corpus.push(iv_flip); + + let mut tail_flip = base; + tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40; + corpus.push(tail_flip); + + let mut seed = 0xBADC_0FFE_EE11_4242u64; + for _ in 0..24 { + let mut mutated = base; + for _ in 0..3 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "mixed corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} diff --git a/src/proxy/tests/handshake_saturation_poison_security_tests.rs b/src/proxy/tests/handshake_saturation_poison_security_tests.rs new file mode 100644 index 0000000..4c2ca5d --- /dev/null +++ b/src/proxy/tests/handshake_saturation_poison_security_tests.rs @@ -0,0 +1,71 @@ +use super::*; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn poison_saturation_mutex() { + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _guard = saturation + .lock() + .expect("saturation mutex must be lockable for poison setup"); + panic!("intentional poison for saturation mutex resilience test"); + }); + let _ = poison_thread.join(); +} + +#[test] +fn auth_probe_saturation_note_recovers_after_mutex_poison() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + let now = Instant::now(); + auth_probe_note_saturation(now); + + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now), + "poisoned saturation mutex must not disable saturation throttling" + ); +} + +#[test] +fn auth_probe_saturation_check_recovers_after_mutex_poison() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: Instant::now() + Duration::from_millis(10), + last_seen: Instant::now(), + }); + } + + assert!( + auth_probe_saturation_is_throttled_for_testing(), + "throttle check must recover poisoned saturation mutex and stay fail-closed" + ); +} + +#[test] +fn clear_auth_probe_state_clears_saturation_even_if_poisoned() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + auth_probe_note_saturation(Instant::now()); + assert!(auth_probe_saturation_is_throttled_for_testing()); + + clear_auth_probe_state_for_testing(); + assert!( + !auth_probe_saturation_is_throttled_for_testing(), + "clear helper must clear saturation state even after poison" + ); +} diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs new file mode 100644 index 0000000..d06f63e --- /dev/null +++ b/src/proxy/tests/handshake_security_tests.rs @@ -0,0 +1,3490 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac}; +use dashmap::DashMap; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[test] +fn test_generate_tg_nonce() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert_eq!(nonce.len(), HANDSHAKE_LEN); + + let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); + assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); +} + +#[test] +fn test_encrypt_tg_nonce() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let encrypted = encrypt_tg_nonce(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); +} + +#[test] +fn test_handshake_success_drop_does_not_panic() { + let success = HandshakeSuccess { + user: "test".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Secure, + dec_key: [0xAA; 32], + dec_iv: 0xBBBBBBBB, + enc_key: [0xCC; 32], + enc_iv: 0xDDDDDDDD, + peer: "198.51.100.10:1234".parse().unwrap(), + is_tls: true, + }; + + assert_eq!(success.dec_key, [0xAA; 32]); + assert_eq!(success.enc_key, [0xCC; 32]); + + drop(success); +} + +#[test] +fn test_generate_tg_nonce_enc_dec_material_is_consistent() { + let client_enc_key = [0x34u8; 32]; + let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128; + let rng = SecureRandom::new(); + + let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 7, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + + let mut expected_tg_enc_key = [0u8; 32]; + expected_tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_tg_enc_iv_arr = [0u8; IV_LEN]; + expected_tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_tg_enc_iv = u128::from_be_bytes(expected_tg_enc_iv_arr); + + let mut expected_tg_dec_key = [0u8; 32]; + expected_tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut expected_tg_dec_iv_arr = [0u8; IV_LEN]; + expected_tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let expected_tg_dec_iv = u128::from_be_bytes(expected_tg_dec_iv_arr); + + assert_eq!(tg_enc_key, expected_tg_enc_key); + assert_eq!(tg_enc_iv, expected_tg_enc_iv); + assert_eq!(tg_dec_key, expected_tg_dec_key); + assert_eq!(tg_dec_iv, expected_tg_dec_iv); + assert_eq!( + i16::from_le_bytes([nonce[DC_IDX_POS], nonce[DC_IDX_POS + 1]]), + 7, + "Generated nonce must keep target dc index in protocol slot" + ); +} + +#[test] +fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { + let client_enc_key = [0xABu8; 32]; + let client_enc_iv = 0x11223344556677889900aabbccddeeffu128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 9, + &client_enc_key, + client_enc_iv, + &rng, + true, + ); + + let mut expected = Vec::with_capacity(KEY_LEN + IV_LEN); + expected.extend_from_slice(&client_enc_key); + expected.extend_from_slice(&client_enc_iv.to_be_bytes()); + expected.reverse(); + + assert_eq!( + &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], + expected.as_slice() + ); +} + +#[test] +fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(&nonce); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + let manual = manual_encryptor.encrypt(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_eq!( + &encrypted[PROTO_TAG_POS..], + &manual[PROTO_TAG_POS..], + "Encrypted nonce suffix must match AES-CTR output with derived enc key/iv" + ); +} + +#[tokio::test] +async fn tls_replay_second_identical_handshake_is_rejected() { + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.21:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn tls_replay_with_ignore_time_skew_and_small_boot_timestamp_is_still_blocked() { + let secret = [0x19u8; 16]; + let config = test_config_with_secret_hex("19191919191919191919191919191919"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.121:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 1); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "ignore_time_skew must not weaken replay rejection for small boot timestamps" + ); +} + +#[tokio::test] +async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { + let secret = [0x77u8; 16]; + let config = Arc::new(test_config_with_secret_hex( + "77777777777777777777777777777777", + )); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let handshake = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut tasks = Vec::new(); + for _ in 0..50 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let handshake = handshake.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.22:45000".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut success_count = 0usize; + for task in tasks { + let result = task.await.unwrap(); + if matches!(result, HandshakeResult::Success(_)) { + success_count += 1; + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + assert_eq!( + success_count, 1, + "Concurrent replay attempts must allow exactly one successful handshake" + ); +} + +#[tokio::test] +async fn tls_replay_matrix_rotating_peers_first_accept_then_rejects() { + let secret = [0x52u8; 16]; + let config = test_config_with_secret_hex("52525252525252525252525252525252"); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let handshake = make_valid_tls_handshake(&secret, 17); + + let first_peer: SocketAddr = "198.51.100.31:44001".parse().unwrap(); + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + for i in 0..128u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 51, 100, ((i % 250) + 1) as u8)), + 45000 + i, + ); + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "replay digest must be rejected regardless of source peer rotation" + ); + } +} + +#[tokio::test] +async fn adversarial_tls_replay_churn_allows_only_unique_digests() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.access.ignore_time_skew = true; + let config = Arc::new(config); + let replay_checker = Arc::new(ReplayChecker::new(8192, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + + let make_tagged_handshake = |timestamp: u32, tag: u8| { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![tag; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(&secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake + }; + + let mut tasks = Vec::new(); + + // 128 exact duplicates: only one should pass. + let duplicated = Arc::new(make_valid_tls_handshake(&secret, 999)); + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let duplicated = Arc::clone(&duplicated); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, ((i % 250) + 1) as u8)), + 46000 + i, + ); + handle_tls_handshake( + &duplicated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + // 128 unique timestamps: all should pass because HMAC digest differs. + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let handshake = make_tagged_handshake(10_000 + i as u32, (i as u8).wrapping_add(0x80)); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, 0, ((i % 250) + 1) as u8)), + 47000 + i, + ); + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut duplicate_success = 0usize; + let mut duplicate_reject = 0usize; + let mut unique_success = 0usize; + let mut unique_reject = 0usize; + + for (idx, task) in tasks.into_iter().enumerate() { + let result = task.await.unwrap(); + let is_duplicate_group = idx < 128; + match result { + HandshakeResult::Success(_) => { + if is_duplicate_group { + duplicate_success += 1; + } else { + unique_success += 1; + } + } + HandshakeResult::BadClient { .. } => { + if is_duplicate_group { + duplicate_reject += 1; + } else { + unique_reject += 1; + } + } + HandshakeResult::Error(e) => panic!("unexpected handshake error in churn test: {e}"), + } + } + + assert_eq!( + duplicate_success, 1, + "duplicate replay group must allow exactly one successful handshake" + ); + assert_eq!( + duplicate_reject, 127, + "duplicate replay group must reject all remaining replays" + ); + assert_eq!( + unique_success, 128, + "unique digest group must fully pass under replay churn" + ); + assert_eq!( + unique_reject, 0, + "unique digest group must not be falsely rejected as replay" + ); +} + +#[tokio::test] +async fn invalid_tls_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.23:44322".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let before = replay_checker.stats(); + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn empty_decoded_secret_is_rejected() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + let config = test_config_with_secret_hex(""); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.24:44323".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn wrong_length_decoded_secret_is_rejected() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + let config = test_config_with_secret_hex("aa"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.25:44324".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[0xaau8], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn invalid_mtproto_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.26:44325".parse().unwrap(); + let handshake = [0u8; HANDSHAKE_LEN]; + + let before = replay_checker.stats(); + let result = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn mixed_secret_lengths_keep_valid_user_authenticating() { + let _probe_guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + clear_auth_probe_state_for_testing(); + let good_secret = [0x22u8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config + .access + .users + .insert("broken_user".to_string(), "aa".to_string()); + config.access.users.insert( + "valid_user".to_string(), + "22222222222222222222222222222222".to_string(), + ); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.27:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&good_secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + +#[tokio::test] +async fn tls_sni_preferred_user_hint_selects_matching_identity_first() { + let shared_secret = [0x3Bu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert( + "user-a".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.users.insert( + "user-b".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&shared_secret, 0, "user-b", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, "user-b", + "TLS SNI preferred-user hint must select matching identity before equivalent decoys" + ); + } + _ => panic!("TLS handshake must succeed for valid shared-secret SNI case"), + } +} + +#[test] +fn stress_decode_user_secrets_keeps_preferred_user_first_in_large_set() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config + .access + .users + .insert(format!("decoy-{i:04}.example"), secret_hex.clone()); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex.clone()); + + let decoded = decode_user_secrets(&config, Some(preferred_user.as_str())); + assert_eq!( + decoded.len(), + config.access.users.len(), + "decoded secret set must preserve full user cardinality under stress" + ); + assert_eq!( + decoded.first().map(|(name, _)| name.as_str()), + Some(preferred_user.as_str()), + "preferred user must be first even under adversarial large user sets" + ); + assert_eq!( + decoded + .iter() + .filter(|(name, _)| name == &preferred_user) + .count(), + 1, + "preferred user must appear exactly once in decoded list" + ); +} + +#[tokio::test] +async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { + let shared_secret = [0x7Fu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config + .access + .users + .insert(format!("decoy-{i:04}.example"), secret_hex.clone()); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.189:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn( + &shared_secret, + 0, + preferred_user.as_str(), + &[b"h2"], + ); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, preferred_user, + "SNI preferred-user hint must remain stable under large user cardinality" + ); + } + _ => panic!("TLS handshake must succeed for valid preferred-user stress case"), + } +} + +#[tokio::test] +async fn alpn_enforce_rejects_unsupported_client_alpn() { + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.28:44327".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn alpn_enforce_accepts_h2() { + let secret = [0x44u8; 16]; + let mut config = test_config_with_secret_hex("44444444444444444444444444444444"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.29:44328".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2", b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + +#[tokio::test] +async fn malformed_tls_classes_complete_within_bounded_time() { + let secret = [0x55u8; 16]; + let mut config = test_config_with_secret_hex("55555555555555555555555555555555"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(512, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.30:44329".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + for probe in [too_short, bad_hmac, alpn_mismatch] { + let result = tokio::time::timeout( + Duration::from_millis(200), + handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ), + ) + .await + .expect("Malformed TLS classes must be rejected within bounded time"); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } +} + +#[tokio::test] +async fn tls_invalid_hmac_respects_configured_anti_fingerprint_delay() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.32:44331".parse().unwrap(); + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let started = Instant::now(); + let result = handle_tls_handshake( + &bad_hmac, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to invalid TLS handshakes" + ); +} + +#[tokio::test] +async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() { + let secret = [0x6Bu8; 16]; + let mut config = test_config_with_secret_hex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.33:44332".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let started = Instant::now(); + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to ALPN-mismatch rejects" + ); +} + +#[tokio::test] +#[ignore = "timing-sensitive; run manually on low-jitter hosts"] +async fn malformed_tls_classes_share_close_latency_buckets() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let secret = [0x99u8; 16]; + let mut config = test_config_with_secret_hex("99999999999999999999999999999999"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.31:44330".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let mut class_means_ms = Vec::new(); + for probe in [too_short, bad_hmac, alpn_mismatch] { + let mut sum_micros: u128 = 0; + for _ in 0..ITER { + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + sum_micros += elapsed.as_micros(); + } + + class_means_ms.push(sum_micros / ITER as u128 / 1_000); + } + + let min_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .min() + .unwrap(); + let max_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .max() + .unwrap(); + + assert!( + max_bucket <= min_bucket + 1, + "Malformed TLS classes diverged across latency buckets: means_ms={:?}", + class_means_ms + ); +} + +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_tls_classes_under_fixed_delay_budget() { + const ITER: usize = 48; + const BUCKET_MS: u128 = 10; + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let rng = SecureRandom::new(); + let base_ip = std::net::Ipv4Addr::new(198, 51, 100, 34); + + let too_short = vec![0x16, 0x03, 0x01]; + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let valid_h2 = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2"]); + + let classes = vec![ + ("too_short", too_short), + ("bad_hmac", bad_hmac), + ("alpn_mismatch", alpn_mismatch), + ("valid_h2", valid_h2), + ]; + + for (class, probe) in classes { + let mut samples_ms = Vec::with_capacity(ITER); + for idx in 0..ITER { + clear_auth_probe_state_for_testing(); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let peer: SocketAddr = SocketAddr::from((base_ip, 44_000 + idx as u16)); + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + samples_ms.push(elapsed.as_millis()); + + if class == "valid_h2" { + assert!(matches!(result, HandshakeResult::Success(_))); + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + + println!( + "TIMING_MATRIX tls class={} mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + class, + mean, + min, + p95, + max, + (mean as u128) / BUCKET_MS + ); + } +} + +#[test] +fn secure_tag_requires_tls_mode_on_tls_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be rejected when tls mode is disabled" + ); + + config.general.modes.tls = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be accepted when tls mode is enabled" + ); +} + +#[test] +fn secure_tag_requires_secure_mode_on_direct_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = false; + config.general.modes.tls = true; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be rejected when secure mode is disabled" + ); + + config.general.modes.secure = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be accepted when secure mode is enabled" + ); +} + +#[test] +fn mode_policy_matrix_is_stable_for_all_tag_transport_mode_combinations() { + let tags = [ProtoTag::Secure, ProtoTag::Intermediate, ProtoTag::Abridged]; + + for classic in [false, true] { + for secure in [false, true] { + for tls in [false, true] { + let mut config = ProxyConfig::default(); + config.general.modes.classic = classic; + config.general.modes.secure = secure; + config.general.modes.tls = tls; + + for is_tls in [false, true] { + for tag in tags { + let expected = match (tag, is_tls) { + (ProtoTag::Secure, true) => tls, + (ProtoTag::Secure, false) => secure, + (ProtoTag::Intermediate | ProtoTag::Abridged, _) => classic, + }; + + assert_eq!( + mode_enabled_for_proto(&config, tag, is_tls), + expected, + "mode policy drifted for tag={:?}, transport_tls={}, modes=(classic={}, secure={}, tls={})", + tag, + is_tls, + classic, + secure, + tls + ); + } + } + } + } + } +} + +#[test] +fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1)); + warn_invalid_secret_once("a", "b:c", ACCESS_SECRET_BYTES, Some(2)); + + let warned = INVALID_SECRET_WARNED + .get() + .expect("warned set must be initialized"); + let guard = warned.lock().expect("warned set lock must be available"); + assert_eq!( + guard.len(), + 2, + "(name, reason) pairs that stringify to the same colon-joined key must remain distinct" + ); +} + +#[test] +fn invalid_secret_warning_cache_is_bounded() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + for idx in 0..(WARNED_SECRET_MAX_ENTRIES + 32) { + let user = format!("warned_user_{idx}"); + warn_invalid_secret_once(&user, "invalid_length", ACCESS_SECRET_BYTES, Some(idx)); + } + + let warned = INVALID_SECRET_WARNED + .get() + .expect("warned set must be initialized"); + let guard = warned.lock().expect("warned set lock must be available"); + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "invalid-secret warning cache must remain bounded" + ); +} + +#[tokio::test] +async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.61:44361".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "invalid probe burst must grow pre-auth failure streak to backoff threshold" + ); +} + +#[tokio::test] +async fn successful_tls_handshake_clears_pre_auth_failure_streak() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x23u8; 16]; + let config = test_config_with_secret_hex("23232323232323232323232323232323"); + let replay_checker = ReplayChecker::new(256, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected), + "failure streak must grow before a successful authentication" + ); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let success = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(success, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful authentication must clear accumulated pre-auth failures" + ); +} + +#[test] +fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { + let state = DashMap::new(); + let now = Instant::now(); + let stale_seen = now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 1, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: stale_seen, + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.get(&newcomer).map(|entry| entry.fail_streak), + Some(1), + "stale-entry pruning must admit and track a new probe source" + ); + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain bounded after stale pruning" + ); +} + +#[test] +fn auth_probe_capacity_fresh_full_map_still_tracks_newcomer_with_bounded_eviction() { + let _guard = auth_probe_test_lock() + .lock() + .expect("auth probe test lock must be available"); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 16, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let oldest = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 55)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_some(), + "fresh-at-cap auth probe map must still track a new source after bounded eviction" + ); + assert!( + state.get(&oldest).is_none(), + "capacity eviction must remove the oldest tracked source first" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must stay at configured cap after bounded eviction" + ); + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now), + "capacity pressure should still activate coarse global pre-auth throttling" + ); +} + +#[test] +fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 2, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 2048) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 0, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_some(), + "new source must still be tracked under sustained at-capacity churn" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map size must stay hard-bounded at capacity" + ); + } +} + +#[test] +fn auth_probe_over_cap_churn_still_tracks_newcomer_after_round_limit() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 32; + + for idx in 0..initial { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 6, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 114, 77)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_secs(1)); + + assert!( + state.get(&newcomer).is_some(), + "new probe source must still be tracked even when map starts above hard cap" + ); + assert!( + state.len() < initial + 1, + "round-limited eviction path must still reclaim capacity under over-cap churn" + ); +} + +#[test] +fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + // Fill map at capacity with mostly high fail streak entries. + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 20, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 9, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let low_fail = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 1)); + state.insert( + low_fail, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_secs(30), + }, + ); + + let high_fail_old = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 2)); + state.insert( + high_fail_old, + AuthProbeState { + fail_streak: 12, + blocked_until: now, + last_seen: now - Duration::from_secs(10), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 201)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&low_fail).is_none(), + "least-penalized entry should be evicted before high-penalty entries" + ); + assert!( + state.get(&high_fail_old).is_some(), + "high fail-streak entry should be preserved under mixed-priority eviction" + ); +} + +#[test] +fn auth_probe_capacity_tie_breaker_evicts_oldest_with_equal_fail_streak() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 30, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 5, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let oldest = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 1)); + let newer = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 2)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(20), + }, + ); + state.insert( + newer, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 202)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&oldest).is_none(), + "among equal fail streak candidates, oldest entry must be evicted" + ); + assert!( + state.get(&newer).is_some(), + "newer equal-priority entry should be retained" + ); +} + +#[test] +fn stress_auth_probe_capacity_churn_preserves_high_fail_sentinels() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + let sentinel_a = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250)); + let sentinel_b = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 251)); + + state.insert( + sentinel_a, + AuthProbeState { + fail_streak: 20, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(30), + }, + ); + state.insert( + sentinel_b, + AuthProbeState { + fail_streak: 21, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(31), + }, + ); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 4, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 1, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain hard-bounded at capacity" + ); + assert!( + state.get(&sentinel_a).is_some() && state.get(&sentinel_b).is_some(), + "high fail-streak sentinels should survive low-streak newcomer churn" + ); + } +} + +#[test] +fn auth_probe_ipv6_is_bucketed_by_prefix_64() { + let state = DashMap::new(); + let now = Instant::now(); + + let ip_a = IpAddr::V6("2001:db8:abcd:1234:1:2:3:4".parse().unwrap()); + let ip_b = IpAddr::V6("2001:db8:abcd:1234:ffff:eeee:dddd:cccc".parse().unwrap()); + + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now); + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now); + + let normalized = normalize_auth_probe_ip(ip_a); + assert_eq!( + state.len(), + 1, + "IPv6 sources in the same /64 must share one pre-auth throttle bucket" + ); + assert_eq!( + state.get(&normalized).map(|entry| entry.fail_streak), + Some(2), + "failures from the same /64 must accumulate in one throttle state" + ); +} + +#[test] +fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() { + let state = DashMap::new(); + let now = Instant::now(); + + let ip_a = IpAddr::V6("2001:db8:1111:2222:1:2:3:4".parse().unwrap()); + let ip_b = IpAddr::V6("2001:db8:1111:3333:1:2:3:4".parse().unwrap()); + + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now); + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now); + + assert_eq!( + state.len(), + 2, + "different IPv6 /64 prefixes must not share throttle buckets" + ); + assert_eq!( + state + .get(&normalize_auth_probe_ip(ip_a)) + .map(|entry| entry.fail_streak), + Some(1) + ); + assert_eq!( + state + .get(&normalize_auth_probe_ip(ip_b)) + .map(|entry| entry.fail_streak), + Some(1) + ); +} + +#[test] +fn auth_probe_success_clears_whole_ipv6_prefix_bucket() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let ip_fail = IpAddr::V6("2001:db8:aaaa:bbbb:1:2:3:4".parse().unwrap()); + let ip_success = IpAddr::V6("2001:db8:aaaa:bbbb:ffff:eeee:dddd:cccc".parse().unwrap()); + + auth_probe_record_failure(ip_fail, now); + assert_eq!( + auth_probe_fail_streak_for_testing(ip_fail), + Some(1), + "precondition: normalized prefix bucket must exist" + ); + + auth_probe_record_success(ip_success); + assert_eq!( + auth_probe_fail_streak_for_testing(ip_fail), + None, + "success from the same /64 must clear the shared bucket" + ); +} + +#[test] +fn auth_probe_eviction_offset_varies_with_input() { + let now = Instant::now(); + let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10)); + let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11)); + + let a = auth_probe_eviction_offset(ip1, now); + let b = auth_probe_eviction_offset(ip1, now); + let c = auth_probe_eviction_offset(ip2, now); + + assert_eq!(a, b, "same input must yield deterministic offset"); + assert_ne!(a, c, "different peer IPs should not collapse to one offset"); +} + +#[test] +fn auth_probe_eviction_offset_changes_with_time_component() { + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 77)); + let now = Instant::now(); + let later = now + Duration::from_millis(1); + + let a = auth_probe_eviction_offset(ip, now); + let b = auth_probe_eviction_offset(ip, later); + + assert_ne!( + a, b, + "eviction offset must incorporate timestamp entropy and not only peer IP" + ); +} + +#[test] +fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 64; + + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 25, + blocked_until: now, + last_seen: now - Duration::from_secs(30), + }, + ); + + for idx in 0..(initial - 1) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 20, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1)); + + assert!( + state.get(&newcomer).is_some(), + "newcomer must still be tracked under over-cap pressure" + ); + assert!( + state.get(&sentinel).is_some(), + "high fail-streak sentinel must survive round-limited eviction" + ); + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now + Duration::from_millis(1)), + "round-limited over-cap path must activate saturation throttle marker" + ); +} + +#[tokio::test] +async fn gap_t01_short_tls_probe_burst_is_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &too_short, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "short TLS probe bursts must increase auth-probe fail streak" + ); +} + +#[test] +fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 30, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(60), + }, + ); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 80) { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 22, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 2048) as u64), + }, + ); + } + + for step in 0..512usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 2, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + auth_probe_record_failure_with_state( + &state, + newcomer, + base_now + Duration::from_millis(step as u64 + 1), + ); + + assert!( + state.get(&sentinel).is_some(), + "step {step}: high-threat sentinel must not be starved by newcomer churn" + ); + assert!( + state.get(&newcomer).is_some(), + "step {step}: newcomer must be tracked" + ); + } +} + +#[test] +fn light_fuzz_auth_probe_overcap_eviction_prefers_less_threatening_entries() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let mut s: u64 = 0xBADC_0FFE_EE11_2233; + + for round in 0..128usize { + let state = DashMap::new(); + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 180)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 18, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + (s & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((s & 1023) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 10, + ((round >> 8) & 0xff) as u8, + (round & 0xff) as u8, + )); + auth_probe_record_failure_with_state( + &state, + newcomer, + now + Duration::from_millis(round as u64 + 1), + ); + + assert!( + state.get(&newcomer).is_some(), + "round {round}: newcomer should be tracked" + ); + assert!( + state.get(&sentinel).is_some(), + "round {round}: high fail-streak sentinel should survive mixed low-threat pool" + ); + } +} +#[test] +fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() { + let mut rng = StdRng::seed_from_u64(0xA11CE5EED); + let base = Instant::now(); + + for _ in 0..4096usize { + let ip = IpAddr::V4(Ipv4Addr::new( + rng.random(), + rng.random(), + rng.random(), + rng.random(), + )); + let offset_ns = rng.random_range(0_u64..2_000_000); + let when = base + Duration::from_nanos(offset_ns); + + let first = auth_probe_eviction_offset(ip, when); + let second = auth_probe_eviction_offset(ip, when); + assert_eq!( + first, second, + "eviction offset must be stable for identical (ip, now) pairs" + ); + } +} + +#[test] +fn adversarial_eviction_offset_spread_avoids_single_bucket_collapse() { + let modulus = AUTH_PROBE_TRACK_MAX_ENTRIES; + let mut bucket_hits = vec![0usize; modulus]; + let now = Instant::now(); + + for idx in 0..8192usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 100, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + ((idx.wrapping_mul(37)) & 0xff) as u8, + )); + let bucket = auth_probe_eviction_offset(ip, now) % modulus; + bucket_hits[bucket] += 1; + } + + let non_empty_buckets = bucket_hits.iter().filter(|&&hits| hits > 0).count(); + assert!( + non_empty_buckets >= modulus / 2, + "adversarial sequential input should cover a broad bucket set (covered {non_empty_buckets}/{modulus})" + ); + + let max_hits = bucket_hits.iter().copied().max().unwrap_or(0); + let min_non_zero_hits = bucket_hits + .iter() + .copied() + .filter(|&hits| hits > 0) + .min() + .unwrap_or(0); + assert!( + max_hits <= min_non_zero_hits.saturating_mul(32).max(1), + "bucket skew is unexpectedly extreme for keyed hasher spread (max={max_hits}, min_non_zero={min_non_zero_hits})" + ); +} + +#[test] +fn stress_auth_probe_eviction_offset_high_volume_uniqueness_sanity() { + let now = Instant::now(); + let mut seen = std::collections::HashSet::new(); + + for idx in 0..50_000usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 198, + ((idx >> 16) & 0xff) as u8, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + seen.insert(auth_probe_eviction_offset(ip, now)); + } + + assert!( + seen.len() >= 40_000, + "high-volume eviction offsets should not collapse excessively under keyed hashing" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let peer_ip: IpAddr = "198.51.100.90".parse().unwrap(); + let tasks = 128usize; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::with_capacity(tasks); + + for _ in 0..tasks { + let barrier = barrier.clone(); + handles.push(tokio::spawn(async move { + barrier.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle + .await + .expect("concurrent failure recording task must not panic"); + } + + let streak = auth_probe_fail_streak_for_testing(peer_ip) + .expect("tracked peer must exist after concurrent failure burst"); + assert_eq!( + streak as usize, tasks, + "concurrent failures for one source must account every attempt" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x31u8; 16]; + let config = Arc::new(test_config_with_secret_hex( + "31313131313131313131313131313131", + )); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap(); + let valid = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid = Arc::new(invalid); + + let mut noise_tasks = Vec::new(); + for idx in 0..96u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid = invalid.clone(); + noise_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 45000 + idx); + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + })); + } + + let victim_config = config.clone(); + let victim_replay_checker = replay_checker.clone(); + let victim_rng = rng.clone(); + let victim_valid = valid.clone(); + let victim_task = tokio::spawn(async move { + handle_tls_handshake( + &victim_valid, + tokio::io::empty(), + tokio::io::sink(), + victim_peer, + &victim_config, + &victim_replay_checker, + &victim_rng, + None, + ) + .await + }); + + for task in noise_tasks { + task.await.expect("noise task must not panic"); + } + + let victim_result = victim_task + .await + .expect("victim handshake task must not panic"); + assert!( + matches!(victim_result, HandshakeResult::Success(_)), + "invalid probe noise from other IPs must not block a valid victim handshake" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(victim_peer.ip()), + None, + "successful victim handshake must not retain pre-auth failure streak" + ); +} + +#[test] +fn auth_probe_saturation_state_expires_after_retention_window() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(30), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + assert!( + !auth_probe_saturation_is_throttled_for_testing(), + "expired saturation state must stop throttling and self-clear" + ); + + let guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none(), "expired saturation state must be removed"); +} + +#[tokio::test] +async fn global_saturation_marker_does_not_block_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x41u8; 16]; + let config = test_config_with_secret_hex("41414141414141414141414141414141"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.101:45101".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "global saturation marker must not block valid authenticated TLS handshakes" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful handshake under saturation marker must not retain per-ip probe failures" + ); +} + +#[tokio::test] +async fn expired_global_saturation_allows_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x55u8; 16]; + let config = test_config_with_secret_hex("55555555555555555555555555555555"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.102:45102".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "expired saturation marker must not block valid handshake" + ); +} + +#[tokio::test] +async fn valid_tls_is_blocked_by_per_ip_preauth_throttle_without_saturation() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x61u8; 16]; + let config = test_config_with_secret_hex("61616161616161616161616161616161"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.103:45103".parse().unwrap(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: Instant::now() + Duration::from_secs(5), + last_seen: Instant::now(), + }, + ); + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn saturation_allows_valid_tls_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x62u8; 16]; + let config = test_config_with_secret_hex("62626262626262626262626262626262"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.104:45104".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_tls_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.105:45105".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid TLS during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_tls_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.205:45205".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid TLS" + ); +} + +#[tokio::test] +async fn saturation_allows_valid_mtproto_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "64646464646464646464646464646464"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.secure = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.106:45106".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let result = handle_mtproto_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful mtproto auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_mtproto_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.107:45107".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid mtproto during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_mtproto_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.206:45206".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid MTProto" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("70707070707070707070707070707070"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.207:45207".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected) + ); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid TLS must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("71717171717171717171717171717171"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.208:45208".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected) + ); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid MTProto must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_boundary_still_admits_valid_tls_before_exhaustion() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x72u8; 16]; + let config = test_config_with_secret_hex("72727272727272727272727272727272"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.209:45209".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "valid TLS should still pass while peer remains within saturation grace budget" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_blocks_valid_tls_until_backoff_expires() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x73u8; 16]; + let config = test_config_with_secret_hex("73737373737373737373737373737373"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.210:45210".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_millis(200), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let blocked = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(blocked, HandshakeResult::BadClient { .. })); + + tokio::time::sleep(Duration::from_millis(230)).await; + + let allowed = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(allowed, HandshakeResult::Success(_)), + "valid TLS should recover after peer-specific pre-auth backoff has elapsed" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_is_shared_across_tls_and_mtproto_for_same_peer() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("74747474747474747474747474747474"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.211:45211".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_mtproto = [0u8; HANDSHAKE_LEN]; + + let tls_result = handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(tls_result, HandshakeResult::BadClient { .. })); + + let mtproto_result = handle_mtproto_handshake( + &invalid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(mtproto_result, HandshakeResult::BadClient { .. })); + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "saturation grace exhaustion must gate both TLS and MTProto pre-auth paths for one peer" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_same_peer_invalid_tls_storm_does_not_bypass_saturation_grace_cap() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = Arc::new(test_config_with_secret_hex( + "75757575757575757575757575757575", + )); + let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut tasks = Vec::new(); + for _ in 0..64usize { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + for task in tasks { + let result = task.await.unwrap(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "same-peer invalid storm under exhausted grace must stay pre-auth throttled without fail-streak growth" + ); +} + +#[tokio::test] +async fn light_fuzz_saturation_grace_tls_invalid_inputs_never_authenticate_or_panic() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("76767676767676767676767676767676"); + let replay_checker = ReplayChecker::new(2048, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.213:45213".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut seeded = StdRng::seed_from_u64(0xD15EA5E5_u64); + for _ in 0..128usize { + let len = seeded.random_range(0usize..96usize); + let mut probe = vec![0u8; len]; + seeded.fill(&mut probe[..]); + + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + let streak = auth_probe_fail_streak_for_testing(peer.ip()) + .expect("peer should remain tracked after repeated invalid fuzz probes"); + assert!( + streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + "fuzzed invalid TLS probes under saturation must not reduce fail-streak below exhaustion threshold" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshakes() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "66666666666666666666666666666666"; + let secret = [0x66u8; 16]; + let mut cfg = test_config_with_secret_hex(secret_hex); + cfg.general.modes.secure = true; + let config = Arc::new(cfg); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0)); + let valid_mtproto = Arc::new(make_valid_mtproto_handshake( + secret_hex, + ProtoTag::Secure, + 3, + )); + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut invalid_tls_tasks = Vec::new(); + for idx in 0..48u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + invalid_tls_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 46000 + idx); + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let valid_tls_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let valid_tls = valid_tls.clone(); + tokio::spawn(async move { + handle_tls_handshake( + &valid_tls, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.108:45108".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + }) + }; + + let valid_mtproto_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let valid_mtproto = valid_mtproto.clone(); + tokio::spawn(async move { + handle_mtproto_handshake( + &valid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.109:45109".parse().unwrap(), + &config, + &replay_checker, + false, + None, + ) + .await + }) + }; + + let mut bad_clients = 0usize; + for task in invalid_tls_tasks { + match task.await.unwrap() { + HandshakeResult::BadClient { .. } => bad_clients += 1, + HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"), + HandshakeResult::Error(err) => { + panic!("unexpected error in invalid TLS saturation burst test: {err}") + } + } + } + + let valid_tls_result = valid_tls_task.await.unwrap(); + assert!( + matches!(valid_tls_result, HandshakeResult::Success(_)), + "valid TLS probe must authenticate during saturation burst" + ); + + let valid_mtproto_result = valid_mtproto_task.await.unwrap(); + assert!( + matches!(valid_mtproto_result, HandshakeResult::Success(_)), + "valid MTProto probe must authenticate during saturation burst" + ); + + assert_eq!( + bad_clients, 48, + "all invalid TLS probes in mixed saturation burst must be rejected" + ); +} + +#[tokio::test] +async fn expired_saturation_keeps_per_ip_throttle_enforced_for_valid_tls() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x67u8; 16]; + let config = test_config_with_secret_hex("67676767676767676767676767676767"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.110:45110".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "expired saturation marker must not disable per-ip pre-auth throttle" + ); +} diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs new file mode 100644 index 0000000..3e860e8 --- /dev/null +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -0,0 +1,577 @@ +use super::*; +use std::collections::BTreeSet; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +#[derive(Clone, Copy)] +enum PathClass { + ConnectFail, + ConnectSuccess, + SlowBackend, +} + +fn mean_ms(samples: &[u128]) -> f64 { + if samples.is_empty() { + return 0.0; + } + let sum: u128 = samples.iter().copied().sum(); + sum as f64 / samples.len() as f64 +} + +fn percentile_ms(mut values: Vec, p_num: usize, p_den: usize) -> u128 { + values.sort_unstable(); + if values.is_empty() { + return 0; + } + let idx = ((values.len() - 1) * p_num) / p_den; + values[idx] +} + +fn bucketize_ms(values: &[u128], bucket_ms: u128) -> Vec { + values.iter().map(|v| *v / bucket_ms).collect() +} + +fn best_threshold_accuracy_u128(a: &[u128], b: &[u128]) -> f64 { + let min_v = *a.iter().chain(b.iter()).min().unwrap(); + let max_v = *a.iter().chain(b.iter()).max().unwrap(); + + let mut best = 0.0f64; + for t in min_v..=max_v { + let correct_a = a.iter().filter(|&&x| x <= t).count(); + let correct_b = b.iter().filter(|&&x| x > t).count(); + let acc = (correct_a + correct_b) as f64 / (a.len() + b.len()) as f64; + if acc > best { + best = acc; + } + } + best +} + +fn spread_u128(values: &[u128]) -> u128 { + if values.is_empty() { + return 0; + } + let min_v = *values.iter().min().unwrap(); + let max_v = *values.iter().max().unwrap(); + max_v - min_v +} + +fn interval_gap_usize(a: &BTreeSet, b: &BTreeSet) -> usize { + if a.is_empty() || b.is_empty() { + return 0; + } + + let a_min = *a.iter().next().unwrap(); + let a_max = *a.iter().next_back().unwrap(); + let b_min = *b.iter().next().unwrap(); + let b_max = *b.iter().next_back().unwrap(); + + if a_max < b_min { + b_min - a_max + } else if b_max < a_min { + a_min - b_max + } else { + 0 + } +} + +async fn collect_timing_samples(path: PathClass, timing_norm_enabled: bool, n: usize) -> Vec { + let mut out = Vec::with_capacity(n); + for _ in 0..n { + out.push(measure_masking_duration_ms(path, timing_norm_enabled).await); + } + out +} + +async fn measure_masking_duration_ms(path: PathClass, timing_norm_enabled: bool) -> u128 { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_timing_normalization_enabled = timing_norm_enabled; + config.censorship.mask_timing_normalization_floor_ms = 220; + config.censorship.mask_timing_normalization_ceiling_ms = 260; + + let accept_task = match path { + PathClass::ConnectFail => { + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + None + } + PathClass::ConnectSuccess => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + })) + } + PathClass::SlowBackend => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + tokio::time::sleep(Duration::from_millis(320)).await; + })) + } + }; + + let (client_reader, _client_writer) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let peer: SocketAddr = "198.51.100.230:57230".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET /ab-harness HTTP/1.1\r\nHost: x\r\n\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + + if let Some(task) = accept_task { + let _ = tokio::time::timeout(Duration::from_secs(2), task).await; + } + + started.elapsed().as_millis() +} + +async fn capture_above_cap_forwarded_len( + body_sent: usize, + above_cap_blur_enabled: bool, + above_cap_blur_max_bytes: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur_enabled; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (client_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + + let peer: SocketAddr = "198.51.100.231:57231".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut initial = vec![0u8; 5 + body_sent]; + initial[0] = 0x16; + initial[1] = 0x03; + initial[2] = 0x01; + initial[3..5].copy_from_slice(&7000u16.to_be_bytes()); + initial[5..].fill(0x5A); + + let fallback_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(4), fallback_task) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn integration_ab_harness_envelope_and_blur_improve_obfuscation_vs_baseline() { + const ITER: usize = 8; + + let mut baseline_fail = Vec::with_capacity(ITER); + let mut baseline_success = Vec::with_capacity(ITER); + let mut baseline_slow = Vec::with_capacity(ITER); + + let mut hardened_fail = Vec::with_capacity(ITER); + let mut hardened_success = Vec::with_capacity(ITER); + let mut hardened_slow = Vec::with_capacity(ITER); + + for _ in 0..ITER { + baseline_fail.push(measure_masking_duration_ms(PathClass::ConnectFail, false).await); + baseline_success.push(measure_masking_duration_ms(PathClass::ConnectSuccess, false).await); + baseline_slow.push(measure_masking_duration_ms(PathClass::SlowBackend, false).await); + + hardened_fail.push(measure_masking_duration_ms(PathClass::ConnectFail, true).await); + hardened_success.push(measure_masking_duration_ms(PathClass::ConnectSuccess, true).await); + hardened_slow.push(measure_masking_duration_ms(PathClass::SlowBackend, true).await); + } + + let baseline_means = [ + mean_ms(&baseline_fail), + mean_ms(&baseline_success), + mean_ms(&baseline_slow), + ]; + let hardened_means = [ + mean_ms(&hardened_fail), + mean_ms(&hardened_success), + mean_ms(&hardened_slow), + ]; + + let baseline_range = baseline_means + .iter() + .copied() + .fold((f64::INFINITY, f64::NEG_INFINITY), |(mn, mx), v| { + (mn.min(v), mx.max(v)) + }); + let hardened_range = hardened_means + .iter() + .copied() + .fold((f64::INFINITY, f64::NEG_INFINITY), |(mn, mx), v| { + (mn.min(v), mx.max(v)) + }); + + let baseline_spread = baseline_range.1 - baseline_range.0; + let hardened_spread = hardened_range.1 - hardened_range.0; + + println!( + "ab_harness_timing baseline_means={:?} hardened_means={:?} baseline_spread={:.2} hardened_spread={:.2}", + baseline_means, hardened_means, baseline_spread, hardened_spread + ); + + assert!( + hardened_spread < baseline_spread, + "timing envelope should reduce cross-path mean spread: baseline={baseline_spread:.2} hardened={hardened_spread:.2}" + ); + + let mut baseline_a = BTreeSet::new(); + let mut baseline_b = BTreeSet::new(); + let mut hardened_a = BTreeSet::new(); + let mut hardened_b = BTreeSet::new(); + + for _ in 0..24 { + baseline_a.insert(capture_above_cap_forwarded_len(5000, false, 0).await); + baseline_b.insert(capture_above_cap_forwarded_len(5040, false, 0).await); + + hardened_a.insert(capture_above_cap_forwarded_len(5000, true, 96).await); + hardened_b.insert(capture_above_cap_forwarded_len(5040, true, 96).await); + } + + let baseline_overlap = baseline_a.intersection(&baseline_b).count(); + let hardened_overlap = hardened_a.intersection(&hardened_b).count(); + let baseline_gap = interval_gap_usize(&baseline_a, &baseline_b); + let hardened_gap = interval_gap_usize(&hardened_a, &hardened_b); + + println!( + "ab_harness_length baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={} baseline_a={} baseline_b={} hardened_a={} hardened_b={}", + baseline_overlap, + hardened_overlap, + baseline_gap, + hardened_gap, + baseline_a.len(), + baseline_b.len(), + hardened_a.len(), + hardened_b.len() + ); + + assert_eq!( + baseline_overlap, 0, + "baseline above-cap classes should be disjoint" + ); + assert!( + hardened_a.len() > baseline_a.len() && hardened_b.len() > baseline_b.len(), + "above-cap blur should widen per-class emitted lengths: baseline_a={} baseline_b={} hardened_a={} hardened_b={}", + baseline_a.len(), + baseline_b.len(), + hardened_a.len(), + hardened_b.len() + ); + assert!( + hardened_overlap > baseline_overlap || hardened_gap < baseline_gap, + "above-cap blur should reduce class separability via direct overlap or tighter interval gap: baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={}", + baseline_overlap, + hardened_overlap, + baseline_gap, + hardened_gap + ); +} + +#[test] +fn timing_classifier_helper_bucketize_is_stable() { + let values = vec![219u128, 220, 239, 240, 259, 260]; + let got = bucketize_ms(&values, 20); + assert_eq!(got, vec![10, 11, 11, 12, 12, 13]); +} + +#[test] +fn timing_classifier_helper_percentile_is_monotonic() { + let samples = vec![210u128, 220, 230, 240, 250, 260, 270, 280]; + let p50 = percentile_ms(samples.clone(), 50, 100); + let p95 = percentile_ms(samples.clone(), 95, 100); + assert!(p95 >= p50); +} + +#[test] +fn timing_classifier_helper_threshold_accuracy_is_perfect_for_disjoint_sets() { + let a = vec![10u128, 11, 12, 13, 14]; + let b = vec![20u128, 21, 22, 23, 24]; + let acc = best_threshold_accuracy_u128(&a, &b); + assert!(acc >= 0.99); +} + +#[test] +fn timing_classifier_helper_threshold_accuracy_drops_for_identical_sets() { + let a = vec![10u128, 11, 12, 13, 14]; + let b = vec![10u128, 11, 12, 13, 14]; + let acc = best_threshold_accuracy_u128(&a, &b); + assert!( + acc <= 0.6, + "identical sets should not be strongly separable" + ); +} + +#[test] +fn timing_classifier_helper_bucketed_threshold_reduces_resolution() { + let raw_a = vec![221u128, 223, 225, 227, 229]; + let raw_b = vec![231u128, 233, 235, 237, 239]; + let raw_acc = best_threshold_accuracy_u128(&raw_a, &raw_b); + + let bucketed_a = bucketize_ms(&raw_a, 20); + let bucketed_b = bucketize_ms(&raw_b, 20); + let bucketed_acc = best_threshold_accuracy_u128(&bucketed_a, &bucketed_b); + + assert!(raw_acc >= bucketed_acc); +} + +#[tokio::test] +async fn timing_classifier_baseline_connect_fail_vs_slow_backend_is_highly_separable() { + let fail = collect_timing_samples(PathClass::ConnectFail, false, 8).await; + let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await; + + let acc = best_threshold_accuracy_u128(&fail, &slow); + assert!( + acc >= 0.80, + "baseline timing classes should be separable enough" + ); +} + +#[tokio::test] +async fn timing_classifier_normalized_connect_fail_vs_slow_backend_reduces_separability() { + let baseline_fail = collect_timing_samples(PathClass::ConnectFail, false, 8).await; + let baseline_slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await; + let hardened_fail = collect_timing_samples(PathClass::ConnectFail, true, 8).await; + let hardened_slow = collect_timing_samples(PathClass::SlowBackend, true, 8).await; + + let baseline_acc = best_threshold_accuracy_u128(&baseline_fail, &baseline_slow); + let hardened_acc = best_threshold_accuracy_u128(&hardened_fail, &hardened_slow); + + assert!( + hardened_acc <= baseline_acc, + "normalization should not increase timing separability" + ); +} + +#[tokio::test] +async fn timing_classifier_bucketed_normalized_connect_fail_vs_slow_backend_is_bounded() { + let baseline_fail = collect_timing_samples(PathClass::ConnectFail, false, 10).await; + let baseline_slow = collect_timing_samples(PathClass::SlowBackend, false, 10).await; + let hardened_fail = collect_timing_samples(PathClass::ConnectFail, true, 10).await; + let hardened_slow = collect_timing_samples(PathClass::SlowBackend, true, 10).await; + + let baseline_acc = best_threshold_accuracy_u128( + &bucketize_ms(&baseline_fail, 20), + &bucketize_ms(&baseline_slow, 20), + ); + let hardened_acc = best_threshold_accuracy_u128( + &bucketize_ms(&hardened_fail, 20), + &bucketize_ms(&hardened_slow, 20), + ); + + assert!( + hardened_acc <= baseline_acc, + "normalized bucketed classifier should not outperform baseline: baseline={baseline_acc:.3} hardened={hardened_acc:.3}" + ); +} + +#[tokio::test] +async fn timing_classifier_normalized_connect_fail_samples_stay_in_sane_bounds() { + let samples = collect_timing_samples(PathClass::ConnectFail, true, 6).await; + for s in samples { + assert!((150..=1200).contains(&s), "sample out of sane bounds: {s}"); + } +} + +#[tokio::test] +async fn timing_classifier_normalized_connect_success_samples_stay_in_sane_bounds() { + let samples = collect_timing_samples(PathClass::ConnectSuccess, true, 6).await; + for s in samples { + assert!((150..=1200).contains(&s), "sample out of sane bounds: {s}"); + } +} + +#[tokio::test] +async fn timing_classifier_normalized_slow_backend_samples_stay_in_sane_bounds() { + let samples = collect_timing_samples(PathClass::SlowBackend, true, 6).await; + for s in samples { + assert!((150..=1400).contains(&s), "sample out of sane bounds: {s}"); + } +} + +#[tokio::test] +async fn timing_classifier_normalized_mean_bucket_delta_connect_fail_vs_connect_success_is_small() { + let fail = collect_timing_samples(PathClass::ConnectFail, true, 8).await; + let success = collect_timing_samples(PathClass::ConnectSuccess, true, 8).await; + let fail_mean = mean_ms(&fail); + let success_mean = mean_ms(&success); + let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20; + assert!( + delta_bucket <= 3, + "mean bucket delta too large: {delta_bucket}" + ); +} + +#[tokio::test] +async fn timing_classifier_normalized_p95_bucket_delta_connect_success_vs_slow_is_small() { + let success = collect_timing_samples(PathClass::ConnectSuccess, true, 10).await; + let slow = collect_timing_samples(PathClass::SlowBackend, true, 10).await; + let p95_success = percentile_ms(success, 95, 100); + let p95_slow = percentile_ms(slow, 95, 100); + let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20; + assert!( + delta_bucket <= 4, + "p95 bucket delta too large: {delta_bucket}" + ); +} + +#[tokio::test] +async fn timing_classifier_normalized_spread_is_not_worse_than_baseline_for_connect_fail() { + let baseline = collect_timing_samples(PathClass::ConnectFail, false, 8).await; + let hardened = collect_timing_samples(PathClass::ConnectFail, true, 8).await; + let baseline_spread = spread_u128(&baseline); + let hardened_spread = spread_u128(&hardened); + assert!( + hardened_spread <= baseline_spread.saturating_add(600), + "normalized spread exploded unexpectedly: baseline={baseline_spread} hardened={hardened_spread}" + ); +} + +#[tokio::test] +async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization() +{ + const SAMPLE_COUNT: usize = 6; + + let pairs = [ + (PathClass::ConnectFail, PathClass::ConnectSuccess), + (PathClass::ConnectFail, PathClass::SlowBackend), + (PathClass::ConnectSuccess, PathClass::SlowBackend), + ]; + + let mut meaningful_improvement_seen = false; + let mut baseline_sum = 0.0f64; + let mut hardened_sum = 0.0f64; + let mut pair_count = 0usize; + let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64; + let tolerated_pair_regression = acc_quant_step + 0.03; + + for (a, b) in pairs { + let baseline_a = collect_timing_samples(a, false, SAMPLE_COUNT).await; + let baseline_b = collect_timing_samples(b, false, SAMPLE_COUNT).await; + let hardened_a = collect_timing_samples(a, true, SAMPLE_COUNT).await; + let hardened_b = collect_timing_samples(b, true, SAMPLE_COUNT).await; + + let baseline_acc = best_threshold_accuracy_u128( + &bucketize_ms(&baseline_a, 20), + &bucketize_ms(&baseline_b, 20), + ); + let hardened_acc = best_threshold_accuracy_u128( + &bucketize_ms(&hardened_a, 20), + &bucketize_ms(&hardened_b, 20), + ); + + // When baseline separability is near-random, tiny sample jitter can make + // hardened appear "worse" without indicating a real side-channel regression. + // Guard hard only on informative baseline pairs. + if baseline_acc >= 0.75 { + assert!( + hardened_acc <= baseline_acc + tolerated_pair_regression, + "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}" + ); + } + + println!( + "timing_classifier_pair baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated_pair_regression={tolerated_pair_regression:.3}" + ); + + if hardened_acc + 0.05 <= baseline_acc { + meaningful_improvement_seen = true; + } + + baseline_sum += baseline_acc; + hardened_sum += hardened_acc; + pair_count += 1; + } + + let baseline_avg = baseline_sum / pair_count as f64; + let hardened_avg = hardened_sum / pair_count as f64; + + assert!( + hardened_avg <= baseline_avg + 0.10, + "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" + ); + + // Optional signal only: do not require improvement on every run because + // noisy CI schedulers can flatten pairwise differences at low sample counts. + let _ = meaningful_improvement_seen; +} + +#[tokio::test] +async fn timing_classifier_stress_parallel_sampling_finishes_and_stays_bounded() { + let mut tasks = Vec::new(); + for i in 0..24usize { + tasks.push(tokio::spawn(async move { + let class = match i % 3 { + 0 => PathClass::ConnectFail, + 1 => PathClass::ConnectSuccess, + _ => PathClass::SlowBackend, + }; + let sample = measure_masking_duration_ms(class, true).await; + assert!( + (100..=1600).contains(&sample), + "stress sample out of bounds: {sample}" + ); + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + } +} diff --git a/src/proxy/tests/masking_adversarial_tests.rs b/src/proxy/tests/masking_adversarial_tests.rs new file mode 100644 index 0000000..ce2807a --- /dev/null +++ b/src/proxy/tests/masking_adversarial_tests.rs @@ -0,0 +1,806 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::proxy::relay::relay_bidirectional; +use crate::stats::Stats; +use crate::stats::beobachten::BeobachtenStore; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +// ------------------------------------------------------------------ +// Probing Indistinguishability (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_probes_indistinguishable_timing() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 80; // Should timeout/refuse + + let peer: SocketAddr = "192.0.2.10:443".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + // Test different probe types + let probes = vec![ + (b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"), + (b"SSH-2.0-probe".to_vec(), "SSH"), + ( + vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], + "TLS-scanner", + ), + (vec![0x42; 5], "port-scanner"), + ]; + + for (probe, type_name) in probes { + let (client_reader, _client_writer) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + + let start = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = start.elapsed(); + + // We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests) + // to mask whether the backend was reachable or refused. + assert!( + elapsed >= Duration::from_millis(30), + "Probe {type_name} finished too fast: {elapsed:?}" + ); + } +} + +// ------------------------------------------------------------------ +// Masking Budget Stress Tests (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_budget_stress_under_load() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; // Unlikely port + + let peer: SocketAddr = "192.0.2.20:443".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = Arc::new(BeobachtenStore::new()); + + let mut tasks = Vec::new(); + for _ in 0..50 { + let (client_reader, _client_writer) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let config = config.clone(); + let beobachten = Arc::clone(&beobachten); + + tasks.push(tokio::spawn(async move { + let start = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"probe", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + start.elapsed() + })); + } + + for task in tasks { + let elapsed = task.await.unwrap(); + assert!( + elapsed >= Duration::from_millis(30), + "Stress probe finished too fast: {elapsed:?}" + ); + } +} + +// ------------------------------------------------------------------ +// detect_client_type Fingerprint Check +// ------------------------------------------------------------------ + +#[test] +fn test_detect_client_type_boundary_cases() { + // 9 bytes = port-scanner + assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner"); + // 10 bytes = unknown + assert_eq!(detect_client_type(&[0x42; 10]), "unknown"); + + // HTTP verbs without trailing space + assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10 + assert_eq!(detect_client_type(b"GET /path"), "HTTP"); +} + +// ------------------------------------------------------------------ +// Priority 2: Slowloris and Slow Read Attacks (OWASP ASVS 5.1.5) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_slowloris_client_idle_timeout_rejected() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut drip = [0u8; 1]; + let drip_read = + tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)) + .await; + assert!( + drip_read.is_err() || drip_read.unwrap().is_err(), + "backend must not receive post-timeout slowloris drip bytes" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let beobachten = BeobachtenStore::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); + + let (mut client_writer, client_reader) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(160)).await; + let _ = client_writer.write_all(b"X").await; + + handle.await.unwrap(); + accept_task.await.unwrap(); +} + +// ------------------------------------------------------------------ +// Priority 2: Fallback Server Down / Fingerprinting (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_fallback_down_mimics_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; // Unlikely port + + let (server_reader, server_writer) = duplex(1024); + let beobachten = BeobachtenStore::new(); + let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap(); + let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); + + let start = Instant::now(); + handle_bad_client( + server_reader, + server_writer, + b"GET / HTTP/1.1\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + + let elapsed = start.elapsed(); + // It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately + assert!( + elapsed >= Duration::from_millis(40), + "Must respect connect budget even on failure: {:?}", + elapsed + ); +} + +// ------------------------------------------------------------------ +// Priority 2: SSRF Prevention (OWASP ASVS 5.1.2) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_ssrf_resolve_internal_ranges_blocked() { + use crate::network::dns_overrides::resolve_socket_addr; + + let blocked_ips = [ + "127.0.0.1", + "169.254.169.254", + "10.0.0.1", + "192.168.1.1", + "0.0.0.0", + ]; + + for ip in blocked_ips { + assert!( + resolve_socket_addr(ip, 80).is_none(), + "runtime DNS overrides must not resolve unconfigured literal host targets" + ); + } +} + +#[tokio::test] +async fn masking_unknown_proxy_protocol_version_falls_back_to_v1_unknown_header() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut header = [0u8; 15]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(&header, b"PROXY UNKNOWN\r\n"); + + let mut payload = [0u8; 5]; + stream.read_exact(&mut payload).await.unwrap(); + assert_eq!(&payload, b"probe"); + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 255; + + let peer: SocketAddr = "198.51.100.77:50001".parse().unwrap(); + let local_addr: SocketAddr = "[2001:db8::10]:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let (client_reader, _client_writer) = duplex(128); + let (_client_visible_reader, client_visible_writer) = duplex(128); + + handle_bad_client( + client_reader, + client_visible_writer, + b"probe", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn masking_zero_length_initial_data_does_not_hang_or_panic() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut one = [0u8; 1]; + let n = tokio::time::timeout(Duration::from_millis(150), stream.read(&mut one)) + .await + .unwrap() + .unwrap(); + assert_eq!( + n, 0, + "backend must observe clean EOF for empty initial payload" + ); + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let peer: SocketAddr = "203.0.113.70:50002".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client_reader, client_writer) = duplex(64); + drop(client_writer); + let (_client_visible_reader, client_visible_writer) = duplex(64); + + handle_bad_client( + client_reader, + client_visible_writer, + b"", + peer, + local, + &config, + &beobachten, + ) + .await; + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn masking_oversized_initial_payload_is_forwarded_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let payload = vec![0xA5u8; 32 * 1024]; + + let accept_task = tokio::spawn({ + let payload = payload.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; payload.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!( + observed, payload, + "large initial payload must stay byte-for-byte" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let peer: SocketAddr = "203.0.113.71:50003".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let (client_reader, _client_writer) = duplex(64); + let (_client_visible_reader, client_visible_writer) = duplex(64); + + handle_bad_client( + client_reader, + client_visible_writer, + &payload, + peer, + local, + &config, + &beobachten, + ) + .await; + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn masking_refused_backend_keeps_constantish_timing_floor_under_burst() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + + let peer: SocketAddr = "203.0.113.72:50004".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + for _ in 0..16 { + let (client_reader, _client_writer) = duplex(128); + let (_client_visible_reader, client_visible_writer) = duplex(128); + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET / HTTP/1.1\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + assert!( + started.elapsed() >= Duration::from_millis(30), + "refused-backend path must keep timing floor to reduce fingerprinting" + ); + } +} + +#[tokio::test] +async fn masking_backend_half_close_then_client_half_close_completes_without_hang() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut pre = [0u8; 4]; + stream.read_exact(&mut pre).await.unwrap(); + assert_eq!(&pre, b"PING"); + stream.write_all(b"PONG").await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let peer: SocketAddr = "203.0.113.73:50005".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client_writer, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"PING", + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + let mut got = [0u8; 4]; + client_visible_reader.read_exact(&mut got).await.unwrap(); + assert_eq!(&got, b"PONG"); + + timeout(Duration::from_secs(2), handle) + .await + .expect("masking task must terminate after bilateral half-close") + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn chaos_burst_reconnect_storm_for_masking_and_relay_concurrently() { + const MASKING_SESSIONS: usize = 48; + const RELAY_SESSIONS: usize = 48; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let backend_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + for _ in 0..MASKING_SESSIONS { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req = [0u8; 32]; + stream.read_exact(&mut req).await.unwrap(); + assert!( + req.starts_with(b"GET /storm/"), + "masking backend must receive storm reconnect probes" + ); + stream.write_all(&backend_reply).await.unwrap(); + stream.shutdown().await.unwrap(); + } + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(config); + let beobachten = Arc::new(BeobachtenStore::new()); + let peer: SocketAddr = "198.51.100.200:55555".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let mut masking_tasks = Vec::with_capacity(MASKING_SESSIONS); + for i in 0..MASKING_SESSIONS { + let config = Arc::clone(&config); + let beobachten = Arc::clone(&beobachten); + let expected_reply = backend_reply.clone(); + masking_tasks.push(tokio::spawn(async move { + let mut probe = [0u8; 32]; + let template = format!("GET /storm/{i:04} HTTP/1.1\r\n\r\n"); + let bytes = template.as_bytes(); + probe[..bytes.len()].copy_from_slice(bytes); + + let (client_reader, client_writer) = duplex(256); + drop(client_writer); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + let mut observed = vec![0u8; expected_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, expected_reply); + + timeout(Duration::from_secs(2), handle) + .await + .expect("masking reconnect task must complete") + .unwrap(); + })); + } + + let mut relay_tasks = Vec::with_capacity(RELAY_SESSIONS); + for i in 0..RELAY_SESSIONS { + relay_tasks.push(tokio::spawn(async move { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "chaos-storm-relay", + stats, + None, + Arc::new(BufferPool::new()), + )); + + let c2s = vec![(i as u8).wrapping_add(1); 64]; + client_peer.write_all(&c2s).await.unwrap(); + let mut c2s_seen = vec![0u8; c2s.len()]; + server_peer.read_exact(&mut c2s_seen).await.unwrap(); + assert_eq!(c2s_seen, c2s); + + let s2c = vec![(i as u8).wrapping_add(17); 96]; + server_peer.write_all(&s2c).await.unwrap(); + let mut s2c_seen = vec![0u8; s2c.len()]; + client_peer.read_exact(&mut s2c_seen).await.unwrap(); + assert_eq!(s2c_seen, s2c); + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay reconnect task must complete") + .unwrap() + .unwrap(); + })); + } + + for task in masking_tasks { + timeout(Duration::from_secs(3), task) + .await + .expect("masking storm join must complete") + .unwrap(); + } + + for task in relay_tasks { + timeout(Duration::from_secs(3), task) + .await + .expect("relay storm join must complete") + .unwrap(); + } + + timeout(Duration::from_secs(3), backend_task) + .await + .expect("masking backend accept loop must complete") + .unwrap(); +} + +fn read_env_usize_or_default(name: &str, default: usize) -> usize { + match std::env::var(name) { + Ok(raw) => match raw.parse::() { + Ok(parsed) if parsed > 0 => parsed, + _ => default, + }, + Err(_) => default, + } +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn chaos_burst_reconnect_storm_for_masking_and_relay_multiwave_soak() { + let waves = read_env_usize_or_default("CHAOS_WAVES", 4); + let masking_per_wave = read_env_usize_or_default("CHAOS_MASKING_PER_WAVE", 160); + let relay_per_wave = read_env_usize_or_default("CHAOS_RELAY_PER_WAVE", 160); + let total_masking = waves * masking_per_wave; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let backend_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + for _ in 0..total_masking { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req = [0u8; 32]; + stream.read_exact(&mut req).await.unwrap(); + assert!( + req.starts_with(b"GET /storm/"), + "mask backend must only receive storm probes" + ); + stream.write_all(&backend_reply).await.unwrap(); + stream.shutdown().await.unwrap(); + } + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(config); + let beobachten = Arc::new(BeobachtenStore::new()); + let peer: SocketAddr = "198.51.100.201:56565".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + for wave in 0..waves { + let mut masking_tasks = Vec::with_capacity(masking_per_wave); + for i in 0..masking_per_wave { + let config = Arc::clone(&config); + let beobachten = Arc::clone(&beobachten); + let expected_reply = backend_reply.clone(); + masking_tasks.push(tokio::spawn(async move { + let mut probe = [0u8; 32]; + let template = format!("GET /storm/{wave:02}-{i:03}\r\n\r\n"); + let bytes = template.as_bytes(); + probe[..bytes.len()].copy_from_slice(bytes); + + let (client_reader, client_writer) = duplex(256); + drop(client_writer); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + let mut observed = vec![0u8; expected_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, expected_reply); + + timeout(Duration::from_secs(3), handle) + .await + .expect("masking storm task must complete") + .unwrap(); + })); + } + + let mut relay_tasks = Vec::with_capacity(relay_per_wave); + for i in 0..relay_per_wave { + relay_tasks.push(tokio::spawn(async move { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "chaos-multiwave-relay", + stats, + None, + Arc::new(BufferPool::new()), + )); + + let c2s = vec![(wave as u8).wrapping_add(i as u8).wrapping_add(1); 32]; + client_peer.write_all(&c2s).await.unwrap(); + let mut c2s_seen = vec![0u8; c2s.len()]; + server_peer.read_exact(&mut c2s_seen).await.unwrap(); + assert_eq!(c2s_seen, c2s); + + let s2c = vec![(wave as u8).wrapping_add(i as u8).wrapping_add(17); 48]; + server_peer.write_all(&s2c).await.unwrap(); + let mut s2c_seen = vec![0u8; s2c.len()]; + client_peer.read_exact(&mut s2c_seen).await.unwrap(); + assert_eq!(s2c_seen, s2c); + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(3), relay_task) + .await + .expect("relay storm task must complete") + .unwrap() + .unwrap(); + })); + } + + for task in masking_tasks { + timeout(Duration::from_secs(6), task) + .await + .expect("masking wave task join must complete") + .unwrap(); + } + + for task in relay_tasks { + timeout(Duration::from_secs(6), task) + .await + .expect("relay wave task join must complete") + .unwrap(); + } + } + + timeout(Duration::from_secs(8), backend_task) + .await + .expect("mask backend must complete all accepted storm sessions") + .unwrap(); +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn masking_timing_bucket_soak_refused_backend_stays_within_narrow_band() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + + let peer: SocketAddr = "203.0.113.74:50006".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut samples = Vec::with_capacity(128); + for _ in 0..128 { + let (client_reader, _client_writer) = duplex(128); + let (_client_visible_reader, client_visible_writer) = duplex(128); + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET / HTTP/1.1\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + samples.push(started.elapsed().as_millis()); + } + + samples.sort_unstable(); + let p10 = samples[samples.len() / 10]; + let p90 = samples[(samples.len() * 9) / 10]; + assert!( + p90.saturating_sub(p10) <= 40, + "timing spread too wide for refused-backend masking path: p10={p10}ms p90={p90}ms" + ); +} diff --git a/src/proxy/tests/masking_aggressive_mode_security_tests.rs b/src/proxy/tests/masking_aggressive_mode_security_tests.rs new file mode 100644 index 0000000..a77fc14 --- /dev/null +++ b/src/proxy/tests/masking_aggressive_mode_security_tests.rs @@ -0,0 +1,107 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +async fn capture_forwarded_len_with_mode( + body_sent: usize, + close_client_after_write: bool, + aggressive_mode: bool, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_hardening_aggressive_mode = aggressive_mode; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + let peer: SocketAddr = "198.51.100.248:57248".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&7000u16.to_be_bytes()); + probe[5..].fill(0x31); + + let fallback = tokio::spawn(async move { + handle_bad_client( + server_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + if close_client_after_write { + client_writer.shutdown().await.unwrap(); + } else { + client_writer.write_all(b"keepalive").await.unwrap(); + tokio::time::sleep(Duration::from_millis(170)).await; + drop(client_writer); + } + + let _ = tokio::time::timeout(Duration::from_secs(4), fallback) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn aggressive_mode_shapes_backend_silent_non_eof_path() { + let body_sent = 17usize; + let floor = 512usize; + + let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await; + let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await; + + assert!(legacy < floor, "legacy mode should keep timeout path unshaped"); + assert!( + aggressive >= floor, + "aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})" + ); +} + +#[tokio::test] +async fn aggressive_mode_enforces_positive_above_cap_blur() { + let body_sent = 5000usize; + let base = 5 + body_sent; + + for _ in 0..48 { + let observed = capture_forwarded_len_with_mode(body_sent, true, true, true, 1).await; + assert!( + observed > base, + "aggressive mode must not emit exact base length when blur is enabled (observed={observed}, base={base})" + ); + } +} diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs new file mode 100644 index 0000000..4519d85 --- /dev/null +++ b/src/proxy/tests/masking_security_tests.rs @@ -0,0 +1,1845 @@ +use super::*; +use crate::config::ProxyConfig; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufReadExt, BufReader, duplex}; +use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; +use tokio::time::{Duration, Instant, sleep, timeout}; + +#[tokio::test] +async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn tls_scanner_probe_keeps_http_like_fallback_surface() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.44:55221".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[TLS-scanner]")); + assert!(snapshot.contains("198.51.100.44-1")); + accept_task.await.unwrap(); +} + +#[test] +fn detect_client_type_covers_ssh_port_scanner_and_unknown() { + assert_eq!(detect_client_type(b"SSH-2.0-OpenSSH_9.7"), "SSH"); + assert_eq!(detect_client_type(b"\x01\x02\x03"), "port-scanner"); + assert_eq!(detect_client_type(b"random-binary-payload"), "unknown"); +} + +#[test] +fn detect_client_type_len_boundary_9_vs_10_bytes() { + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[test] +fn build_mask_proxy_header_version_zero_disables_header() { + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let header = build_mask_proxy_header(0, peer, local_addr); + assert!(header.is_none(), "version 0 must disable PROXY header"); +} + +#[test] +fn build_mask_proxy_header_v2_matches_builder_output() { + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let expected = ProxyProtocolV2Builder::new() + .with_addrs(peer, local_addr) + .build(); + let actual = + build_mask_proxy_header(2, peer, local_addr).expect("v2 mode must produce a header"); + + assert_eq!(actual, expected, "v2 header bytes must be deterministic"); +} + +#[test] +fn build_mask_proxy_header_v1_mixed_ip_family_uses_generic_unknown_form() { + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); + + let expected = ProxyProtocolV1Builder::new().build(); + let actual = + build_mask_proxy_header(1, peer, local_addr).expect("v1 mode must produce a header"); + + assert_eq!(actual, expected, "mixed-family v1 must use UNKNOWN form"); +} + +#[tokio::test] +async fn beobachten_records_scanner_class_when_mask_is_disabled() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.99:41234".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"SSH-2.0-probe"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + beobachten + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + let beobachten = timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[SSH]")); + assert!(snapshot.contains("203.0.113.99-1")); +} + +#[tokio::test] +async fn backend_unavailable_falls_back_to_silent_consume() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.11:42425".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + // Close client reader immediately to force the refusal path to rely on masking budget timing. + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("masking fallback must not complete before connect budget elapses"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "fallback path must absorb immediate connect refusal into connect budget" + ); +} + +#[tokio::test] +async fn backend_reachable_fast_response_waits_mask_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() >= Duration::from_millis(45), + "reachable mask path must also satisfy coarse outcome budget" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_header_write_error_on_tcp_path_still_honors_coarse_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /proxy-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.88:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task).await.expect_err( + "proxy-header write error path should remain inside coarse masking budget window", + ); + assert!( + started.elapsed() >= Duration::from_millis(35), + "proxy-header write error path should avoid immediate-return timing signature" + ); + + accept_task.await.unwrap(); +} + +#[cfg(unix)] +#[tokio::test] +async fn proxy_header_write_error_on_unix_path_still_honors_coarse_outcome_budget() { + let sock_path = format!( + "/tmp/telemt-mask-unix-hdr-err-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.89:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task).await.expect_err( + "unix proxy-header write error path should remain inside coarse masking budget window", + ); + assert!( + started.elapsed() >= Duration::from_millis(35), + "unix proxy-header write error path should avoid immediate-return timing signature" + ); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() { + let sock_path = format!( + "/tmp/telemt-mask-unix-v1-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-v1 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line).unwrap(); + assert!( + header_text.starts_with("PROXY "), + "must start with PROXY prefix" + ); + assert!( + header_text.ends_with("\r\n"), + "v1 header must end with CRLF" + ); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.51:51010".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() { + let sock_path = format!( + "/tmp/telemt-mask-unix-v2-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-v2 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut sig = [0u8; 12]; + stream.read_exact(&mut sig).await.unwrap(); + assert_eq!( + &sig, b"\r\n\r\n\0\r\nQUIT\n", + "v2 signature must match spec" + ); + + let mut fixed = [0u8; 4]; + stream.read_exact(&mut fixed).await.unwrap(); + let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize; + let mut addr_block = vec![0u8; addr_len]; + stream.read_exact(&mut addr_block).await.unwrap(); + + let mut received_probe = vec![0u8; probe.len()]; + stream.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.52:51011".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[tokio::test] +async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"x", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() < Duration::from_millis(20), + "mask-disabled fallback should keep immediate EOF behavior" + ); +} + +#[tokio::test] +async fn backend_reachable_slow_response_not_padded_twice() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(90)).await; + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + assert!(elapsed >= Duration::from_millis(85)); + assert!( + elapsed < Duration::from_millis(170), + "slow reachable backend should not incur an extra full budget after already exceeding it" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() { + const ITER: usize = 20; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let mut refused = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused.push(started.elapsed().as_millis()); + } + + let mut reachable = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe_vec = probe.to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + let refused_mean = refused.iter().copied().sum::() as f64 / refused.len() as f64; + let reachable_mean = reachable.iter().copied().sum::() as f64 / reachable.len() as f64; + let refused_bucket = (refused_mean as u128) / BUCKET_MS; + let reachable_bucket = (reachable_mean as u128) / BUCKET_MS; + + assert!( + refused_bucket.abs_diff(reachable_bucket) <= 1, + "enabled refused and reachable paths must collapse into the same coarse latency bucket" + ); +} + +#[tokio::test] +async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() { + let mut seed: u64 = 0xA5A5_5A5A_1337_4242; + let mut next = || { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + seed + }; + + let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + for _ in 0..40 { + let probe_len = (next() as usize % 96).saturating_add(8); + let mut probe = vec![0u8; probe_len]; + for byte in &mut probe { + *byte = next() as u8; + } + + let use_reachable = (next() & 1) == 0; + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(512); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + if use_reachable { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let probe_vec = probe.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut observed).await.unwrap(); + }); + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + accept_task.await.unwrap(); + } else { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + } + + assert!( + started.elapsed() >= Duration::from_millis(45), + "mask-enabled fallback must preserve coarse timing budget under varied probe shapes" + ); + } +} + +#[tokio::test] +async fn mask_disabled_consumes_client_data_without_response() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.12:45454".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"scanner"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side + .write_all(b"untrusted payload") + .await + .unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn proxy_protocol_v1_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line.clone()).unwrap(); + assert!(header_text.starts_with("PROXY TCP4 ")); + assert!(header_text.ends_with("\r\n")); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.15:50001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_protocol_v2_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut sig = [0u8; 12]; + stream.read_exact(&mut sig).await.unwrap(); + assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n"); + + let mut fixed = [0u8; 4]; + stream.read_exact(&mut fixed).await.unwrap(); + let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize; + + let mut addr_block = vec![0u8; addr_len]; + stream.read_exact(&mut addr_block).await.unwrap(); + + let mut received_probe = vec![0u8; probe.len()]; + stream.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.18:50004".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /mix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line).unwrap(); + assert_eq!(header_text, "PROXY UNKNOWN\r\n"); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.20:50006".parse().unwrap(); + let local_addr: SocketAddr = "[::1]:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_mask_path_forwards_probe_and_response() { + let sock_path = format!( + "/tmp/telemt-mask-test-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.30:50010".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[tokio::test] +async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.33:45455".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"slowloris", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_secs(1), task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn mask_enabled_idle_relay_is_closed_by_idle_timeout_before_global_relay_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /idle HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.34:45456".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(512); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed < Duration::from_millis(150), + "idle unauth relay must terminate on idle timeout instead of waiting for full relay timeout" + ); + + accept_task.await.unwrap(); +} + +struct PendingWriter; + +impl tokio::io::AsyncWrite for PendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct DropTrackedPendingReader { + dropped: Arc, +} + +impl tokio::io::AsyncRead for DropTrackedPendingReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +impl Drop for DropTrackedPendingReader { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + +struct DropTrackedPendingWriter { + dropped: Arc, +} + +impl tokio::io::AsyncWrite for DropTrackedPendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl Drop for DropTrackedPendingWriter { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + +#[tokio::test] +async fn proxy_header_write_timeout_returns_false() { + let mut writer = PendingWriter; + let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await; + assert!(!ok, "Proxy header writes that never complete must time out"); +} + +#[tokio::test] +async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stalls() { + let (mut client_feed_writer, client_feed_reader) = duplex(64); + let (mut client_visible_reader, client_visible_writer) = duplex(64); + let (mut backend_feed_writer, backend_feed_reader) = duplex(64); + + // Make client->mask direction immediately active so the c2m path blocks on PendingWriter. + client_feed_writer.write_all(b"X").await.unwrap(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_feed_reader, + client_visible_writer, + backend_feed_reader, + PendingWriter, + b"", + false, + 0, + 0, + false, + 0, + false, + ) + .await; + }); + + // Allow relay tasks to start, then emulate mask backend response. + sleep(Duration::from_millis(20)).await; + backend_feed_writer + .write_all(b"HTTP/1.1 200 OK\r\n\r\n") + .await + .unwrap(); + backend_feed_writer.shutdown().await.unwrap(); + + let mut observed = vec![0u8; 19]; + timeout( + Duration::from_secs(1), + client_visible_reader.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n"); + + relay.abort(); + let _ = relay.await; +} + +#[tokio::test] +async fn relay_to_mask_preserves_backend_response_after_client_half_close() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let request = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let backend_task = tokio::spawn({ + let request = request.clone(); + let response = response.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed_req = vec![0u8; request.len()]; + stream.read_exact(&mut observed_req).await.unwrap(); + assert_eq!(observed_req, request); + stream.write_all(&response).await.unwrap(); + stream.shutdown().await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_write, client_read) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + let beobachten = BeobachtenStore::new(); + + let fallback_task = tokio::spawn(async move { + handle_bad_client( + client_read, + client_visible_writer, + &request, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + + let mut observed_resp = vec![0u8; response.len()]; + timeout( + Duration::from_secs(1), + client_visible_reader.read_exact(&mut observed_resp), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed_resp, response); + + timeout(Duration::from_secs(1), fallback_task) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(1), backend_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { + let reader_dropped = Arc::new(AtomicBool::new(false)); + let writer_dropped = Arc::new(AtomicBool::new(false)); + let mask_reader_dropped = Arc::new(AtomicBool::new(false)); + let mask_writer_dropped = Arc::new(AtomicBool::new(false)); + + let reader = DropTrackedPendingReader { + dropped: reader_dropped.clone(), + }; + let writer = DropTrackedPendingWriter { + dropped: writer_dropped.clone(), + }; + let mask_read = DropTrackedPendingReader { + dropped: mask_reader_dropped.clone(), + }; + let mask_write = DropTrackedPendingWriter { + dropped: mask_writer_dropped.clone(), + }; + + let timed = timeout( + Duration::from_millis(40), + relay_to_mask( + reader, + writer, + mask_read, + mask_write, + b"", + false, + 0, + 0, + false, + 0, + false, + ), + ) + .await; + + assert!(timed.is_err(), "stalled relay must be bounded by timeout"); + + assert!(reader_dropped.load(Ordering::SeqCst)); + assert!(writer_dropped.load(Ordering::SeqCst)); + assert!(mask_reader_dropped.load(Ordering::SeqCst)); + assert!(mask_writer_dropped.load(Ordering::SeqCst)); +} + +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_masking_classes_under_controlled_inputs() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + // Class 1: masking disabled with immediate EOF (fast fail-closed consume path). + let mut disabled_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + disabled_samples.push(started.elapsed().as_millis()); + } + + // Class 2: masking enabled, backend connect refused. + let mut refused_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused_samples.push(started.elapsed().as_millis()); + } + + // Class 3: masking enabled, backend reachable and immediately responds. + let mut reachable_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe_vec = probe.to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe_vec); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable_samples.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) { + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + (mean, min, p95, max) + } + + let (disabled_mean, disabled_min, disabled_p95, disabled_max) = + summarize(&mut disabled_samples); + let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); + let (reachable_mean, reachable_min, reachable_p95, reachable_max) = + summarize(&mut reachable_samples); + + println!( + "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + disabled_mean, + disabled_min, + disabled_p95, + disabled_max, + (disabled_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + refused_mean, + refused_min, + refused_p95, + refused_max, + (refused_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + reachable_mean, + reachable_min, + reachable_p95, + reachable_max, + (reachable_mean as u128) / BUCKET_MS + ); +} + +#[tokio::test] +async fn backend_connect_refusal_completes_within_bounded_mask_budget() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.41:51001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /bounded HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(45), + "connect refusal path must respect minimum masking budget" + ); + assert!( + elapsed < Duration::from_millis(500), + "connect refusal path must stay bounded and avoid unbounded stall" + ); +} + +#[tokio::test] +async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /oneshot HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let response = response.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&response).await.unwrap(); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.42:51002".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + let mut observed = vec![0u8; response.len()]; + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); + assert_eq!(observed, response); + assert!( + elapsed < Duration::from_millis(190), + "idle backend silence after first response must be cut by relay idle timeout" + ); + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /drip HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut extra = [0u8; 1]; + let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut extra)).await; + assert!( + read_res.is_err() || read_res.unwrap().is_err(), + "drip-fed post-probe byte arriving after idle timeout should not be forwarded" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.43:51003".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_writer_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + sleep(Duration::from_millis(160)).await; + let _ = client_writer_side.write_all(b"X").await; + drop(client_writer_side); + + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} diff --git a/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs b/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs new file mode 100644 index 0000000..3f581e2 --- /dev/null +++ b/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs @@ -0,0 +1,102 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +async fn capture_forwarded_len( + body_sent: usize, + shape_hardening: bool, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = shape_hardening; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + let peer: SocketAddr = "198.51.100.220:57120".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&7000u16.to_be_bytes()); + probe[5..].fill(0x5A); + + let fallback = tokio::spawn(async move { + handle_bad_client( + server_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(4), fallback) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn above_cap_blur_disabled_keeps_exact_above_cap_length() { + let body_sent = 5000usize; + let observed = capture_forwarded_len(body_sent, true, false, 0).await; + assert_eq!(observed, 5 + body_sent); +} + +#[tokio::test] +async fn above_cap_blur_enabled_adds_bounded_random_tail() { + let body_sent = 5000usize; + let base = 5 + body_sent; + let max_extra = 64usize; + + let mut saw_extra = false; + for _ in 0..20 { + let observed = capture_forwarded_len(body_sent, true, true, max_extra).await; + assert!(observed >= base, "observed={observed} base={base}"); + assert!( + observed <= base + max_extra, + "observed={observed} base={} max_extra={max_extra}", + base + ); + if observed > base { + saw_extra = true; + } + } + + assert!( + saw_extra, + "at least one run should produce above-cap blur bytes under randomization" + ); +} diff --git a/src/proxy/tests/masking_shape_bypass_blackhat_tests.rs b/src/proxy/tests/masking_shape_bypass_blackhat_tests.rs new file mode 100644 index 0000000..24ceea4 --- /dev/null +++ b/src/proxy/tests/masking_shape_bypass_blackhat_tests.rs @@ -0,0 +1,182 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +async fn capture_forwarded_len_with_optional_eof( + body_sent: usize, + shape_hardening: bool, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, + close_client_after_write: bool, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = shape_hardening; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + let peer: SocketAddr = "198.51.100.241:57241".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&7000u16.to_be_bytes()); + probe[5..].fill(0x73); + + let fallback = tokio::spawn(async move { + handle_bad_client( + server_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + if close_client_after_write { + client_writer.shutdown().await.unwrap(); + } else { + client_writer.write_all(b"keepalive").await.unwrap(); + tokio::time::sleep(Duration::from_millis(170)).await; + drop(client_writer); + } + + let _ = tokio::time::timeout(Duration::from_secs(4), fallback) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +#[ignore = "red-team detector: shaping on non-EOF timeout path is disabled by design to prevent post-timeout tail leaks"] +async fn security_shape_padding_applies_without_client_eof_when_backend_silent() { + let body_sent = 17usize; + let hardened_floor = 512usize; + + let with_eof = capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, true).await; + let without_eof = + capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, false).await; + + assert!( + with_eof >= hardened_floor, + "EOF path should be shaped to floor (with_eof={with_eof}, floor={hardened_floor})" + ); + assert!( + without_eof >= hardened_floor, + "non-EOF path should also be shaped when backend is silent (without_eof={without_eof}, floor={hardened_floor})" + ); +} + +#[tokio::test] +#[ignore = "red-team detector: blur currently allows zero-extra sample by design within [0..=max] bound"] +async fn security_above_cap_blur_never_emits_exact_base_length() { + let body_sent = 5000usize; + let base = 5 + body_sent; + let max_blur = 1usize; + + for _ in 0..64 { + let observed = + capture_forwarded_len_with_optional_eof(body_sent, true, true, max_blur, true).await; + assert!( + observed > base, + "above-cap blur must add at least one byte when enabled (observed={observed}, base={base})" + ); + } +} + +#[tokio::test] +#[ignore = "red-team detector: shape padding currently depends on EOF, enabling idle-timeout bypass probes"] +async fn redteam_detector_shape_padding_must_not_depend_on_client_eof() { + let body_sent = 17usize; + let hardened_floor = 512usize; + + let with_eof = capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, true).await; + let without_eof = + capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, false).await; + + assert!( + with_eof >= hardened_floor, + "sanity check failed: EOF path should be shaped to floor (with_eof={with_eof}, floor={hardened_floor})" + ); + + assert!( + without_eof >= hardened_floor, + "strict anti-probing model expects shaping even without EOF; observed without_eof={without_eof}, floor={hardened_floor}" + ); +} + +#[tokio::test] +#[ignore = "red-team detector: zero-extra above-cap blur samples leak exact class boundary"] +async fn redteam_detector_above_cap_blur_must_never_emit_exact_base_length() { + let body_sent = 5000usize; + let base = 5 + body_sent; + let mut saw_exact_base = false; + let max_blur = 1usize; + + for _ in 0..96 { + let observed = + capture_forwarded_len_with_optional_eof(body_sent, true, true, max_blur, true).await; + if observed == base { + saw_exact_base = true; + break; + } + } + + assert!( + !saw_exact_base, + "strict anti-classifier model expects >0 blur always; observed exact base length leaks class" + ); +} + +#[tokio::test] +#[ignore = "red-team detector: disjoint above-cap ranges enable near-perfect size-class classification"] +async fn redteam_detector_above_cap_blur_ranges_for_far_classes_should_overlap() { + let mut a_min = usize::MAX; + let mut a_max = 0usize; + let mut b_min = usize::MAX; + let mut b_max = 0usize; + + for _ in 0..48 { + let a = capture_forwarded_len_with_optional_eof(5000, true, true, 64, true).await; + let b = capture_forwarded_len_with_optional_eof(7000, true, true, 64, true).await; + a_min = a_min.min(a); + a_max = a_max.max(a); + b_min = b_min.min(b); + b_max = b_max.max(b); + } + + let overlap = a_min <= b_max && b_min <= a_max; + assert!( + overlap, + "strict anti-classifier model expects overlapping output bands; class_a=[{a_min},{a_max}] class_b=[{b_min},{b_max}]" + ); +} diff --git a/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs b/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs new file mode 100644 index 0000000..5d494b8 --- /dev/null +++ b/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs @@ -0,0 +1,339 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +async fn capture_forwarded_len( + body_sent: usize, + shape_hardening: bool, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = shape_hardening; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (client_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + + let mut initial = vec![0u8; 5 + body_sent]; + initial[0] = 0x16; + initial[1] = 0x03; + initial[2] = 0x01; + initial[3..5].copy_from_slice(&7000u16.to_be_bytes()); + initial[5..].fill(0x5A); + + let peer: SocketAddr = "198.51.100.250:57450".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let fallback = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), fallback) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap() +} + +fn best_threshold_accuracy(a: &[usize], b: &[usize]) -> f64 { + let min_v = *a.iter().chain(b.iter()).min().unwrap(); + let max_v = *a.iter().chain(b.iter()).max().unwrap(); + + let mut best = 0.0f64; + for t in min_v..=max_v { + let correct_a = a.iter().filter(|&&x| x <= t).count(); + let correct_b = b.iter().filter(|&&x| x > t).count(); + let acc = (correct_a + correct_b) as f64 / (a.len() + b.len()) as f64; + if acc > best { + best = acc; + } + } + best +} + +fn nearest_centroid_classifier_accuracy( + samples_a: &[usize], + samples_b: &[usize], + samples_c: &[usize], +) -> f64 { + let mean = |xs: &[usize]| -> f64 { xs.iter().copied().sum::() as f64 / xs.len() as f64 }; + + let ca = mean(samples_a); + let cb = mean(samples_b); + let cc = mean(samples_c); + + let mut correct = 0usize; + let mut total = 0usize; + + for &x in samples_a { + total += 1; + let xf = x as f64; + let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; + if d[0] <= d[1] && d[0] <= d[2] { + correct += 1; + } + } + + for &x in samples_b { + total += 1; + let xf = x as f64; + let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; + if d[1] <= d[0] && d[1] <= d[2] { + correct += 1; + } + } + + for &x in samples_c { + total += 1; + let xf = x as f64; + let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; + if d[2] <= d[0] && d[2] <= d[1] { + correct += 1; + } + } + + correct as f64 / total as f64 +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_blur_reduces_threshold_attack_accuracy() { + const SAMPLES: usize = 120; + const MAX_EXTRA: usize = 96; + const CLASS_A_BODY: usize = 5000; + const CLASS_B_BODY: usize = 5040; + + let mut baseline_a = Vec::with_capacity(SAMPLES); + let mut baseline_b = Vec::with_capacity(SAMPLES); + let mut hardened_a = Vec::with_capacity(SAMPLES); + let mut hardened_b = Vec::with_capacity(SAMPLES); + + for _ in 0..SAMPLES { + baseline_a.push(capture_forwarded_len(CLASS_A_BODY, true, false, 0).await); + baseline_b.push(capture_forwarded_len(CLASS_B_BODY, true, false, 0).await); + hardened_a.push(capture_forwarded_len(CLASS_A_BODY, true, true, MAX_EXTRA).await); + hardened_b.push(capture_forwarded_len(CLASS_B_BODY, true, true, MAX_EXTRA).await); + } + + let baseline_acc = best_threshold_accuracy(&baseline_a, &baseline_b); + let hardened_acc = best_threshold_accuracy(&hardened_a, &hardened_b); + + // Baseline classes are deterministic/non-overlapping -> near-perfect threshold attack. + assert!( + baseline_acc >= 0.99, + "baseline separability unexpectedly low: {baseline_acc:.3}" + ); + // Blur must materially reduce the best one-dimensional length classifier. + assert!( + hardened_acc <= 0.90, + "blur should degrade threshold attack accuracy, got {hardened_acc:.3}" + ); + assert!( + hardened_acc <= baseline_acc - 0.08, + "blur must reduce threshold accuracy by a meaningful margin: baseline={baseline_acc:.3}, hardened={hardened_acc:.3}" + ); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_blur_increases_cross_class_overlap() { + const SAMPLES: usize = 96; + const MAX_EXTRA: usize = 96; + const CLASS_A_BODY: usize = 5000; + const CLASS_B_BODY: usize = 5040; + + let mut baseline_a = std::collections::BTreeSet::new(); + let mut baseline_b = std::collections::BTreeSet::new(); + let mut hardened_a = std::collections::BTreeSet::new(); + let mut hardened_b = std::collections::BTreeSet::new(); + + for _ in 0..SAMPLES { + baseline_a.insert(capture_forwarded_len(CLASS_A_BODY, true, false, 0).await); + baseline_b.insert(capture_forwarded_len(CLASS_B_BODY, true, false, 0).await); + hardened_a.insert(capture_forwarded_len(CLASS_A_BODY, true, true, MAX_EXTRA).await); + hardened_b.insert(capture_forwarded_len(CLASS_B_BODY, true, true, MAX_EXTRA).await); + } + + let baseline_overlap = baseline_a.intersection(&baseline_b).count(); + let hardened_overlap = hardened_a.intersection(&hardened_b).count(); + + assert_eq!(baseline_overlap, 0, "baseline classes should not overlap"); + assert!( + hardened_overlap >= 8, + "blur should create meaningful overlap between classes, got overlap={hardened_overlap}" + ); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_parallel_probe_campaign_keeps_blur_bounds() { + const MAX_EXTRA: usize = 128; + + let mut tasks = Vec::new(); + for i in 0..64usize { + tasks.push(tokio::spawn(async move { + let body = 4300 + (i % 700); + let observed = capture_forwarded_len(body, true, true, MAX_EXTRA).await; + let base = 5 + body; + assert!( + observed >= base && observed <= base + MAX_EXTRA, + "campaign bounds violated for i={i}: observed={observed} base={base}" + ); + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_edge_max_extra_one_has_two_point_support() { + const BODY: usize = 5000; + const BASE: usize = 5 + BODY; + + let mut seen = std::collections::BTreeSet::new(); + for _ in 0..64 { + let observed = capture_forwarded_len(BODY, true, true, 1).await; + assert!( + observed == BASE || observed == BASE + 1, + "max_extra=1 must only produce two-point support" + ); + seen.insert(observed); + } + + assert_eq!( + seen.len(), + 2, + "both support points should appear under repeated sampling" + ); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_negative_blur_without_shape_hardening_is_noop() { + const BODY_A: usize = 5000; + const BODY_B: usize = 5040; + + let mut as_observed = std::collections::BTreeSet::new(); + let mut bs_observed = std::collections::BTreeSet::new(); + for _ in 0..48 { + as_observed.insert(capture_forwarded_len(BODY_A, false, true, 96).await); + bs_observed.insert(capture_forwarded_len(BODY_B, false, true, 96).await); + } + + assert_eq!( + as_observed.len(), + 1, + "without shape hardening class A must stay deterministic" + ); + assert_eq!( + bs_observed.len(), + 1, + "without shape hardening class B must stay deterministic" + ); + assert_ne!( + as_observed, bs_observed, + "distinct classes should remain separable without shaping" + ); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur() + { + const SAMPLES: usize = 80; + const MAX_EXTRA: usize = 96; + const C1: usize = 5000; + const C2: usize = 5040; + const C3: usize = 5080; + + let mut base1 = Vec::with_capacity(SAMPLES); + let mut base2 = Vec::with_capacity(SAMPLES); + let mut base3 = Vec::with_capacity(SAMPLES); + let mut hard1 = Vec::with_capacity(SAMPLES); + let mut hard2 = Vec::with_capacity(SAMPLES); + let mut hard3 = Vec::with_capacity(SAMPLES); + + for _ in 0..SAMPLES { + base1.push(capture_forwarded_len(C1, true, false, 0).await); + base2.push(capture_forwarded_len(C2, true, false, 0).await); + base3.push(capture_forwarded_len(C3, true, false, 0).await); + + hard1.push(capture_forwarded_len(C1, true, true, MAX_EXTRA).await); + hard2.push(capture_forwarded_len(C2, true, true, MAX_EXTRA).await); + hard3.push(capture_forwarded_len(C3, true, true, MAX_EXTRA).await); + } + + let base_acc = nearest_centroid_classifier_accuracy(&base1, &base2, &base3); + let hard_acc = nearest_centroid_classifier_accuracy(&hard1, &hard2, &hard3); + + assert!( + base_acc >= 0.99, + "baseline centroid separability should be near-perfect" + ); + assert!( + hard_acc <= 0.88, + "blur should materially degrade 3-class centroid attack" + ); + assert!( + hard_acc <= base_acc - 0.1, + "accuracy drop should be meaningful" + ); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign() + { + let mut s: u64 = 0xDEAD_BEEF_CAFE_BABE; + for _ in 0..96 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let body = 4097 + (s as usize % 2048); + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let max_extra = 1 + (s as usize % 128); + + let observed = capture_forwarded_len(body, true, true, max_extra).await; + let base = 5 + body; + assert!( + observed >= base && observed <= base + max_extra, + "fuzz bounds violated: body={body} observed={observed} max_extra={max_extra}" + ); + } +} diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs new file mode 100644 index 0000000..982fd26 --- /dev/null +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -0,0 +1,417 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, empty, sink}; +use tokio::time::{Duration, sleep, timeout}; + +fn oracle_len( + total_sent: usize, + shape_enabled: bool, + ended_by_eof: bool, + initial_len: usize, + floor: usize, + cap: usize, +) -> usize { + if shape_enabled && ended_by_eof && initial_len > 0 { + next_mask_shape_bucket(total_sent, floor, cap) + } else { + total_sent + } +} + +async fn run_relay_case( + initial: Vec, + extra: Vec, + close_client: bool, + shape_enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) -> Vec { + let (client_reader, mut client_writer) = duplex(8192); + let (mut mask_observer, mask_writer) = duplex(8192); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + empty(), + mask_writer, + &initial, + shape_enabled, + floor, + cap, + above_cap_blur, + above_cap_blur_max_bytes, + false, + ) + .await; + }); + + if !extra.is_empty() { + client_writer.write_all(&extra).await.unwrap(); + } + + if close_client { + client_writer.shutdown().await.unwrap(); + } + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + + if !close_client { + drop(client_writer); + } + + let mut observed = Vec::new(); + timeout( + Duration::from_secs(2), + mask_observer.read_to_end(&mut observed), + ) + .await + .unwrap() + .unwrap(); + observed +} + +#[tokio::test] +async fn masking_shape_guard_negative_timeout_path_never_shapes_even_with_blur_enabled() { + let initial = b"GET /timeout-path HTTP/1.1\r\n".to_vec(); + let extra = vec![0xCC; 700]; + let total = initial.len() + extra.len(); + + let observed = run_relay_case( + initial.clone(), + extra.clone(), + false, + true, + 512, + 4096, + true, + 1024, + ) + .await; + + assert_eq!(observed.len(), total, "timeout path must stay unshaped"); + assert_eq!(&observed[..initial.len()], initial.as_slice()); + assert_eq!(&observed[initial.len()..], extra.as_slice()); +} + +#[tokio::test] +async fn masking_shape_guard_positive_clean_eof_path_shapes_and_preserves_prefix() { + let initial = b"GET /ok HTTP/1.1\r\n".to_vec(); + let extra = vec![0x55; 300]; + let total = initial.len() + extra.len(); + + let observed = run_relay_case( + initial.clone(), + extra.clone(), + true, + true, + 512, + 4096, + false, + 0, + ) + .await; + + let expected_len = oracle_len(total, true, true, initial.len(), 512, 4096); + assert_eq!( + observed.len(), + expected_len, + "clean EOF path must be bucket-shaped" + ); + assert_eq!(&observed[..initial.len()], initial.as_slice()); + assert_eq!( + &observed[initial.len()..(initial.len() + extra.len())], + extra.as_slice() + ); +} + +#[tokio::test] +async fn masking_shape_guard_edge_empty_initial_remains_transparent_under_clean_eof() { + let initial = Vec::new(); + let extra = vec![0xA1; 257]; + + let observed = run_relay_case(initial, extra.clone(), true, true, 512, 4096, false, 0).await; + + assert_eq!( + observed.len(), + extra.len(), + "empty initial_data must never trigger shaping" + ); + assert_eq!(observed, extra); +} + +#[tokio::test] +async fn masking_shape_guard_light_fuzz_oracle_matches_for_eof_and_timeout_variants() { + let floor = 512usize; + let cap = 4096usize; + + // Deterministic xorshift to keep this fuzz test stable in CI. + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + for _ in 0..96 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let initial_len = (s as usize) % 48; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let extra_len = (s as usize) % 1800; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let close_client = (s & 1) == 0; + + let initial = vec![0x42; initial_len]; + let extra = vec![0x99; extra_len]; + let total = initial_len + extra_len; + + let observed = run_relay_case( + initial.clone(), + extra.clone(), + close_client, + true, + floor, + cap, + false, + 0, + ) + .await; + + let expected = oracle_len(total, true, close_client, initial_len, floor, cap); + assert_eq!( + observed.len(), + expected, + "oracle mismatch: initial_len={initial_len} extra_len={extra_len} close_client={close_client}" + ); + + if initial_len > 0 { + assert_eq!(&observed[..initial_len], initial.as_slice()); + } + if extra_len > 0 { + assert_eq!( + &observed[initial_len..(initial_len + extra_len)], + extra.as_slice(), + "payload prefix must remain byte-for-byte before any optional shaping tail" + ); + } + } +} + +#[tokio::test] +async fn masking_shape_guard_stress_parallel_mixed_sessions_keep_oracle_and_no_hangs() { + let mut tasks = Vec::new(); + + for i in 0..48usize { + tasks.push(tokio::spawn(async move { + let initial_len = if i % 3 == 0 { 0 } else { 5 + (i % 19) }; + let extra_len = 64 + (i * 37 % 1300); + let close_client = i % 2 == 0; + + let initial = vec![i as u8; initial_len]; + let extra = vec![0xE0 | ((i as u8) & 0x0F); extra_len]; + let total = initial_len + extra_len; + + let observed = run_relay_case( + initial.clone(), + extra.clone(), + close_client, + true, + 512, + 4096, + false, + 0, + ) + .await; + + let expected = oracle_len(total, true, close_client, initial_len, 512, 4096); + assert_eq!( + observed.len(), + expected, + "stress oracle mismatch for worker={i} close_client={close_client}" + ); + + if initial_len > 0 { + assert_eq!(&observed[..initial_len], initial.as_slice()); + } + if extra_len > 0 { + assert_eq!( + &observed[initial_len..(initial_len + extra_len)], + extra.as_slice() + ); + } + })); + } + + for task in tasks { + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_leak() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /drip-guard HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut one = [0u8; 1]; + let r = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await; + assert!( + r.is_err() || r.unwrap().is_err(), + "no post-timeout drip/tail may reach backend" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "198.51.100.245:53101".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_writer, client_reader) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + let beobachten = BeobachtenStore::new(); + + let relay = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + sleep(Duration::from_millis(160)).await; + let _ = client_writer.write_all(b"X").await; + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn masking_shape_guard_above_cap_blur_statistical_quality_and_bounds() { + let base_len = 5005usize; // 5-byte header + 5000 payload + let max_extra = 64usize; + let mut extras = Vec::new(); + + for _ in 0..192 { + let observed = run_relay_case( + vec![0x16, 0x03, 0x01, 0x1B, 0x58], + vec![0xAA; 5000], + true, + true, + 512, + 4096, + true, + max_extra, + ) + .await; + + assert!( + observed.len() >= base_len && observed.len() <= base_len + max_extra, + "above-cap blur length must stay in bounded window" + ); + extras.push(observed.len() - base_len); + } + + let unique: std::collections::BTreeSet<_> = extras.iter().copied().collect(); + let mean = extras.iter().copied().sum::() as f64 / extras.len() as f64; + + // For uniform [0..=64], mean is ~32. Keep wide bounds to avoid CI flakiness. + assert!( + (20.0..=44.0).contains(&mean), + "blur mean drifted too far from expected center, mean={mean:.2}" + ); + assert!( + unique.len() >= 16, + "blur distribution appears too low-entropy, unique_extras={}", + unique.len() + ); +} + +#[tokio::test] +async fn masking_shape_guard_above_cap_blur_parallel_stress_keeps_bounds() { + let max_extra = 96usize; + let mut tasks = Vec::new(); + + for i in 0..64usize { + tasks.push(tokio::spawn(async move { + let body_len = 4500 + (i % 256); + let base_len = 5 + body_len; + + let observed = run_relay_case( + vec![0x16, 0x03, 0x01, 0x1B, 0x58], + vec![0xA0 | ((i as u8) & 0x0F); body_len], + true, + true, + 512, + 4096, + true, + max_extra, + ) + .await; + + assert!( + observed.len() >= base_len && observed.len() <= base_len + max_extra, + "parallel blur bounds violated for worker={i}: observed_len={} base_len={} max_extra={}", + observed.len(), + base_len, + max_extra + ); + })); + } + + for task in tasks { + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn masking_shape_guard_above_cap_blur_disabled_keeps_exact_length_even_on_clean_eof() { + let initial = vec![0x16, 0x03, 0x01, 0x1B, 0x58]; + let body = vec![0x77; 5200]; + let expected = initial.len() + body.len(); + + let observed = run_relay_case(initial, body, true, true, 512, 4096, false, 0).await; + assert_eq!( + observed.len(), + expected, + "without above-cap blur the output must remain exact even on clean EOF" + ); +} diff --git a/src/proxy/tests/masking_shape_guard_security_tests.rs b/src/proxy/tests/masking_shape_guard_security_tests.rs new file mode 100644 index 0000000..34a89c4 --- /dev/null +++ b/src/proxy/tests/masking_shape_guard_security_tests.rs @@ -0,0 +1,189 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, timeout}; + +#[tokio::test] +async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let client_payload = vec![0x7A; 64]; + + let accept_task = tokio::spawn({ + let expected = client_payload.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + assert_eq!( + got, expected, + "empty initial_data path must not inject shape padding" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "203.0.113.90:52001".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client_writer, client_reader) = duplex(2048); + let (_client_visible_reader, client_visible_writer) = duplex(2048); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"", + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.write_all(&client_payload).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay_task) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn shape_guard_timeout_exit_does_not_append_padding_after_initial_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /timeout-shape-guard HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut one = [0u8; 1]; + let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await; + assert!( + read_res.is_err() || read_res.unwrap().is_err(), + "idle-timeout path must not append shape padding after initial probe" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "203.0.113.91:52002".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (_client_reader_side, client_reader) = duplex(2048); + let (_client_visible_reader, client_visible_writer) = duplex(2048); + + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_padding() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /shape-bucket HTTP/1.1\r\n".to_vec(); + let extra = vec![0x41; 31]; + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + let extra = extra.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + + let expected_prefix_len = initial.len() + extra.len(); + assert_eq!(&got[..initial.len()], initial.as_slice()); + assert_eq!(&got[initial.len()..expected_prefix_len], extra.as_slice()); + assert_eq!( + got.len(), + 512, + "clean EOF path should still shape to floor bucket" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "203.0.113.92:52003".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client_writer, client_reader) = duplex(4096); + let (_client_visible_reader, client_visible_writer) = duplex(4096); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay_task) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs new file mode 100644 index 0000000..3c886ba --- /dev/null +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -0,0 +1,133 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWrite, duplex, empty, sink}; + +struct CountingWriter { + written: usize, +} + +impl CountingWriter { + fn new() -> Self { + Self { written: 0 } + } +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + self.written = self.written.saturating_add(buf.len()); + std::task::Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + +#[test] +fn shape_bucket_clamps_to_cap_when_next_power_of_two_exceeds_cap() { + let bucket = next_mask_shape_bucket(1200, 1000, 1500); + assert_eq!(bucket, 1500); +} + +#[test] +fn shape_bucket_never_drops_below_total_for_valid_ranges() { + for total in [1usize, 32, 127, 512, 999, 1000, 1001, 1499, 1500, 1501] { + let bucket = next_mask_shape_bucket(total, 1000, 1500); + assert!( + bucket >= total || total >= 1500, + "bucket={bucket} total={total}" + ); + } +} + +#[tokio::test] +async fn maybe_write_shape_padding_writes_exact_delta() { + let mut writer = CountingWriter::new(); + maybe_write_shape_padding(&mut writer, 1200, true, 1000, 1500, false, 0, false).await; + assert_eq!(writer.written, 300); +} + +#[tokio::test] +async fn maybe_write_shape_padding_skips_when_disabled() { + let mut writer = CountingWriter::new(); + maybe_write_shape_padding(&mut writer, 1200, false, 1000, 1500, false, 0, false).await; + assert_eq!(writer.written, 0); +} + +#[tokio::test] +async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() { + let initial = vec![0x16, 0x03, 0x01, 0x04, 0x00]; + let extra = vec![0xAB; 1195]; + + let (client_reader, mut client_writer) = duplex(4096); + let (mut mask_observer, mask_writer) = duplex(4096); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + empty(), + mask_writer, + &initial, + true, + 1000, + 1500, + false, + 0, + false, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + relay.await.unwrap(); + + let mut observed = Vec::new(); + mask_observer.read_to_end(&mut observed).await.unwrap(); + assert_eq!(observed.len(), 1500); + assert_eq!(&observed[..5], &[0x16, 0x03, 0x01, 0x04, 0x00]); + assert!(observed[5..1200].iter().all(|b| *b == 0xAB)); + assert_eq!(observed[1200..].len(), 300); +} + +#[test] +fn shape_bucket_light_fuzz_monotonicity_and_bounds() { + let floor = 512usize; + let cap = 4096usize; + let mut prev = 0usize; + + for step in 1usize..=3000 { + let total = ((step * 37) ^ (step << 3)) % (cap + 512); + let bucket = next_mask_shape_bucket(total, floor, cap); + + if total < cap { + assert!(bucket >= total, "bucket={bucket} total={total}"); + assert!(bucket <= cap, "bucket={bucket} cap={cap}"); + } else { + assert_eq!(bucket, total, "above-cap totals must remain unchanged"); + } + + if total >= prev { + // For non-decreasing inputs, bucket class must not regress. + let prev_bucket = next_mask_shape_bucket(prev, floor, cap); + assert!(bucket >= prev_bucket || total >= cap); + } + + prev = total; + } +} diff --git a/src/proxy/tests/masking_timing_normalization_security_tests.rs b/src/proxy/tests/masking_timing_normalization_security_tests.rs new file mode 100644 index 0000000..327ba6a --- /dev/null +++ b/src/proxy/tests/masking_timing_normalization_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +#[derive(Clone, Copy)] +enum MaskPath { + ConnectFail, + ConnectSuccess, + SlowBackend, +} + +async fn measure_bad_client_duration_ms(path: MaskPath, floor_ms: u64, ceiling_ms: u64) -> u128 { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = floor_ms; + config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms; + + let accept_task = match path { + MaskPath::ConnectFail => { + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + None + } + MaskPath::ConnectSuccess => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + })) + } + MaskPath::SlowBackend => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + tokio::time::sleep(Duration::from_millis(320)).await; + })) + } + }; + + let (client_reader, _client_writer) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let peer: SocketAddr = "198.51.100.221:57121".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET /timing-normalize HTTP/1.1\r\nHost: x\r\n\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + + if let Some(task) = accept_task { + let _ = tokio::time::timeout(Duration::from_secs(2), task).await; + } + + started.elapsed().as_millis() +} + +#[tokio::test] +async fn timing_normalization_envelope_applies_to_connect_fail_and_success() { + let floor = 160u64; + let ceiling = 180u64; + + let fail = measure_bad_client_duration_ms(MaskPath::ConnectFail, floor, ceiling).await; + let success = measure_bad_client_duration_ms(MaskPath::ConnectSuccess, floor, ceiling).await; + + assert!( + fail >= floor as u128, + "connect-fail duration below floor: {fail}ms < {floor}ms" + ); + assert!( + fail <= (ceiling + 60) as u128, + "connect-fail duration exceeded relaxed ceiling: {fail}ms > {}ms", + ceiling + 60 + ); + + assert!( + success >= floor as u128, + "connect-success duration below floor: {success}ms < {floor}ms" + ); + assert!( + success <= (ceiling + 60) as u128, + "connect-success duration exceeded relaxed ceiling: {success}ms > {}ms", + ceiling + 60 + ); + + let delta = fail.abs_diff(success); + assert!( + delta <= 80, + "timing normalization should reduce path divergence (delta={}ms)", + delta + ); +} + +#[tokio::test] +async fn timing_normalization_does_not_sleep_if_path_already_exceeds_ceiling() { + let floor = 120u64; + let ceiling = 150u64; + + let slow = measure_bad_client_duration_ms(MaskPath::SlowBackend, floor, ceiling).await; + + assert!( + slow >= 280, + "slow backend path should remain slow (got {slow}ms)" + ); + assert!( + slow <= 520, + "slow backend path should remain bounded in tests (got {slow}ms)" + ); +} diff --git a/src/proxy/tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs b/src/proxy/tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..3c4a342 --- /dev/null +++ b/src/proxy/tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs @@ -0,0 +1,200 @@ +use super::*; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +#[derive(Clone, Copy)] +enum TimingPath { + ConnectFail, + ConnectSuccess, + SlowBackend, +} + +async fn measure_path_duration_ms(path: TimingPath) -> u128 { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + + let maybe_accept = match path { + TimingPath::ConnectFail => { + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + None + } + TimingPath::ConnectSuccess => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + })) + } + TimingPath::SlowBackend => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + tokio::time::sleep(Duration::from_millis(350)).await; + })) + } + }; + + let peer: SocketAddr = "198.51.100.213:57013".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client_reader, _client_writer) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + + if let Some(task) = maybe_accept { + let _ = tokio::time::timeout(Duration::from_secs(2), task).await; + } + + started.elapsed().as_millis() +} + +fn summarize(values: &[u128]) -> (u128, u128, f64) { + let min = *values.iter().min().unwrap_or(&0); + let max = *values.iter().max().unwrap_or(&0); + let sum: u128 = values.iter().copied().sum(); + let mean = if values.is_empty() { + 0.0 + } else { + sum as f64 / values.len() as f64 + }; + (min, max, mean) +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict path-indistinguishability target"] +async fn redteam_timing_01_connect_fail_success_slow_backend_must_be_within_10ms() { + const ITER: usize = 8; + + let mut fail = Vec::with_capacity(ITER); + let mut success = Vec::with_capacity(ITER); + let mut slow = Vec::with_capacity(ITER); + + for _ in 0..ITER { + fail.push(measure_path_duration_ms(TimingPath::ConnectFail).await); + success.push(measure_path_duration_ms(TimingPath::ConnectSuccess).await); + slow.push(measure_path_duration_ms(TimingPath::SlowBackend).await); + } + + let (_, fail_max, fail_mean) = summarize(&fail); + let (_, success_max, success_mean) = summarize(&success); + let (_, slow_max, slow_mean) = summarize(&slow); + + let global_min = *fail + .iter() + .chain(success.iter()) + .chain(slow.iter()) + .min() + .unwrap(); + let global_max = *fail + .iter() + .chain(success.iter()) + .chain(slow.iter()) + .max() + .unwrap(); + + println!( + "redteam_timing path=connect_fail mean_ms={:.2} max_ms={}", + fail_mean, fail_max + ); + println!( + "redteam_timing path=connect_success mean_ms={:.2} max_ms={}", + success_mean, success_max + ); + println!( + "redteam_timing path=slow_backend mean_ms={:.2} max_ms={}", + slow_mean, slow_max + ); + + assert!( + global_max.saturating_sub(global_min) <= 10, + "strict model expects all masking outcomes in one 10ms bucket: min={global_min} max={global_max}" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict classifier-separability target"] +async fn redteam_timing_02_path_classifier_centroid_accuracy_must_be_below_40pct() { + const ITER: usize = 12; + + let mut fail = Vec::with_capacity(ITER); + let mut success = Vec::with_capacity(ITER); + let mut slow = Vec::with_capacity(ITER); + + for _ in 0..ITER { + fail.push(measure_path_duration_ms(TimingPath::ConnectFail).await as f64); + success.push(measure_path_duration_ms(TimingPath::ConnectSuccess).await as f64); + slow.push(measure_path_duration_ms(TimingPath::SlowBackend).await as f64); + } + + let mean = |v: &Vec| -> f64 { v.iter().sum::() / v.len() as f64 }; + let c_fail = mean(&fail); + let c_success = mean(&success); + let c_slow = mean(&slow); + + let mut correct = 0usize; + let mut total = 0usize; + + let classify = |x: f64, c0: f64, c1: f64, c2: f64| -> usize { + let d0 = (x - c0).abs(); + let d1 = (x - c1).abs(); + let d2 = (x - c2).abs(); + if d0 <= d1 && d0 <= d2 { + 0 + } else if d1 <= d0 && d1 <= d2 { + 1 + } else { + 2 + } + }; + + for &x in &fail { + total += 1; + if classify(x, c_fail, c_success, c_slow) == 0 { + correct += 1; + } + } + for &x in &success { + total += 1; + if classify(x, c_fail, c_success, c_slow) == 1 { + correct += 1; + } + } + for &x in &slow { + total += 1; + if classify(x, c_fail, c_success, c_slow) == 2 { + correct += 1; + } + } + + let accuracy = correct as f64 / total as f64; + println!( + "redteam_timing_classifier accuracy={:.3} c_fail={:.2} c_success={:.2} c_slow={:.2}", + accuracy, c_fail, c_success, c_slow + ); + + assert!( + accuracy <= 0.40, + "strict model expects poor classifier; observed accuracy={accuracy:.3}" + ); +} diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs new file mode 100644 index 0000000..2c9f3f6 --- /dev/null +++ b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs @@ -0,0 +1,112 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tokio::sync::Barrier; +use tokio::time::{Duration, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!( + "middle-blackhat-held-{}-{idx}", + std::process::id() + ))); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "precondition: bounded lock cache must be saturated" + ); + + let (tx, _rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Close) + .await + .expect("queue prefill should succeed"); + + let pressure_seq_before = relay_pressure_event_seq(); + let pressure_errors = Arc::new(AtomicUsize::new(0)); + let mut pressure_workers = Vec::new(); + for _ in 0..16 { + let tx = tx.clone(); + let pressure_errors = Arc::clone(&pressure_errors); + pressure_workers.push(tokio::spawn(async move { + if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() { + pressure_errors.fetch_add(1, Ordering::Relaxed); + } + })); + } + + let stats = Arc::new(Stats::new()); + let user = format!("middle-blackhat-quota-race-{}", std::process::id()); + let gate = Arc::new(Barrier::new(16)); + + let mut quota_workers = Vec::new(); + for _ in 0..16u8 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + quota_workers.push(tokio::spawn(async move { + gate.wait().await; + let user_lock = quota_user_lock(&user); + let _quota_guard = user_lock.lock().await; + + if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) { + return false; + } + stats.add_user_octets_to(&user, 1); + true + })); + } + + let mut ok_count = 0usize; + let mut denied_count = 0usize; + for worker in quota_workers { + let result = timeout(Duration::from_secs(2), worker) + .await + .expect("quota worker must finish") + .expect("quota worker must not panic"); + if result { + ok_count += 1; + } else { + denied_count += 1; + } + } + + for worker in pressure_workers { + timeout(Duration::from_secs(2), worker) + .await + .expect("pressure worker must finish") + .expect("pressure worker must not panic"); + } + + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "black-hat campaign must not overshoot same-user quota under saturation" + ); + assert!(ok_count <= 1, "at most one quota contender may succeed"); + assert!( + denied_count >= 15, + "all remaining contenders must be quota-denied" + ); + + let pressure_seq_after = relay_pressure_event_seq(); + assert!( + pressure_seq_after > pressure_seq_before, + "queue pressure leg must trigger pressure accounting" + ); + assert!( + pressure_errors.load(Ordering::Relaxed) >= 1, + "at least one pressure worker should fail from persistent backpressure" + ); + + drop(retained); +} diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs new file mode 100644 index 0000000..fff26b4 --- /dev/null +++ b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs @@ -0,0 +1,708 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::crypto::SecureRandom; +use crate::stats::Stats; +use crate::stream::{BufferPool, PooledBuffer}; +use std::sync::Arc; +use tokio::io::AsyncReadExt; +use tokio::io::duplex; +use tokio::sync::mpsc; +use tokio::time::{Duration as TokioDuration, timeout}; + +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +#[tokio::test] +async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + RPC_FLAG_QUICKACK, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("abridged quickack payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 1 + payload.len()]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read serialized abridged frame"); + let plaintext = decryptor.decrypt(&encrypted); + + assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8)); + assert_eq!(&plaintext[1..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_abridged_extended_header_is_encoded_correctly() { + let (mut read_side, write_side) = duplex(16 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + // Boundary where abridged switches to extended length encoding. + let payload = vec![0x5Au8; 0x7f * 4]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + RPC_FLAG_QUICKACK, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("extended abridged payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 4 + payload.len()]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read serialized extended abridged frame"); + let plaintext = decryptor.decrypt(&encrypted); + + assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set"); + assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]); + assert_eq!(&plaintext[4..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() { + let (_read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + let err = write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &[1, 2, 3], + &rng, + &mut frame_buf, + ) + .await + .expect_err("misaligned abridged payload must be rejected"); + + let msg = format!("{err}"); + assert!( + msg.contains("4-byte aligned"), + "error should explain alignment contract, got: {msg}" + ); +} + +#[tokio::test] +async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() { + let (_read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + let err = write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &[9, 8, 7, 6, 5], + &rng, + &mut frame_buf, + ) + .await + .expect_err("misaligned secure payload must be rejected"); + + let msg = format!("{err}"); + assert!( + msg.contains("Secure payload must be 4-byte aligned"), + "error should be explicit for fail-closed triage, got: {msg}" + ); +} + +#[tokio::test] +async fn write_client_payload_intermediate_quickack_sets_length_msb() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = b"hello-middle-relay"; + + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + RPC_FLAG_QUICKACK, + payload, + &rng, + &mut frame_buf, + ) + .await + .expect("intermediate quickack payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 4 + payload.len()]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read intermediate frame"); + let plaintext = decryptor.decrypt(&encrypted); + + let mut len_bytes = [0u8; 4]; + len_bytes.copy_from_slice(&plaintext[..4]); + let len_with_flags = u32::from_le_bytes(len_bytes); + assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set"); + assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len()); + assert_eq!(&plaintext[4..], payload); +} + +#[tokio::test] +async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode. + + write_client_payload( + &mut writer, + ProtoTag::Secure, + RPC_FLAG_QUICKACK, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("secure quickack payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + // Secure mode adds 1..=3 bytes of randomized tail padding. + let mut encrypted_header = [0u8; 4]; + read_side + .read_exact(&mut encrypted_header) + .await + .expect("must read secure header"); + let decrypted_header = decryptor.decrypt(&encrypted_header); + let header: [u8; 4] = decrypted_header + .try_into() + .expect("decrypted secure header must be 4 bytes"); + let wire_len_raw = u32::from_le_bytes(header); + + assert_ne!( + wire_len_raw & 0x8000_0000, + 0, + "secure quickack bit must be set" + ); + + let wire_len = (wire_len_raw & 0x7fff_ffff) as usize; + assert!(wire_len >= payload.len()); + let padding_len = wire_len - payload.len(); + assert!( + (1..=3).contains(&padding_len), + "secure writer must add bounded random tail padding, got {padding_len}" + ); + + let mut encrypted_body = vec![0u8; wire_len]; + read_side + .read_exact(&mut encrypted_body) + .await + .expect("must read secure body"); + let decrypted_body = decryptor.decrypt(&encrypted_body); + assert_eq!(&decrypted_body[..payload.len()], payload.as_slice()); +} + +#[tokio::test] +#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"] +async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() { + let (_read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + // Exactly one 4-byte word above the encodable 24-bit abridged length range. + let payload = vec![0x00u8; (1 << 24) * 4]; + let err = write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect_err("oversized abridged payload must be rejected"); + + let msg = format!("{err}"); + assert!( + msg.contains("Abridged frame too large"), + "error must clearly indicate oversize fail-close path, got: {msg}" + ); +} + +#[tokio::test] +async fn write_client_ack_intermediate_is_little_endian() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + + write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44) + .await + .expect("ack serialization should succeed"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 4]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read ack bytes"); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes()); +} + +#[tokio::test] +async fn write_client_ack_abridged_is_big_endian() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + + write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF) + .await + .expect("ack serialization should succeed"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 4]; + read_side + .read_exact(&mut encrypted) + .await + .expect("must read ack bytes"); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes()); +} + +#[tokio::test] +async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() { + let (mut read_side, write_side) = duplex(1024 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0xABu8; 0x7e * 4]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("boundary payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 1 + payload.len()]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain[0], 0x7e); + assert_eq!(&plain[1..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() { + let (mut read_side, write_side) = duplex(16 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = vec![0x42u8; 0x80 * 4]; + + write_client_payload( + &mut writer, + ProtoTag::Abridged, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("extended payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = vec![0u8; 4 + payload.len()]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain[0], 0x7f); + assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]); + assert_eq!(&plain[4..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_intermediate_zero_length_emits_header_only() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + 0, + &[], + &rng, + &mut frame_buf, + ) + .await + .expect("zero-length intermediate payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 4]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + assert_eq!(plain.as_slice(), &[0, 0, 0, 0]); +} + +#[tokio::test] +async fn write_client_payload_intermediate_ignores_unrelated_flags() { + let (mut read_side, write_side) = duplex(1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = [7u8; 12]; + + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + 0x4000_0000, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted = [0u8; 16]; + read_side.read_exact(&mut encrypted).await.unwrap(); + let plain = decryptor.decrypt(&encrypted); + let len = u32::from_le_bytes(plain[0..4].try_into().unwrap()); + assert_eq!(len, payload.len() as u32, "only quickack bit may affect header"); + assert_eq!(&plain[4..], payload.as_slice()); +} + +#[tokio::test] +async fn write_client_payload_secure_without_quickack_keeps_msb_clear() { + let (mut read_side, write_side) = duplex(4096); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = [0x1Du8; 64]; + + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted_header = [0u8; 4]; + read_side.read_exact(&mut encrypted_header).await.unwrap(); + let plain_header = decryptor.decrypt(&encrypted_header); + let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); + let wire_len_raw = u32::from_le_bytes(h); + assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear"); +} + +#[tokio::test] +async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() { + let (mut read_side, write_side) = duplex(256 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let payload = [0x55u8; 100]; + let mut seen = [false; 4]; + + for _ in 0..96 { + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("secure payload should serialize"); + writer.flush().await.expect("flush must succeed"); + + let mut encrypted_header = [0u8; 4]; + read_side.read_exact(&mut encrypted_header).await.unwrap(); + let plain_header = decryptor.decrypt(&encrypted_header); + let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); + let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize; + let padding_len = wire_len - payload.len(); + assert!((1..=3).contains(&padding_len)); + seen[padding_len] = true; + + let mut encrypted_body = vec![0u8; wire_len]; + read_side.read_exact(&mut encrypted_body).await.unwrap(); + let _ = decryptor.decrypt(&encrypted_body); + } + + let distinct = (1..=3).filter(|idx| seen[*idx]).count(); + assert!( + distinct >= 2, + "padding generator should not collapse to a single outcome under campaign" + ); +} + +#[tokio::test] +async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() { + let (mut read_side, write_side) = duplex(128 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); + let mut decryptor = AesCtr::new(&key, iv); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + let p1 = vec![1u8; 8]; + let p2 = vec![2u8; 16]; + let p3 = vec![3u8; 20]; + + write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf) + .await + .unwrap(); + write_client_payload( + &mut writer, + ProtoTag::Intermediate, + RPC_FLAG_QUICKACK, + &p2, + &rng, + &mut frame_buf, + ) + .await + .unwrap(); + write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf) + .await + .unwrap(); + writer.flush().await.unwrap(); + + // Frame 1: abridged short. + let mut e1 = vec![0u8; 1 + p1.len()]; + read_side.read_exact(&mut e1).await.unwrap(); + let d1 = decryptor.decrypt(&e1); + assert_eq!(d1[0], (p1.len() / 4) as u8); + assert_eq!(&d1[1..], p1.as_slice()); + + // Frame 2: intermediate with quickack. + let mut e2 = vec![0u8; 4 + p2.len()]; + read_side.read_exact(&mut e2).await.unwrap(); + let d2 = decryptor.decrypt(&e2); + let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap()); + assert_ne!(l2 & 0x8000_0000, 0); + assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len()); + assert_eq!(&d2[4..], p2.as_slice()); + + // Frame 3: secure with bounded tail. + let mut e3h = [0u8; 4]; + read_side.read_exact(&mut e3h).await.unwrap(); + let d3h = decryptor.decrypt(&e3h); + let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize; + assert!(l3 >= p3.len()); + assert!((1..=3).contains(&(l3 - p3.len()))); + let mut e3b = vec![0u8; l3]; + read_side.read_exact(&mut e3b).await.unwrap(); + let d3b = decryptor.decrypt(&e3b); + assert_eq!(&d3b[..p3.len()], p3.as_slice()); +} + +#[test] +fn should_yield_sender_boundary_matrix_blackhat() { + assert!(!should_yield_c2me_sender(0, false)); + assert!(!should_yield_c2me_sender(0, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); + assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); + assert!(should_yield_c2me_sender( + C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024), + true + )); +} + +#[test] +fn should_yield_sender_light_fuzz_matches_oracle() { + let mut s: u64 = 0xD00D_BAAD_F00D_CAFE; + for _ in 0..5000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let sent = (s as usize) & 0x1fff; + let backlog = (s & 1) != 0; + + let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET; + assert_eq!(should_yield_c2me_sender(sent, backlog), expected); + } +} + +#[test] +fn quota_would_be_exceeded_exact_remaining_one_byte() { + let stats = Stats::new(); + let user = "quota-edge"; + let quota = 100u64; + stats.add_user_octets_to(user, 99); + + assert!( + !quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1), + "exactly remaining budget should be allowed" + ); + assert!( + quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), + "one byte beyond remaining budget must be rejected" + ); +} + +#[test] +fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() { + let stats = Stats::new(); + let user = "quota-saturating-edge"; + let quota = u64::MAX - 3; + stats.add_user_octets_to(user, u64::MAX - 4); + + assert!( + quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), + "saturating arithmetic edge must stay fail-closed" + ); +} + +#[test] +fn quota_exceeded_boundary_is_inclusive() { + let stats = Stats::new(); + let user = "quota-inclusive-boundary"; + stats.add_user_octets_to(user, 50); + + assert!(quota_exceeded_for_user(&stats, user, Some(50))); + assert!(!quota_exceeded_for_user(&stats, user, Some(51))); +} + +#[tokio::test] +async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { + let (tx, mut rx) = mpsc::channel::(4); + enqueue_c2me_command(&tx, C2MeCommand::Close) + .await + .expect("close should enqueue on fast path"); + + let recv = timeout(TokioDuration::from_millis(50), rx.recv()) + .await + .expect("must receive close command") + .expect("close command should be present"); + assert!(matches!(recv, C2MeCommand::Close)); +} + +#[tokio::test] +async fn enqueue_c2me_data_full_then_drain_preserves_order() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[1]), + flags: 10, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload(&[2, 2]), + flags: 20, + }, + ) + .await + }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + + let first = rx.recv().await.expect("first item should exist"); + match first { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[1]); + assert_eq!(flags, 10); + } + C2MeCommand::Close => panic!("unexpected close as first item"), + } + + producer.await.unwrap().expect("producer should complete"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("second item should exist"); + match second { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[2, 2]); + assert_eq!(flags, 20); + } + C2MeCommand::Close => panic!("unexpected close as second item"), + } +} diff --git a/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs new file mode 100644 index 0000000..dab0dff --- /dev/null +++ b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs @@ -0,0 +1,191 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::thread; + +#[test] +fn desync_all_full_bypass_does_not_initialize_or_grow_dedup_cache() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let initial_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0); + let now = Instant::now(); + + for i in 0..20_000u64 { + assert!( + should_emit_full_desync(0xD35E_D000_0000_0000u64 ^ i, true, now), + "desync_all_full path must always emit" + ); + } + + let after_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0); + assert_eq!( + after_len, initial_len, + "desync_all_full bypass must not allocate or accumulate dedup entries" + ); +} + +#[test] +fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let seed_time = Instant::now() - Duration::from_secs(7); + dedup.insert(0xAAAABBBBCCCCDDDD, seed_time); + dedup.insert(0x1111222233334444, seed_time); + + let now = Instant::now(); + for i in 0..2048u64 { + assert!( + should_emit_full_desync(0xF011_F000_0000_0000u64 ^ i, true, now), + "desync_all_full must bypass suppression and dedup refresh" + ); + } + + assert_eq!( + dedup.len(), + 2, + "bypass path must not mutate dedup cardinality" + ); + assert_eq!( + *dedup + .get(&0xAAAABBBBCCCCDDDD) + .expect("seed key must remain"), + seed_time, + "bypass path must not refresh existing dedup timestamps" + ); + assert_eq!( + *dedup + .get(&0x1111222233334444) + .expect("seed key must remain"), + seed_time, + "bypass path must not touch unrelated dedup entries" + ); +} + +#[test] +fn edge_all_full_burst_does_not_poison_later_false_path_tracking() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for i in 0..8192u64 { + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0000 ^ i, + true, + now + )); + } + + let tracked_key = 0xDEAD_BEEF_0000_0001u64; + assert!( + should_emit_full_desync(tracked_key, false, now), + "first false-path event after all_full burst must still be tracked and emitted" + ); + + let dedup = DESYNC_DEDUP + .get() + .expect("false path should initialize dedup"); + assert!(dedup.get(&tracked_key).is_some()); +} + +#[test] +fn adversarial_mixed_sequence_true_steps_never_change_cache_len() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + for i in 0..256u64 { + dedup.insert(0x1000_0000_0000_0000 ^ i, Instant::now()); + } + + let mut seed = 0xC0DE_CAFE_BAAD_F00Du64; + for i in 0..4096u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let flag_all_full = (seed & 0x1) == 1; + let key = 0x7000_0000_0000_0000u64 ^ i ^ seed; + let before = dedup.len(); + let _ = should_emit_full_desync(key, flag_all_full, Instant::now()); + let after = dedup.len(); + + if flag_all_full { + assert_eq!(after, before, "all_full step must not mutate dedup length"); + } + } +} + +#[test] +fn light_fuzz_all_full_mode_always_emits_and_stays_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let mut seed = 0x1234_5678_9ABC_DEF0u64; + let before = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0); + + for _ in 0..20_000 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let key = seed ^ 0x55AA_55AA_55AA_55AAu64; + assert!(should_emit_full_desync(key, true, Instant::now())); + } + + let after = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0); + assert_eq!(after, before); + assert!(after <= DESYNC_DEDUP_MAX_ENTRIES); +} + +#[test] +fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let seed_time = Instant::now() - Duration::from_secs(2); + for i in 0..1024u64 { + dedup.insert(0x8888_0000_0000_0000 ^ i, seed_time); + } + let before_len = dedup.len(); + + let emits = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + for worker in 0..16u64 { + let emits = Arc::clone(&emits); + workers.push(thread::spawn(move || { + let now = Instant::now(); + for i in 0..4096u64 { + let key = 0xFACE_0000_0000_0000u64 ^ (worker << 20) ^ i; + if should_emit_full_desync(key, true, now) { + emits.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096); + assert_eq!( + dedup.len(), + before_len, + "parallel all_full storm must not mutate cache len" + ); +} diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs new file mode 100644 index 0000000..3e0b30f --- /dev/null +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -0,0 +1,816 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::atomic::AtomicU64; +use std::sync::{Arc, Mutex, OnceLock}; +use tokio::io::AsyncWriteExt; +use tokio::io::duplex; +use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xA000_0000 + conn_id, + conn_id, + user: format!("idle-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(soft_ms), + hard_idle: Duration::from_millis(hard_ms), + grace_after_downstream_activity: Duration::from_millis(grace_ms), + legacy_frame_read_timeout: Duration::from_millis(hard_ms), + } +} + +fn idle_pressure_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { + match idle_pressure_test_lock().lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + } +} + +#[tokio::test] +async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(40, 120, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let start = TokioInstant::now(); + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("idle test must complete"); + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); + let err_text = match result { + Err(ProxyError::Io(ref e)) => e.to_string(), + _ => String::new(), + }; + assert!( + err_text.contains("middle-relay hard idle timeout"), + "hard close must expose a clear timeout reason" + ); + assert!( + start.elapsed() >= TokioDuration::from_millis(80), + "hard timeout must not trigger before idle deadline window" + ); + assert_eq!(stats.get_relay_idle_soft_mark_total(), 1); + assert_eq!(stats.get_relay_idle_hard_close_total(), 1); +} + +#[tokio::test] +async fn idle_policy_downstream_activity_grace_extends_hard_deadline() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(30, 60, 100); + let last_downstream_activity_ms = AtomicU64::new(20); + + let start = TokioInstant::now(); + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("grace test must complete"); + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); + assert!( + start.elapsed() >= TokioDuration::from_millis(100), + "recent downstream activity must extend hard idle deadline" + ); +} + +#[tokio::test] +async fn relay_idle_policy_disabled_keeps_legacy_timeout_behavior() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics(3, Instant::now()); + let mut frame_counter = 0u64; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + Duration::from_millis(60), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); + let err_text = match result { + Err(ProxyError::Io(ref e)) => e.to_string(), + _ => String::new(), + }; + assert!( + err_text.contains("middle-relay client frame read timeout"), + "legacy mode must keep expected timeout reason" + ); + assert_eq!(stats.get_relay_idle_soft_mark_total(), 0); + assert_eq!(stats.get_relay_idle_hard_close_total(), 0); +} + +#[tokio::test] +async fn adversarial_partial_frame_trickle_cannot_bypass_hard_idle_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(4, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(30, 90, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(12); + plaintext.extend_from_slice(&8u32.to_le_bytes()); + plaintext.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]); + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted[..1]) + .await + .expect("must write a single trickle byte"); + + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("partial frame trickle test must complete"); + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); + assert_eq!( + frame_counter, 0, + "partial trickle must not count as a valid frame" + ); +} + +#[tokio::test] +async fn successful_client_frame_resets_soft_idle_mark() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(5, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + idle_state.soft_idle_marked = true; + let idle_policy = make_idle_policy(200, 300, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [9u8, 8, 7, 6, 5, 4, 3, 2]; + let mut plaintext = Vec::with_capacity(4 + payload.len()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("must write full encrypted frame"); + + let read = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("frame read must succeed") + .expect("frame must be returned"); + + assert_eq!(read.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + assert!( + !idle_state.soft_idle_marked, + "a valid client frame must clear soft-idle mark" + ); +} + +#[tokio::test] +async fn protocol_desync_small_frame_updates_reason_counter() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics(6, Instant::now()); + let mut frame_counter = 0u64; + + let mut plaintext = Vec::with_capacity(7); + plaintext.extend_from_slice(&3u32.to_le_bytes()); + plaintext.extend_from_slice(&[1u8, 2, 3]); + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("must write frame"); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small"))); + assert_eq!(stats.get_relay_protocol_desync_close_total(), 1); +} + +#[tokio::test] +async fn stress_many_idle_sessions_fail_closed_without_hang() { + let mut tasks = Vec::with_capacity(24); + + for idx in 0..24u64 { + tasks.push(tokio::spawn(async move { + let (reader, _writer) = duplex(256); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(100 + idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(20, 50, 10); + let last_downstream_activity_ms = AtomicU64::new(0); + + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("stress task must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert_eq!(stats.get_relay_idle_hard_close_total(), 1); + assert_eq!(frame_counter, 0); + })); + } + + for task in tasks { + task.await.expect("stress task must not panic"); + } +} + +#[test] +fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(10)); + assert!(mark_relay_idle_candidate(11)); + assert_eq!(oldest_relay_idle_candidate(), Some(10)); + + note_relay_pressure_event(); + + let mut seen_for_newer = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(11, &mut seen_for_newer, &stats), + "newer idle candidate must not be evicted while older candidate exists" + ); + assert_eq!(oldest_relay_idle_candidate(), Some(10)); + + let mut seen_for_oldest = 0u64; + assert!( + maybe_evict_idle_candidate_on_pressure(10, &mut seen_for_oldest, &stats), + "oldest idle candidate must be evicted first under pressure" + ); + assert_eq!(oldest_relay_idle_candidate(), Some(11)); + assert_eq!(stats.get_relay_pressure_evict_total(), 1); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn pressure_does_not_evict_without_new_pressure_signal() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(21)); + let mut seen = relay_pressure_event_seq(); + + assert!( + !maybe_evict_idle_candidate_on_pressure(21, &mut seen, &stats), + "without new pressure signal, candidate must stay" + ); + assert_eq!(stats.get_relay_pressure_evict_total(), 0); + assert_eq!(oldest_relay_idle_candidate(), Some(21)); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + let mut seen_per_conn = std::collections::HashMap::new(); + for conn_id in 1000u64..1064u64 { + assert!(mark_relay_idle_candidate(conn_id)); + seen_per_conn.insert(conn_id, 0u64); + } + + for expected in 1000u64..1064u64 { + note_relay_pressure_event(); + + let mut seen = *seen_per_conn + .get(&expected) + .expect("per-conn pressure cursor must exist"); + assert!( + maybe_evict_idle_candidate_on_pressure(expected, &mut seen, &stats), + "expected conn_id {expected} must be evicted next by deterministic FIFO ordering" + ); + seen_per_conn.insert(expected, seen); + + let next = if expected == 1063 { + None + } else { + Some(expected + 1) + }; + assert_eq!(oldest_relay_idle_candidate(), next); + } + + assert_eq!(stats.get_relay_pressure_evict_total(), 64); + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(301)); + assert!(mark_relay_idle_candidate(302)); + assert!(mark_relay_idle_candidate(303)); + + let mut seen_301 = 0u64; + let mut seen_302 = 0u64; + let mut seen_303 = 0u64; + + // Single pressure event should authorize at most one eviction globally. + note_relay_pressure_event(); + + let evicted_301 = maybe_evict_idle_candidate_on_pressure(301, &mut seen_301, &stats); + let evicted_302 = maybe_evict_idle_candidate_on_pressure(302, &mut seen_302, &stats); + let evicted_303 = maybe_evict_idle_candidate_on_pressure(303, &mut seen_303, &stats); + + let evicted_total = [evicted_301, evicted_302, evicted_303] + .iter() + .filter(|value| **value) + .count(); + + assert_eq!( + evicted_total, 1, + "single pressure event must not cascade-evict multiple idle candidates" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(401)); + assert!(mark_relay_idle_candidate(402)); + + let mut seen_oldest = 0u64; + let mut seen_next = 0u64; + + note_relay_pressure_event(); + + assert!( + maybe_evict_idle_candidate_on_pressure(401, &mut seen_oldest, &stats), + "oldest candidate must consume pressure budget first" + ); + + assert!( + !maybe_evict_idle_candidate_on_pressure(402, &mut seen_next, &stats), + "next candidate must not consume the same pressure budget" + ); + + assert_eq!( + stats.get_relay_pressure_evict_total(), + 1, + "single pressure budget must produce exactly one eviction" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + // Pressure happened before any idle candidate existed. + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(501)); + + let mut seen = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(501, &mut seen, &stats), + "stale pressure (before soft-idle mark) must not evict newly marked candidate" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(511)); + assert!(mark_relay_idle_candidate(512)); + assert!(mark_relay_idle_candidate(513)); + + let mut seen_511 = 0u64; + let mut seen_512 = 0u64; + let mut seen_513 = 0u64; + + let evicted = [ + maybe_evict_idle_candidate_on_pressure(511, &mut seen_511, &stats), + maybe_evict_idle_candidate_on_pressure(512, &mut seen_512, &stats), + maybe_evict_idle_candidate_on_pressure(513, &mut seen_513, &stats), + ] + .iter() + .filter(|value| **value) + .count(); + + assert_eq!( + evicted, 0, + "stale pressure event must not evict any candidate from a newly marked batch" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + + // Session A observed pressure while there were no candidates. + let mut seen_a = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(999_001, &mut seen_a, &stats), + "no candidate existed, so no eviction is possible" + ); + + // Candidate appears later; Session B must not be able to consume stale pressure. + assert!(mark_relay_idle_candidate(521)); + let mut seen_b = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(521, &mut seen_b, &stats), + "once pressure is observed with empty candidate set, it must not be replayed later" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_must_not_survive_candidate_churn() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(531)); + clear_relay_idle_candidate(531); + assert!(mark_relay_idle_candidate(532)); + + let mut seen = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(532, &mut seen, &stats), + "stale pressure must not survive clear+remark churn cycles" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + { + let mut guard = relay_idle_candidate_registry() + .lock() + .expect("registry lock must be available"); + guard.pressure_event_seq = u64::MAX; + guard.pressure_consumed_seq = u64::MAX - 1; + } + + // A new pressure event should still be representable; saturating at MAX creates a permanent lockout. + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert_ne!( + after, + u64::MAX, + "pressure sequence saturation must not permanently freeze event progression" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + { + let mut guard = relay_idle_candidate_registry() + .lock() + .expect("registry lock must be available"); + guard.pressure_event_seq = u64::MAX; + guard.pressure_consumed_seq = u64::MAX; + } + + note_relay_pressure_event(); + let first = relay_pressure_event_seq(); + note_relay_pressure_event(); + let second = relay_pressure_event_seq(); + + assert!( + second > first, + "distinct pressure events must remain distinguishable even at sequence boundary" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() +{ + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + let stats = Arc::new(Stats::new()); + let sessions = 16usize; + let rounds = 200usize; + let conn_ids: Vec = (10_000u64..10_000u64 + sessions as u64).collect(); + let mut seen_per_session = vec![0u64; sessions]; + + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + + for round in 0..rounds { + note_relay_pressure_event(); + + let mut joins = Vec::with_capacity(sessions); + for (idx, conn_id) in conn_ids.iter().enumerate() { + let mut seen = seen_per_session[idx]; + let conn_id = *conn_id; + let stats = stats.clone(); + joins.push(tokio::spawn(async move { + let evicted = + maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + (idx, conn_id, seen, evicted) + })); + } + + let mut evicted_this_round = 0usize; + let mut evicted_conn = None; + for join in joins { + let (idx, conn_id, seen, evicted) = join.await.expect("race task must not panic"); + seen_per_session[idx] = seen; + if evicted { + evicted_this_round += 1; + evicted_conn = Some(conn_id); + } + } + + assert!( + evicted_this_round <= 1, + "round {round}: one pressure event must never produce more than one eviction" + ); + if let Some(conn) = evicted_conn { + assert!( + mark_relay_idle_candidate(conn), + "round {round}: evicted conn must be re-markable as idle candidate" + ); + } + } + + assert!( + stats.get_relay_pressure_evict_total() <= rounds as u64, + "eviction total must never exceed number of pressure events" + ); + assert!( + stats.get_relay_pressure_evict_total() > 0, + "parallel race must still observe at least one successful eviction" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + let stats = Arc::new(Stats::new()); + let sessions = 12usize; + let rounds = 120usize; + let conn_ids: Vec = (20_000u64..20_000u64 + sessions as u64).collect(); + let mut seen_per_session = vec![0u64; sessions]; + + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + + let mut expected_total_evictions = 0u64; + + for round in 0..rounds { + let empty_phase = round % 5 == 0; + if empty_phase { + for conn_id in &conn_ids { + clear_relay_idle_candidate(*conn_id); + } + } + + note_relay_pressure_event(); + + let mut joins = Vec::with_capacity(sessions); + for (idx, conn_id) in conn_ids.iter().enumerate() { + let mut seen = seen_per_session[idx]; + let conn_id = *conn_id; + let stats = stats.clone(); + joins.push(tokio::spawn(async move { + let evicted = + maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + (idx, conn_id, seen, evicted) + })); + } + + let mut evicted_this_round = 0usize; + let mut evicted_conn = None; + for join in joins { + let (idx, conn_id, seen, evicted) = join.await.expect("burst race task must not panic"); + seen_per_session[idx] = seen; + if evicted { + evicted_this_round += 1; + evicted_conn = Some(conn_id); + } + } + + if empty_phase { + assert_eq!( + evicted_this_round, 0, + "round {round}: empty candidate phase must not allow stale-pressure eviction" + ); + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + } else { + assert!( + evicted_this_round <= 1, + "round {round}: pressure budget must cap at one eviction" + ); + if let Some(conn_id) = evicted_conn { + expected_total_evictions = expected_total_evictions.saturating_add(1); + assert!(mark_relay_idle_candidate(conn_id)); + } + } + } + + assert_eq!( + stats.get_relay_pressure_evict_total(), + expected_total_evictions, + "global pressure eviction counter must match observed per-round successful consumes" + ); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/proxy/tests/middle_relay_length_cast_hardening_security_tests.rs b/src/proxy/tests/middle_relay_length_cast_hardening_security_tests.rs new file mode 100644 index 0000000..6c6644d --- /dev/null +++ b/src/proxy/tests/middle_relay_length_cast_hardening_security_tests.rs @@ -0,0 +1,75 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; + +#[test] +fn intermediate_secure_wire_len_allows_max_31bit_payload() { + let (len_val, total) = compute_intermediate_secure_wire_len(0x7fff_fffe, 1, true) + .expect("31-bit wire length should be accepted"); + + assert_eq!(len_val, 0xffff_ffff, "quickack must use top bit only"); + assert_eq!(total, 0x8000_0003); +} + +#[test] +fn intermediate_secure_wire_len_rejects_length_above_31bit_limit() { + let err = compute_intermediate_secure_wire_len(0x7fff_ffff, 1, false) + .expect_err("wire length above 31-bit must fail closed"); + assert!( + format!("{err}").contains("frame too large"), + "error should identify oversize frame path" + ); +} + +#[test] +fn intermediate_secure_wire_len_rejects_addition_overflow() { + let err = compute_intermediate_secure_wire_len(usize::MAX, 1, false) + .expect_err("overflowing addition must fail closed"); + assert!( + format!("{err}").contains("overflow"), + "error should clearly report overflow" + ); +} + +#[test] +fn desync_forensics_len_bytes_marks_truncation_for_oversize_values() { + let (small_bytes, small_truncated) = desync_forensics_len_bytes(0x1020_3040); + assert_eq!(small_bytes, 0x1020_3040u32.to_le_bytes()); + assert!(!small_truncated); + + let (huge_bytes, huge_truncated) = desync_forensics_len_bytes(usize::MAX); + assert_eq!(huge_bytes, u32::MAX.to_le_bytes()); + assert!(huge_truncated); +} + +#[test] +fn report_desync_frame_too_large_preserves_full_length_in_error_message() { + let state = RelayForensicsState { + trace_id: 0x1234, + conn_id: 0x5678, + user: "middle-desync-oversize".to_string(), + peer: "198.51.100.55:443".parse().expect("valid test peer"), + peer_hash: 0xAABBCCDD, + started_at: Instant::now(), + bytes_c2me: 7, + bytes_me2c: Arc::new(AtomicU64::new(9)), + desync_all_full: false, + }; + + let huge_len = usize::MAX; + let err = report_desync_frame_too_large( + &state, + ProtoTag::Intermediate, + 3, + 1024, + huge_len, + None, + &Stats::new(), + ); + + let msg = format!("{err}"); + assert!( + msg.contains(&huge_len.to_string()), + "error must preserve full usize length for forensics" + ); +} diff --git a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs new file mode 100644 index 0000000..d06e103 --- /dev/null +++ b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs @@ -0,0 +1,131 @@ +use super::*; +use dashmap::DashMap; +use std::sync::Arc; + +#[test] +fn saturation_uses_stable_overflow_lock_without_cache_growth() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("middle-quota-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); + + let user = format!("middle-quota-overflow-{}", std::process::id()); + let first = quota_user_lock(&user); + let second = quota_user_lock(&user); + + assert!( + Arc::ptr_eq(&first, &second), + "overflow user must get deterministic same lock while cache is saturated" + ); + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow path must not grow bounded lock map" + ); + assert!( + map.get(&user).is_none(), + "overflow user should stay outside bounded lock map under saturation" + ); + + drop(retained); +} + +#[test] +fn overflow_striping_keeps_different_users_distributed() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("middle-quota-dist-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + let a = quota_user_lock("middle-overflow-user-a"); + let b = quota_user_lock("middle-overflow-user-b"); + let c = quota_user_lock("middle-overflow-user-c"); + + let distinct = [ + Arc::as_ptr(&a) as usize, + Arc::as_ptr(&b) as usize, + Arc::as_ptr(&c) as usize, + ] + .iter() + .copied() + .collect::>() + .len(); + + assert!( + distinct >= 2, + "striped overflow lock set should avoid collapsing all users to one lock" + ); + + drop(retained); +} + +#[test] +fn reclaim_path_caches_new_user_after_stale_entries_drop() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("middle-quota-reclaim-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + drop(retained); + + let user = format!("middle-quota-reclaim-user-{}", std::process::id()); + let got = quota_user_lock(&user); + assert!(map.get(&user).is_some()); + assert!( + Arc::strong_count(&got) >= 2, + "after reclaim, lock should be held both by caller and map" + ); +} + +#[test] +fn overflow_path_same_user_is_stable_across_parallel_threads() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!( + "middle-quota-thread-held-{}-{idx}", + std::process::id() + ))); + } + + let user = format!("middle-quota-overflow-thread-user-{}", std::process::id()); + let mut workers = Vec::new(); + for _ in 0..32 { + let user = user.clone(); + workers.push(std::thread::spawn(move || quota_user_lock(&user))); + } + + let first = workers + .remove(0) + .join() + .expect("thread must return lock handle"); + for worker in workers { + let got = worker.join().expect("thread must return lock handle"); + assert!( + Arc::ptr_eq(&first, &got), + "same overflow user should resolve to one striped lock even under contention" + ); + } + + drop(retained); +} diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs new file mode 100644 index 0000000..1d3b736 --- /dev/null +++ b/src/proxy/tests/middle_relay_security_tests.rs @@ -0,0 +1,2517 @@ +use super::*; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::AesCtr; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::proxy::handshake::HandshakeSuccess; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; +use crate::transport::middle_proxy::MePool; +use bytes::Bytes; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::collections::{HashMap, HashSet}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::Mutex; +use std::thread; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::duplex; +use tokio::sync::Barrier; +use tokio::time::{Duration as TokioDuration, timeout}; + +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +#[test] +fn should_yield_sender_only_on_budget_with_backlog() { + assert!(!should_yield_c2me_sender(0, true)); + assert!(!should_yield_c2me_sender( + C2ME_SENDER_FAIRNESS_BUDGET - 1, + true + )); + assert!(!should_yield_c2me_sender( + C2ME_SENDER_FAIRNESS_BUDGET, + false + )); + assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); +} + +#[tokio::test] +async fn enqueue_c2me_command_uses_try_send_fast_path() { + let (tx, mut rx) = mpsc::channel::(2); + enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: make_pooled_payload(&[1, 2, 3]), + flags: 0, + }, + ) + .await + .unwrap(); + + let recv = timeout(TokioDuration::from_millis(50), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[1, 2, 3]); + assert_eq!(flags, 0); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[tokio::test] +async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[9]), + flags: 9, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload(&[7, 7]), + flags: 7, + }, + ) + .await + .unwrap(); + }); + + let _ = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap(); + producer.await.unwrap(); + + let recv = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[7, 7]); + assert_eq!(flags, 7); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[tokio::test] +async fn enqueue_c2me_command_closed_channel_recycles_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); + let (tx, rx) = mpsc::channel::(1); + drop(rx); + + let result = enqueue_c2me_command(&tx, C2MeCommand::Data { payload, flags: 0 }).await; + + assert!(result.is_err(), "closed queue must fail enqueue"); + drop(result); + assert!( + pool.stats().pooled >= 1, + "payload must return to pool when enqueue fails on closed channel" + ); +} + +#[tokio::test] +async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let (tx, rx) = mpsc::channel::(1); + + tx.send(C2MeCommand::Data { + payload: make_pooled_payload_from(&pool, &[9]), + flags: 1, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let pool2 = pool.clone(); + let blocked_send = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), + flags: 2, + }, + ) + .await + }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), blocked_send) + .await + .expect("blocked send task must finish") + .expect("blocked send task must not panic"); + + assert!( + result.is_err(), + "closing receiver while sender is blocked must fail enqueue" + ); + drop(result); + assert!( + pool.stats().pooled >= 2, + "both queued and blocked payloads must return to pool after channel close" + ); +} + +#[tokio::test] +async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() { + let (tx, _rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[1]), + flags: 0, + }) + .await + .unwrap(); + + let started = Instant::now(); + let result = enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: make_pooled_payload(&[2, 2]), + flags: 1, + }, + ) + .await; + + assert!( + result.is_err(), + "enqueue must fail when queue stays full beyond bounded timeout" + ); + assert!( + started.elapsed() < TokioDuration::from_millis(400), + "full-queue timeout must resolve promptly" + ); +} + +#[test] +fn desync_dedup_cache_is_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!( + should_emit_full_desync(key, false, now), + "unique keys up to cap must be tracked" + ); + } + + assert!( + should_emit_full_desync(u64::MAX, false, now), + "new key above cap must emit once after bounded eviction for forensic visibility" + ); + + assert!( + !should_emit_full_desync(u64::MAX, false, now), + "already tracked key inside dedup window must stay suppressed" + ); +} + +#[test] +fn quota_user_lock_cache_reuses_entry_for_same_user() { + let _guard = super::quota_user_lock_test_scope(); + + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let a = quota_user_lock("quota-user-a"); + let b = quota_user_lock("quota-user-a"); + assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); +} + +#[test] +fn quota_user_lock_cache_is_bounded_under_unique_churn() { + let _guard = super::quota_user_lock_test_scope(); + + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) { + let user = format!("quota-user-{idx}"); + let lock = quota_user_lock(&user); + drop(lock); + } + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock cache must stay within configured bound" + ); +} + +#[test] +fn quota_user_lock_cache_saturation_returns_stable_overflow_lock_without_growth() { + let _guard = super::quota_user_lock_test_scope(); + + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + for attempt in 0..8u32 { + map.clear(); + + let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("{prefix}-{idx}"); + retained.push(quota_user_lock(&user)); + } + + if map.len() != QUOTA_USER_LOCKS_MAX { + drop(retained); + continue; + } + + let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); + let overflow_a = quota_user_lock(&overflow_user); + let overflow_b = quota_user_lock(&overflow_user); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow acquisition must not grow cache past hard limit" + ); + assert!( + map.get(&overflow_user).is_none(), + "overflow path should not cache new user lock when map is saturated and all entries are retained" + ); + assert!( + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user lock should use deterministic striping under saturation" + ); + + drop(retained); + return; + } + + panic!("unable to observe stable saturated lock-cache precondition after bounded retries"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("quota-saturated-user-{idx}"); + retained.push(quota_user_lock(&user)); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "precondition: cache must be saturated for overflow-user race test" + ); + + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "gap-t04-saturated-lock-race-user"; + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone()); + let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier); + let (r1, r2) = tokio::join!(one, two); + + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "both racers must resolve cleanly without unexpected errors" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "at least one racer must be quota-rejected even when lock cache is saturated" + ); + assert_eq!( + stats.get_user_total_octets(user), + 1, + "saturated lock cache must not permit double-success quota overshoot" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("quota-saturated-stress-holder-{idx}"); + retained.push(quota_user_lock(&user)); + } + + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + for round in 0..128u64 { + let user = format!("gap-t04-saturated-race-round-{round}"); + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x71, + 12_000 + round, + barrier.clone(), + ); + let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x72, 13_000 + round, barrier); + + let (r1, r2) = tokio::join!(one, two); + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: racers must resolve cleanly" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "round {round}: saturated cache must still enforce exactly one forwarded byte" + ); + } + + drop(retained); +} + +#[test] +fn adversarial_forensics_trace_id_should_not_alias_conn_id() { + let now = Instant::now(); + let trace_id = 0x1122_3344_5566_7788; + let conn_id = 0x8877_6655_4433_2211; + let state = RelayForensicsState { + trace_id, + conn_id, + user: "trace-user".to_string(), + peer: "198.51.100.17:443".parse().unwrap(), + peer_hash: 0x8877_6655_4433_2211, + started_at: now, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + }; + + assert_ne!( + state.trace_id, state.conn_id, + "security expectation: trace correlation should be independent of connection identity" + ); + assert_eq!(state.trace_id, trace_id); + assert_eq!(state.conn_id, conn_id); +} + +#[tokio::test] +async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() { + let (mut writer_side, reader_side) = duplex(8); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024); + + write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44) + .await + .expect("ack write must succeed"); + + let mut observed = [0u8; 4]; + writer_side + .read_exact(&mut observed) + .await + .expect("ack bytes must be readable"); + let mut decryptor = AesCtr::new(&key, iv); + let decrypted = decryptor.decrypt(&observed); + + assert_eq!( + decrypted, + 0x11_22_33_44u32.to_be_bytes(), + "abridged ACK should encode confirm bytes in big-endian order" + ); +} + +#[test] +fn desync_dedup_full_cache_churn_stays_suppressed() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!(should_emit_full_desync(key, false, now)); + } + + for offset in 0..2048u64 { + let emitted = should_emit_full_desync(u64::MAX - offset, false, now); + if offset == 0 { + assert!( + emitted, + "first full-cache newcomer should emit for forensic visibility" + ); + } else { + assert!( + !emitted, + "full-cache newcomer churn inside emit interval must stay suppressed" + ); + } + } +} + +#[test] +fn dedup_hash_is_stable_for_same_input_within_process() { + let sample = ( + "scope_user", + hash_ip("198.51.100.7".parse().unwrap()), + ProtoTag::Secure, + ); + let first = hash_value(&sample); + let second = hash_value(&sample); + assert_eq!( + first, second, + "dedup hash must be stable within a process for cache lookups" + ); +} + +#[test] +fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { + let mut seen = HashSet::new(); + + for octet in 1u16..=2048 { + let third = ((octet / 256) & 0xff) as u8; + let fourth = (octet & 0xff) as u8; + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); + let key = hash_value(&( + "scope_user", + hash_ip(ip), + ProtoTag::Secure, + DESYNC_ERROR_CLASS, + )); + seen.insert(key); + } + + assert_eq!( + seen.len(), + 2048, + "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" + ); +} + +#[test] +fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { + let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); + let mut seen = HashSet::new(); + let samples = 8192usize; + + for _ in 0..samples { + let user_seed: u64 = rng.random(); + let peer_seed: u64 = rng.random(); + let proto = if (peer_seed & 1) == 0 { + ProtoTag::Secure + } else { + ProtoTag::Intermediate + }; + let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); + seen.insert(key); + } + + let collisions = samples - seen.len(); + assert!( + collisions <= 1, + "light fuzz collision count should remain negligible for 64-bit dedup keys" + ); +} + +#[test] +fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; + + let mut emitted_count = 0usize; + for key in 0..total as u64 { + let emitted = should_emit_full_desync(key, false, now); + if emitted { + emitted_count += 1; + } + } + + assert_eq!( + emitted_count, + DESYNC_DEDUP_MAX_ENTRIES + 1, + "after capacity is reached, same-tick newcomer churn must be rate-limited" + ); + + let len = DESYNC_DEDUP + .get() + .expect("dedup cache must be initialized by stress run") + .len(); + assert!( + len <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must stay bounded under stress churn" + ); +} + +#[test] +fn full_cache_newcomer_emission_is_rate_limited_but_periodic() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // Same-tick newcomer storm: only the first should emit full forensic record. + let mut burst_emits = 0usize; + for i in 0..1024u64 { + if should_emit_full_desync(10_000_000 + i, false, base_now) { + burst_emits += 1; + } + } + assert_eq!( + burst_emits, 1, + "full-cache newcomer burst must be bounded to a single full emit per interval" + ); + + // After each interval elapses, one newcomer may emit again. + for step in 1..=6u64 { + let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32; + assert!( + should_emit_full_desync(20_000_000 + step, false, t), + "full-cache newcomer should re-emit once interval has elapsed" + ); + assert!( + !should_emit_full_desync(30_000_000 + step, false, t), + "additional newcomers in the same interval tick must remain suppressed" + ); + } +} + +#[test] +fn full_cache_mode_override_emits_every_event() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for i in 0..10_000u64 { + assert!( + should_emit_full_desync(100_000_000 + i, true, now), + "desync_all_full override must bypass dedup and rate-limit suppression" + ); + } +} + +#[test] +fn report_desync_stats_follow_rate_limited_full_cache_policy() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let stats = Stats::new(); + let mut state = make_forensics_state(); + state.started_at = base_now; + + for i in 0..128u64 { + state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i; + let _ = report_desync_frame_too_large( + &state, + ProtoTag::Secure, + 3, + 1024, + 4096, + Some([0x16, 0x03, 0x03, 0x00]), + &stats, + ); + } + + assert_eq!( + stats.get_desync_total(), + 128, + "every detected desync must increment total counter" + ); + assert_eq!( + stats.get_desync_full_logged(), + 1, + "same-interval full-cache newcomer storm must allow only one full forensic emit" + ); + assert_eq!( + stats.get_desync_suppressed(), + 127, + "remaining same-interval full-cache newcomer events must be suppressed" + ); + + // After one full interval in real wall clock, a newcomer should emit again. + thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20)); + state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64; + let _ = report_desync_frame_too_large( + &state, + ProtoTag::Secure, + 4, + 1024, + 4097, + Some([0x16, 0x03, 0x03, 0x01]), + &stats, + ); + + assert_eq!( + stats.get_desync_full_logged(), + 2, + "full forensic emission must recover after rate-limit interval" + ); +} + +#[test] +fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let emits = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + for worker_id in 0..32u64 { + let emits = Arc::clone(&emits); + workers.push(thread::spawn(move || { + for i in 0..512u64 { + let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i; + if should_emit_full_desync(key, false, base_now) { + emits.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must not panic"); + } + + assert_eq!( + emits.load(Ordering::Relaxed), + 1, + "concurrent same-interval full-cache storm must allow only one full forensic emit" + ); +} + +#[test] +fn light_fuzz_full_cache_rate_limit_oracle_matches_model() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD); + let mut model_last_emit: Option = None; + + for i in 0..4096u64 { + let jitter_ms: u64 = rng.random_range(0..=3000); + let t = base_now + TokioDuration::from_millis(jitter_ms); + let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::(); + let actual = should_emit_full_desync(key, false, t); + + let expected = match model_last_emit { + None => { + model_last_emit = Some(t); + true + } + Some(last) => { + match t.checked_duration_since(last) { + Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => { + model_last_emit = Some(t); + true + } + Some(_) => false, + None => { + // Match production fail-open behavior for non-monotonic synthetic input. + model_last_emit = Some(t); + true + } + } + } + }; + + assert_eq!( + actual, expected, + "full-cache rate-limit gate diverged from reference model under light fuzz" + ); + } +} + +#[test] +fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // Poison the full-cache gate lock intentionally. + let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); + let _ = std::panic::catch_unwind(|| { + let _lock = gate + .lock() + .expect("gate lock must be lockable before poison"); + panic!("intentional gate poison for fail-closed regression"); + }); + + let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now); + assert!( + !emitted, + "poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open" + ); + assert!( + dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must remain bounded even when gate lock is poisoned" + ); +} + +#[test] +fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // First event seeds the gate. + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0001, + false, + base_now + TokioDuration::from_millis(900) + )); + + // Synthetic earlier timestamp must not panic; it should fail-open and reset gate. + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0002, + false, + base_now + TokioDuration::from_millis(100) + )); + + // Same instant again remains suppressed after reset. + assert!(!should_emit_full_desync( + 0xABCD_0000_0000_0003, + false, + base_now + TokioDuration::from_millis(100) + )); +} + +#[test] +fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + + // Fill with fresh entries so stale-pruning does not apply. + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + + let newcomer_key = u64::MAX; + let emitted = should_emit_full_desync(newcomer_key, false, base_now); + assert!( + emitted, + "new entry under full fresh cache must emit after bounded eviction" + ); + assert!( + dedup.get(&newcomer_key).is_some(), + "new key must be inserted after bounded eviction" + ); + + let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + let removed_count = before_keys.difference(&after_keys).count(); + let added_count = after_keys.difference(&before_keys).count(); + + assert_eq!( + removed_count, 1, + "full-cache insertion must evict exactly one prior key" + ); + assert_eq!( + added_count, 1, + "full-cache insertion must add exactly one newcomer key" + ); + assert!( + dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must remain hard-bounded after full-cache churn" + ); +} + +#[test] +fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let key = 0xC0DE_CAFE_u64; + let start = Instant::now(); + + assert!( + should_emit_full_desync(key, false, start), + "first event for key must emit full forensic record" + ); + + // Deterministic pseudo-random time deltas around dedup window edge. + let mut s: u64 = 0x1234_5678_9ABC_DEF0; + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); + let now = start + TokioDuration::from_millis(delta_ms); + let emitted = should_emit_full_desync(key, false, now); + + if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { + assert!( + !emitted, + "events inside dedup window must remain suppressed" + ); + } else { + // Once window elapsed for this key, at least one sample should re-emit and refresh. + if emitted { + return; + } + } + } + + panic!("expected at least one post-window sample to re-emit forensic record"); +} + +fn make_forensics_state() -> RelayForensicsState { + RelayForensicsState { + trace_id: 1, + conn_id: 2, + user: "test-user".to_string(), + peer: "127.0.0.1:50000".parse::().unwrap(), + peer_hash: 3, + started_at: Instant::now(), + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + stats, + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_instadrain, + general.me_pool_drain_threshold, + general.me_pool_drain_soft_evict_enabled, + general.me_pool_drain_soft_evict_grace_secs, + general.me_pool_drain_soft_evict_per_writer, + general.me_pool_drain_soft_evict_budget_per_core, + general.me_pool_drain_soft_evict_cooldown_ms, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +#[tokio::test] +async fn read_client_payload_times_out_on_header_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled header read must time out" + ); +} + +#[tokio::test] +async fn read_client_payload_times_out_on_payload_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, mut writer) = duplex(1024); + let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); + writer.write_all(&encrypted_len).await.unwrap(); + + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled payload body read must time out" + ); +} + +#[tokio::test] +async fn read_client_payload_large_intermediate_frame_is_exact() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(262_144); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537); + let mut plaintext = Vec::with_capacity(4 + payload_len); + plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + payload_len + 16, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!( + frame.len(), + payload_len, + "payload size must match wire length" + ); + for (idx, byte) in frame.iter().enumerate() { + assert_eq!(*byte, (idx as u8).wrapping_mul(31)); + } + assert_eq!(frame_counter, 1, "exactly one frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_secure_strips_tail_padding_bytes() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd]; + let tail = [0xeeu8, 0xff, 0x99]; + let wire_len = payload.len() + tail.len(); + + let mut plaintext = Vec::with_capacity(4 + wire_len); + plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + plaintext.extend_from_slice(&tail); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("secure payload read must succeed") + .expect("secure frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.as_ref(), &payload); + assert_eq!(frame_counter, 1, "one secure frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_secure_rejects_wire_len_below_4() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let mut plaintext = Vec::with_capacity(7); + plaintext.extend_from_slice(&3u32.to_le_bytes()); + plaintext.extend_from_slice(&[1u8, 2, 3]); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")), + "secure wire length below 4 must be fail-closed by the frame-too-small guard" + ); +} + +#[tokio::test] +async fn read_client_payload_intermediate_skips_zero_len_frame() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [7u8, 6, 5, 4, 3, 2, 1, 0]; + let mut plaintext = Vec::with_capacity(4 + 4 + payload.len()); + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("intermediate payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.as_ref(), &payload); + assert_eq!(frame_counter, 1, "zero-length frame must be skipped"); +} + +#[tokio::test] +async fn read_client_payload_abridged_extended_len_sets_quickack() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload_len = 4 * 130; + let len_words = (payload_len / 4) as u32; + let mut plaintext = Vec::with_capacity(1 + 3 + payload_len); + plaintext.push(0xff | 0x80); + let lw = len_words.to_le_bytes(); + plaintext.extend_from_slice(&lw[..3]); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Abridged, + payload_len + 16, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("abridged payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!( + quickack, + "quickack bit must be propagated from abridged header" + ); + assert_eq!(frame.len(), payload_len); + assert_eq!(frame_counter, 1, "one abridged frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_returns_buffer_to_pool_after_emit() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 8)); + pool.preallocate(1); + assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); + + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + // Force growth beyond default pool buffer size to catch ownership-take regressions. + let payload_len = 257usize; + let mut plaintext = Vec::with_capacity(4 + payload_len); + plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let _ = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + payload_len + 8, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert_eq!(frame_counter, 1); + let pool_stats = pool.stats(); + assert!( + pool_stats.pooled >= 1, + "emitted payload buffer must be returned to pool to avoid pool drain" + ); +} + +#[tokio::test] +async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 2)); + pool.preallocate(1); + assert_eq!( + pool.stats().pooled, + 1, + "one pooled buffer must be available" + ); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; + let mut plaintext = Vec::with_capacity(4 + payload.len()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let (frame, quickack) = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert!(!quickack); + assert_eq!(frame.as_ref(), &payload); + assert_eq!( + pool.stats().pooled, + 0, + "buffer must stay checked out while frame payload is alive" + ); + + drop(frame); + assert!( + pool.stats().pooled >= 1, + "buffer must return to pool only after frame drop" + ); +} + +#[tokio::test] +async fn enqueue_c2me_close_unblocks_after_queue_drain() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x41]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = + tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + + let first = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("first queued item must be present"); + assert!(matches!(first, C2MeCommand::Data { .. })); + + close_task + .await + .unwrap() + .expect("close enqueue must succeed after drain"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("close command must follow after queue drain"); + assert!(matches!(second, C2MeCommand::Close)); +} + +#[tokio::test] +async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { + let (tx, rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x42]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = + tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), close_task) + .await + .expect("close task must finish") + .expect("close task must not panic"); + assert!( + result.is_err(), + "close enqueue must fail cleanly when receiver is dropped under pressure" + ); +} + +#[tokio::test] +async fn process_me_writer_response_ack_obeys_flush_policy() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let immediate = process_me_writer_response( + MeResponse::Ack(0x11223344), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + None, + 0, + &bytes_me2c, + 77, + true, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + immediate, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: true, + } + )); + + let delayed = process_me_writer_response( + MeResponse::Ack(0x55667788), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + None, + 0, + &bytes_me2c, + 77, + false, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + delayed, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: false, + } + )); +} + +#[tokio::test] +async fn process_me_writer_response_data_updates_byte_accounting() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; + let outcome = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload.clone()), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + None, + 0, + &bytes_me2c, + 88, + false, + false, + ) + .await + .expect("data response must be processed"); + + assert!(matches!( + outcome, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes, + flush_immediately: false, + } if bytes == payload.len() + )); + assert_eq!( + bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), + payload.len() as u64, + "ME->C byte accounting must increase by emitted payload size" + ); +} + +#[tokio::test] +async fn process_me_writer_response_data_enforces_live_user_quota() { + let (writer_side, mut reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_from("quota-user", 10); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![1u8, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "quota-user", + Some(12), + 0, + &bytes_me2c, + 89, + false, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"), + "ME->client runtime path must terminate when live user quota is crossed" + ); + + let mut raw = [0u8; 1]; + assert!( + timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) + .await + .is_err(), + "quota exhaustion must not write any ciphertext to the client stream" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "quota-race-user"; + + let (writer_side_a, _reader_side_a) = duplex(1024); + let (writer_side_b, _reader_side_b) = duplex(1024); + let mut writer_a = make_crypto_writer(writer_side_a); + let mut writer_b = make_crypto_writer(writer_side_b); + let mut frame_buf_a = Vec::new(); + let mut frame_buf_b = Vec::new(); + let rng_a = SecureRandom::new(); + let rng_b = SecureRandom::new(); + + let fut_a = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x11]), + }, + &mut writer_a, + ProtoTag::Intermediate, + &rng_a, + &mut frame_buf_a, + &stats, + user, + Some(1), + 0, + &bytes_me2c, + 91, + false, + false, + ); + let fut_b = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x22]), + }, + &mut writer_b, + ProtoTag::Intermediate, + &rng_b, + &mut frame_buf_b, + &stats, + user, + Some(1), + 0, + &bytes_me2c, + 92, + false, + false, + ); + + let (result_a, result_b) = tokio::join!(fut_a, fut_b); + + assert!( + matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") + || matches!(result_a, Ok(_)), + "concurrent quota test must complete without panicking" + ); + assert!( + matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") + || matches!(result_b, Ok(_)), + "concurrent quota test must complete without panicking" + ); + assert!( + stats.get_user_total_octets(user) <= 1, + "same-user concurrent middle-relay responses must not overshoot the configured quota" + ); +} + +#[tokio::test] +async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() + { + let (writer_side, mut reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_to("partial-quota-user", 3); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![1u8, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "partial-quota-user", + Some(4), + 0, + &bytes_me2c, + 90, + false, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"), + "ME->client runtime path must reject oversized payloads before writing" + ); + + let mut raw = [0u8; 1]; + assert!( + timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) + .await + .is_err(), + "oversized payloads must not leak any partial ciphertext to the client stream" + ); +} + +#[tokio::test] +async fn middle_relay_abort_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50001".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xdecafbad, + )); + + let started = tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await; + assert!( + started.is_ok(), + "middle relay must increment route gauge before abort" + ); + + relay_task.abort(); + let joined = relay_task.await; + assert!( + joined.is_err(), + "aborted middle relay task must return join error" + ); + + tokio::time::sleep(TokioDuration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay task is aborted mid-flight" + ); + + drop(client_side); +} + +#[tokio::test] +async fn middle_relay_cutover_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50003".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xfeed_beef, + )); + + tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("middle relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) + .await + .expect("middle relay must terminate after cutover") + .expect("middle relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate middle relay session" + ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); + + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay exits on cutover" + ); + + drop(client_side); +} + +async fn run_quota_race_attempt( + stats: &Stats, + bytes_me2c: &AtomicU64, + user: &str, + payload: u8, + conn_id: u64, + barrier: Arc, +) -> Result { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + barrier.wait().await; + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![payload]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + stats, + user, + Some(1), + 0, + bytes_me2c, + conn_id, + false, + false, + ) + .await +} + +#[tokio::test] +async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(256); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let plaintext = vec![0x7f, 0xff, 0xff, 0xff]; + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Abridged, + 4096, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + result.is_err(), + "oversized abridged length must fail closed" + ); + assert_eq!( + frame_counter, 0, + "oversized frame must not be counted as accepted" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "gap-t04-race-user"; + let barrier = Arc::new(Barrier::new(2)); + + let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone()); + let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier); + + let (r1, r2) = tokio::join!(f1, f2); + + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "first racer must either finish or fail closed on quota" + ); + assert!( + matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "second racer must either finish or fail closed on quota" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(user), + 1, + "same-user race must forward/account exactly one payload byte" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_quota_race_bursts_never_allow_double_success_per_round() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + for round in 0..128u64 { + let user = format!("gap-t04-race-burst-{round}"); + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x33, + 6000 + round, + barrier.clone(), + ); + let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x44, 7000 + round, barrier); + + let (r1, r2) = tokio::join!(one, two); + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: racers must resolve cleanly without unexpected errors" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "round {round}: same-user total octets must remain exactly 1 (single forwarded winner)" + ); + } +} + +#[tokio::test] +async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-middle-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 52000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xB000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(TokioDuration::from_secs(4), async { + loop { + if stats.get_current_connections_me() == session_count as u64 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("all middle sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(TokioDuration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) + .await + .expect("middle relay task must finish under cutover storm") + .expect("middle relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all middle sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after cutover storm" + ); + + drop(client_sides); +} + +#[tokio::test] +async fn secure_padding_distribution_in_relay_writer() { + timeout(TokioDuration::from_secs(10), async { + let (mut client_side, relay_side) = duplex(512 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = Arc::new(SecureRandom::new()); + let mut frame_buf = Vec::new(); + let mut decryptor = AesCtr::new(&key, iv); + + let mut padding_counts = [0usize; 4]; + let iterations = 180usize; + let payload = vec![0xAAu8; 100]; // 4-byte aligned + + for _ in 0..iterations { + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload write must succeed"); + writer + .flush() + .await + .expect("writer flush must complete so encrypted frame becomes readable"); + + let mut len_buf = [0u8; 4]; + client_side + .read_exact(&mut len_buf) + .await + .expect("must read encrypted secure length"); + let decrypted_len_bytes = decryptor.decrypt(&len_buf); + let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes + .try_into() + .expect("decrypted length must be 4 bytes"); + let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; + + assert!( + wire_len >= payload.len(), + "wire length must include at least payload bytes" + ); + let padding_len = wire_len - payload.len(); + assert!(padding_len >= 1 && padding_len <= 3); + padding_counts[padding_len] += 1; + + // Drain and decrypt frame bytes so CTR state stays aligned across writes. + let mut trash = vec![0u8; wire_len]; + client_side + .read_exact(&mut trash) + .await + .expect("must read encrypted secure frame body"); + let _ = decryptor.decrypt(&trash); + } + + for p in 1..=3 { + let count = padding_counts[p]; + assert!( + count > iterations / 8, + "padding length {p} is under-represented ({count}/{iterations})" + ); + } + }) + .await + .expect("secure padding distribution test exceeded runtime budget"); +} + +#[tokio::test] +async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + // Create an ME pool. + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // ConnRegistry ids are monotonic; reserve one id so we can predict the + // next session conn_id and close it deterministically without relying on + // writer-bound views such as active_conn_ids(). + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_runtime.snapshot(), + 0x1234_5678, + )); + + // Wait until session startup is visible, then unregister the predicted + // conn_id to close the per-session ME response channel. + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before channel close simulation"); + + me_pool.registry().unregister(target_conn_id).await; + + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("Session task must terminate after ME drop and client EOF") + .expect("Session task must not panic"); + + assert!( + result.is_ok(), + "Session should complete cleanly after ME drop when client closes, got: {:?}", + result + ); +} + +#[tokio::test] +async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // Predict the next conn_id so we can force-drop its ME channel deterministically. + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user-cutover".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xC001_CAFE, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before race trigger"); + + // Race ME channel drop with route cutover and assert generic client-visible outcome. + me_pool.registry().unregister(target_conn_id).await; + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance generation" + ); + + let relay_result = timeout(TokioDuration::from_secs(6), session_task) + .await + .expect("session must terminate under ME-drop + cutover race") + .expect("session task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "race outcome must remain generic and not leak ME internals, got: {:?}", + relay_result + ); +} + +#[tokio::test] +async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + for round in 0..32u64 { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: format!("stress-me-drop-eof-{round}"), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xD00D_0000 + round, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("session must start before forced drop in burst round"); + + me_pool.registry().unregister(target_conn_id).await; + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("burst round session must terminate quickly") + .expect("burst round session must not panic"); + + assert!( + result.is_ok(), + "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", + result + ); + } +} diff --git a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs new file mode 100644 index 0000000..2635a28 --- /dev/null +++ b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs @@ -0,0 +1,168 @@ +use super::*; +use crate::stream::BufferPool; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::time::{Duration as TokioDuration, timeout}; + +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +#[test] +#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] +fn should_emit_full_desync_filters_duplicates() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let key = 0x4D04_0000_0000_0001_u64; + let base = Instant::now(); + + assert!( + should_emit_full_desync(key, false, base), + "first occurrence must emit full forensic record" + ); + assert!( + !should_emit_full_desync(key, false, base), + "duplicate at same timestamp must be suppressed" + ); + + let within_window = base + DESYNC_DEDUP_WINDOW - TokioDuration::from_millis(1); + assert!( + !should_emit_full_desync(key, false, within_window), + "duplicate strictly inside dedup window must stay suppressed" + ); + + let on_window_edge = base + DESYNC_DEDUP_WINDOW; + assert!( + should_emit_full_desync(key, false, on_window_edge), + "duplicate at window boundary must re-emit and refresh" + ); +} + +#[test] +#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"] +fn desync_dedup_eviction_under_map_full_condition() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let base = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!( + should_emit_full_desync(key, false, base), + "unique key should be inserted while warming dedup cache" + ); + } + + let dedup = DESYNC_DEDUP + .get() + .expect("dedup map must exist after warm-up insertions"); + assert_eq!( + dedup.len(), + DESYNC_DEDUP_MAX_ENTRIES, + "cache warm-up must reach exact hard cap" + ); + + let before_keys: HashSet = dedup.iter().map(|entry| *entry.key()).collect(); + let newcomer_key = 0x4D04_FFFF_FFFF_0001_u64; + + assert!( + should_emit_full_desync(newcomer_key, false, base), + "first newcomer at map-full must emit under bounded full-cache gate" + ); + + let after_keys: HashSet = dedup.iter().map(|entry| *entry.key()).collect(); + assert_eq!( + dedup.len(), + DESYNC_DEDUP_MAX_ENTRIES, + "map-full insertion must preserve hard capacity bound" + ); + assert!( + after_keys.contains(&newcomer_key), + "newcomer must be present after bounded eviction path" + ); + + let removed_count = before_keys.difference(&after_keys).count(); + let added_count = after_keys.difference(&before_keys).count(); + assert_eq!( + removed_count, 1, + "map-full insertion must evict exactly one prior key" + ); + assert_eq!( + added_count, 1, + "map-full insertion must add exactly one newcomer key" + ); + + assert!( + !should_emit_full_desync(newcomer_key, false, base), + "immediate duplicate newcomer must remain suppressed" + ); +} + +#[tokio::test] +#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"] +async fn c2me_channel_full_path_yields_then_sends() { + let (tx, mut rx) = mpsc::channel::(1); + + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0xAA]), + flags: 1, + }) + .await + .expect("priming queue with one frame must succeed"); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload(&[0xBB, 0xCC]), + flags: 2, + }, + ) + .await + }); + + tokio::task::yield_now().await; + tokio::time::sleep(TokioDuration::from_millis(10)).await; + assert!( + !producer.is_finished(), + "producer should stay pending while queue is full" + ); + + let first = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .expect("receiver should observe primed frame") + .expect("first queued command must exist"); + match first { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[0xAA]); + assert_eq!(flags, 1); + } + C2MeCommand::Close => panic!("unexpected close command as first item"), + } + + producer + .await + .expect("producer task must not panic") + .expect("blocked enqueue must succeed once receiver drains capacity"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .expect("receiver should observe backpressure-resumed frame") + .expect("second queued command must exist"); + match second { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[0xBB, 0xCC]); + assert_eq!(flags, 2); + } + C2MeCommand::Close => panic!("unexpected close command as second item"), + } +} diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs new file mode 100644 index 0000000..14754cd --- /dev/null +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -0,0 +1,217 @@ +use super::*; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +// ------------------------------------------------------------------ +// Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn relay_hol_blocking_prevention_regression() { + let stats = Arc::new(Stats::new()); + let user = "hol-user"; + + let (client_peer, relay_client) = duplex(65536); + let (relay_server, server_peer) = duplex(65536); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 8192, + 8192, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + let payload_size = 1024 * 10; + let s2c_payload = vec![0x41; payload_size]; + let c2s_payload = vec![0x42; payload_size]; + + let s2c_handle = tokio::spawn(async move { + sp_writer.write_all(&s2c_payload).await.unwrap(); + + let mut total_read = 0; + let mut buf = [0u8; 10]; + while total_read < payload_size { + let n = cp_reader.read(&mut buf).await.unwrap(); + total_read += n; + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + + let start = Instant::now(); + cp_writer.write_all(&c2s_payload).await.unwrap(); + + let mut server_buf = vec![0u8; payload_size]; + sp_reader.read_exact(&mut server_buf).await.unwrap(); + let elapsed = start.elapsed(); + + assert!( + elapsed < Duration::from_millis(1000), + "C->S must not be blocked by slow S->C (HOL blocking): {:?}", + elapsed + ); + assert_eq!(server_buf, c2s_payload); + + s2c_handle.abort(); + relay_task.abort(); +} + +// ------------------------------------------------------------------ +// Priority 3: Data Quota Mid-Session Cutoff (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn relay_quota_mid_session_cutoff() { + let stats = Arc::new(Stats::new()); + let user = "quota-mid-user"; + let quota = 5000; + + let (client_peer, relay_client) = duplex(8192); + let (relay_server, server_peer) = duplex(8192); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut _cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, _sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + // Send 4000 bytes (Ok) + let buf1 = vec![0x42; 4000]; + cp_writer.write_all(&buf1).await.unwrap(); + let mut server_recv = vec![0u8; 4000]; + sp_reader.read_exact(&mut server_recv).await.unwrap(); + + // Send another 2000 bytes (Total 6000 > 5000) + let buf2 = vec![0x42; 2000]; + let _ = cp_writer.write_all(&buf2).await; + + let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap(); + + match relay_res { + Ok(Err(ProxyError::DataQuotaExceeded { .. })) => { + // Expected + } + other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), + } + + let mut small_buf = [0u8; 1]; + let n = sp_reader.read(&mut small_buf).await.unwrap(); + assert_eq!(n, 0, "Server must see EOF after quota reached"); +} + +#[tokio::test] +async fn relay_chaos_half_close_crossfire_terminates_without_hang() { + let stats = Arc::new(Stats::new()); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "half-close-crossfire", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(b"c2s-pre-half-close").await.unwrap(); + server_peer.write_all(b"s2c-pre-half-close").await.unwrap(); + + client_peer.shutdown().await.unwrap(); + tokio::time::sleep(Duration::from_millis(10)).await; + server_peer.shutdown().await.unwrap(); + + let done = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay must terminate after bilateral half-close") + .expect("relay task must not panic"); + assert!( + done.is_ok(), + "relay must terminate cleanly under half-close crossfire" + ); +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn relay_soak_bidirectional_temporal_jitter_5k_rounds() { + let stats = Arc::new(Stats::new()); + + let (mut client_peer, relay_client) = duplex(65536); + let (relay_server, mut server_peer) = duplex(65536); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 4096, + 4096, + "soak-jitter-user", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + for i in 0..5_000u32 { + let c = [((i as u8).wrapping_mul(13)).wrapping_add(1); 17]; + client_peer.write_all(&c).await.unwrap(); + let mut c_seen = [0u8; 17]; + server_peer.read_exact(&mut c_seen).await.unwrap(); + assert_eq!(c_seen, c); + + let s = [((i as u8).wrapping_mul(7)).wrapping_add(3); 23]; + server_peer.write_all(&s).await.unwrap(); + let mut s_seen = [0u8; 23]; + client_peer.read_exact(&mut s_seen).await.unwrap(); + assert_eq!(s_seen, s); + + if i % 10 == 0 { + tokio::time::sleep(Duration::from_millis((i % 3) as u64)).await; + } + } + + drop(client_peer); + drop(server_peer); + let done = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay must stop after soak peers close") + .expect("relay task must not panic"); + assert!(done.is_ok()); +} diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs new file mode 100644 index 0000000..080240a --- /dev/null +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -0,0 +1,461 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +async fn read_available(reader: &mut R, budget: Duration) -> usize { + let start = Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 256]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +#[tokio::test] +async fn integration_full_duplex_exact_budget_then_hard_cutoff() { + let stats = Arc::new(Stats::new()); + let user = "quota-full-duplex-boundary-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x10, 0x11, 0x12, 0x13]) + .await + .unwrap(); + let mut c2s = [0u8; 4]; + server_peer.read_exact(&mut c2s).await.unwrap(); + assert_eq!(c2s, [0x10, 0x11, 0x12, 0x13]); + + server_peer + .write_all(&[0x20, 0x21, 0x22, 0x23, 0x24, 0x25]) + .await + .unwrap(); + let mut s2c = [0u8; 6]; + client_peer.read_exact(&mut s2c).await.unwrap(); + assert_eq!(s2c, [0x20, 0x21, 0x22, 0x23, 0x24, 0x25]); + + let _ = client_peer.write_all(&[0x99]).await; + let _ = server_peer.write_all(&[0x88]).await; + + let mut probe_server = [0u8; 1]; + let mut probe_client = [0u8; 1]; + let leaked_to_server = timeout( + Duration::from_millis(120), + server_peer.read(&mut probe_server), + ) + .await; + let leaked_to_client = timeout( + Duration::from_millis(120), + client_peer.read(&mut probe_client), + ) + .await; + + assert!( + !matches!(leaked_to_server, Ok(Ok(n)) if n > 0), + "once quota is exhausted, no extra client byte must be forwarded" + ); + assert!( + !matches!(leaked_to_client, Ok(Ok(n)) if n > 0), + "once quota is exhausted, no extra server byte must be forwarded" + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under quota cutoff") + .expect("relay task must not panic"); + + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user" + )); + assert!(stats.get_user_total_octets(user) <= 10); +} + +#[tokio::test] +async fn negative_preloaded_quota_blocks_both_directions_immediately() { + let stats = Arc::new(Stats::new()); + let user = "quota-preloaded-cutoff-user"; + stats.add_user_octets_from(user, 5); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(5), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0x41, 0x42]), + server_peer.write_all(&[0x51, 0x52]), + ); + + let leaked_to_server = read_available(&mut server_peer, Duration::from_millis(120)).await; + let leaked_to_client = read_available(&mut client_peer, Duration::from_millis(120)).await; + + assert_eq!( + leaked_to_server, 0, + "preloaded limit must block C->S immediately" + ); + assert_eq!( + leaked_to_client, 0, + "preloaded limit must block S->C immediately" + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under preloaded cutoff") + .expect("relay task must not panic"); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(stats.get_user_total_octets(user) <= 5); +} + +#[tokio::test] +async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() { + let stats = Arc::new(Stats::new()); + let user = "quota-one-race-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0xAA]), + server_peer.write_all(&[0xBB]) + ); + + let mut to_server = [0u8; 1]; + let mut to_client = [0u8; 1]; + + let delivered_server = + match timeout(Duration::from_millis(120), server_peer.read(&mut to_server)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + let delivered_client = + match timeout(Duration::from_millis(120), client_peer.read(&mut to_client)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + + assert!( + delivered_server + delivered_client <= 1, + "quota=1 must not allow >1 forwarded byte across both directions" + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under quota=1") + .expect("relay task must not panic"); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(stats.get_user_total_octets(user) <= 1); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-blackhat-jitter-user"; + let quota = 32u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut delivered_to_server = 0usize; + let mut delivered_to_client = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x5A]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + delivered_to_server = delivered_to_server.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA5]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + delivered_to_client = delivered_to_client.saturating_add(n); + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under black-hat jitter attack") + .expect("relay task must not panic"); + + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!( + delivered_to_server + delivered_to_client <= quota as usize, + "combined forwarded bytes must never exceed configured quota" + ); + assert!(stats.get_user_total_octets(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invariants() { + let mut rng = StdRng::seed_from_u64(0xD15C_A11E_F00D_BAAD); + + for case in 0..48u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-fuzz-schedule-{case}"); + let quota = rng.random_range(1u64..=32u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut delivered_total = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), server_peer.read(&mut one)).await + { + delivered_total = delivered_total.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), client_peer.read(&mut one)).await + { + delivered_total = delivered_total.saturating_add(n); + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("fuzz relay must terminate") + .expect("fuzz relay task must not panic"); + + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "relay must either close cleanly or terminate via typed quota error" + ); + assert!( + delivered_total <= quota as usize, + "fuzz case {case}: forwarded bytes must not exceed quota" + ); + assert!( + stats.get_user_total_octets(&user) <= quota, + "fuzz case {case}: accounted bytes must not exceed quota" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-stress-multi-relay-user"; + let quota = 64u64; + + let mut workers = Vec::new(); + + for worker_id in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + workers.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut delivered = 0usize; + + for step in 0..96u8 { + if relay.is_finished() { + break; + } + + if ((step as usize + worker_id as usize) & 1) == 0 { + let _ = client_peer.write_all(&[step ^ 0x3C]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), server_peer.read(&mut one)).await + { + delivered = delivered.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[step ^ 0xC3]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), client_peer.read(&mut one)).await + { + delivered = delivered.saturating_add(n); + } + } + + tokio::time::sleep(Duration::from_millis( + (((worker_id as u64) + (step as u64)) % 3) + 1, + )) + .await; + } + + drop(client_peer); + drop(server_peer); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must either close cleanly or terminate via typed quota error" + ); + delivered + })); + } + + let mut delivered_sum = 0usize; + for worker in workers { + delivered_sum = + delivered_sum.saturating_add(worker.await.expect("stress worker must not panic")); + } + + assert!( + stats.get_user_total_octets(user) <= quota, + "global per-user quota must hold under concurrent mixed-direction relay stress" + ); + assert!( + delivered_sum <= quota as usize, + "combined delivered bytes across relays must stay within global quota" + ); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs new file mode 100644 index 0000000..e29e86e --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs @@ -0,0 +1,438 @@ +use super::*; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::sync::Barrier; +use tokio::time::Instant; + +#[test] +fn quota_lock_same_user_returns_same_arc_instance() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let a = quota_user_lock("quota-lock-same-user"); + let b = quota_user_lock("quota-lock-same-user"); + assert!(Arc::ptr_eq(&a, &b)); +} + +#[test] +fn quota_lock_parallel_same_user_reuses_single_lock() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-lock-parallel-same"; + let mut handles = Vec::new(); + + for _ in 0..64 { + handles.push(std::thread::spawn(move || quota_user_lock(user))); + } + + let first = handles + .remove(0) + .join() + .expect("thread must return lock handle"); + + for handle in handles { + let got = handle.join().expect("thread must return lock handle"); + assert!(Arc::ptr_eq(&first, &got)); + } +} + +#[test] +fn quota_lock_unique_users_materialize_distinct_entries() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + + map.clear(); + + let base = format!("quota-lock-distinct-{}", std::process::id()); + let users: Vec = (0..(QUOTA_USER_LOCKS_MAX / 2)) + .map(|idx| format!("{base}-{idx}")) + .collect(); + + for user in &users { + let _ = quota_user_lock(user); + } + + for user in &users { + assert!( + map.get(user).is_some(), + "lock cache must contain entry for {user}" + ); + } +} + +#[test] +fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + + map.clear(); + + let base = format!("quota-lock-churn-{}", std::process::id()); + for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) { + let _ = quota_user_lock(&format!("{base}-{idx}")); + } + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock cache must stay bounded under unique-user churn" + ); +} + +#[test] +fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("quota-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "cache must be saturated for overflow check" + ); + + let overflow_user = format!("quota-overflow-{}", std::process::id()); + let overflow_a = quota_user_lock(&overflow_user); + let overflow_b = quota_user_lock(&overflow_user); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow path must not grow lock cache" + ); + assert!( + map.get(&overflow_user).is_none(), + "overflow user lock must stay outside bounded cache under saturation" + ); + assert!( + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user must receive stable striped overflow lock while saturated" + ); + + drop(retained); +} + +#[test] +fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + // Saturate with retained strong references first so parallel tests cannot + // reclaim our fixture entries before we validate the reclaim path. + let prefix = format!("quota-reclaim-drop-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + drop(retained); + + let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); + let overflow = quota_user_lock(&overflow_user); + + assert!( + map.get(&overflow_user).is_some(), + "after reclaiming stale entries, overflow user should become cacheable" + ); + assert!( + Arc::strong_count(&overflow) >= 2, + "cacheable overflow lock should be held by both map and caller" + ); +} + +#[test] +fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!( + "quota-saturated-held-{}-{idx}", + std::process::id() + ))); + } + + let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); + let a = quota_user_lock(&overflow_user); + let b = quota_user_lock(&overflow_user); + + assert!( + Arc::ptr_eq(&a, &b), + "same user must not receive distinct locks under saturation because that enables quota race bypass" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!( + "quota-saturated-race-held-{}-{idx}", + std::process::id() + ))); + } + + let stats = Arc::new(Stats::new()); + let user = format!("quota-saturated-race-user-{}", std::process::id()); + let gate = Arc::new(Barrier::new(2)); + + let worker = |label: u8, stats: Arc, user: String, gate: Arc| { + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[label]).await + }) + }; + + let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); + let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); + + let _ = tokio::time::timeout(Duration::from_secs(2), async { + let _ = one.await.expect("task one must not panic"); + let _ = two.await.expect("task two must not panic"); + }) + .await + .expect("quota race workers must complete"); + + assert!( + stats.get_user_total_octets(&user) <= 1, + "saturated lock path must never overshoot quota for same user" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!( + "quota-saturated-stress-held-{}-{idx}", + std::process::id() + ))); + } + + for round in 0..128u32 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id()); + let gate = Arc::new(Barrier::new(2)); + + let one = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x31]).await + }) + }; + + let two = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x32]).await + }) + }; + + let _ = one.await.expect("stress task one must not panic"); + let _ = two.await.expect("stress task two must not panic"); + + assert!( + stats.get_user_total_octets(&user) <= 1, + "round {round}: saturated path must not overshoot quota" + ); + } + + drop(retained); +} + +#[test] +fn quota_error_classifier_accepts_internal_quota_sentinel_only() { + let err = quota_io_error(); + assert!(is_quota_io_error(&err)); +} + +#[test] +fn quota_error_classifier_rejects_plain_permission_denied() { + let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied"); + assert!(!is_quota_io_error(&err)); +} + +#[test] +fn quota_lock_test_scope_recovers_after_guard_poison() { + let poison_result = std::thread::spawn(|| { + let _guard = super::quota_user_lock_test_scope(); + panic!("intentional test-only guard poison"); + }) + .join(); + assert!(poison_result.is_err(), "poison setup thread must panic"); + + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let a = quota_user_lock("quota-lock-poison-recovery-user"); + let b = quota_user_lock("quota-lock-poison-recovery-user"); + assert!(Arc::ptr_eq(&a, &b)); +} + +#[tokio::test] +async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-zero-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(0), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(b"x") + .await + .expect("client write must succeed"); + + let mut probe = [0u8; 1]; + let forwarded = + tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; + if let Ok(Ok(n)) = forwarded { + assert_eq!(n, 0, "zero quota path must not forward payload bytes"); + } + + let result = tokio::time::timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under zero quota") + .expect("relay task must not panic"); + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { + let stats = Arc::new(Stats::new()); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "quota-none-burst-user", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + let c2s = vec![0xA5; 2048]; + let s2c = vec![0x5A; 1536]; + + client_peer + .write_all(&c2s) + .await + .expect("client burst write must succeed"); + let mut got_c2s = vec![0u8; c2s.len()]; + server_peer + .read_exact(&mut got_c2s) + .await + .expect("server must receive c2s burst"); + assert_eq!(got_c2s, c2s); + + server_peer + .write_all(&s2c) + .await + .expect("server burst write must succeed"); + let mut got_s2c = vec![0u8; s2c.len()]; + client_peer + .read_exact(&mut got_s2c) + .await + .expect("client must receive s2c burst"); + assert_eq!(got_s2c, s2c); + + drop(client_peer); + drop(server_peer); + + let done = tokio::time::timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate after peers close") + .expect("relay task must not panic"); + assert!(done.is_ok()); +} diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs new file mode 100644 index 0000000..5714f48 --- /dev/null +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -0,0 +1,315 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::sync::Barrier; +use tokio::time::{Duration, timeout}; + +fn assert_is_prefix(received: &[u8], sent: &[u8], direction: &str) { + assert!( + sent.starts_with(received), + "{direction} stream corruption: received={} sent={} (received must be prefix of sent)", + received.len(), + sent.len() + ); +} + +async fn drain_available(reader: &mut R, out: &mut Vec, rounds: usize) { + for _ in 0..rounds { + let mut buf = [0u8; 64]; + match timeout(Duration::from_millis(2), reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => out.extend_from_slice(&buf[..n]), + Ok(Err(_)) | Err(_) => break, + } + } +} + +#[tokio::test] +async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { + let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + + for case in 0..64u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-model-fuzz-{case}"); + let quota = rng.random_range(1u64..=64u64); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut sent_c2s = Vec::new(); + let mut sent_s2c = Vec::new(); + let mut recv_at_server = Vec::new(); + let mut recv_at_client = Vec::new(); + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + let do_c2s = rng.random::(); + let chunk_len = rng.random_range(1usize..=12usize); + let mut chunk = vec![0u8; chunk_len]; + for b in &mut chunk { + *b = rng.random::(); + } + + if do_c2s { + if client_peer.write_all(&chunk).await.is_ok() { + sent_c2s.extend_from_slice(&chunk); + } + } else if server_peer.write_all(&chunk).await.is_ok() { + sent_s2c.extend_from_slice(&chunk); + } + + drain_available(&mut server_peer, &mut recv_at_server, 2).await; + drain_available(&mut client_peer, &mut recv_at_client, 2).await; + + assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); + assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); + assert!( + recv_at_server.len() + recv_at_client.len() <= quota as usize, + "fuzz case {case}: delivered bytes exceed quota" + ); + assert!( + stats.get_user_total_octets(&user) <= quota, + "fuzz case {case}: accounted bytes exceed quota" + ); + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("fuzz relay must terminate") + .expect("fuzz relay task must not panic"); + + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "fuzz case {case}: relay must end cleanly or with typed quota error" + ); + + assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); + assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); + assert!(stats.get_user_total_octets(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byte() { + let stats = Arc::new(Stats::new()); + let user = "quota-dual-race-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let gate = Arc::new(Barrier::new(3)); + + let writer_c2s = { + let gate = Arc::clone(&gate); + tokio::spawn(async move { + gate.wait().await; + let _ = client_peer.write_all(&[0xA1]).await; + client_peer + }) + }; + + let writer_s2c = { + let gate = Arc::clone(&gate); + tokio::spawn(async move { + gate.wait().await; + let _ = server_peer.write_all(&[0xB2]).await; + server_peer + }) + }; + + gate.wait().await; + + let mut client_peer = writer_c2s.await.expect("c2s writer must not panic"); + let mut server_peer = writer_s2c.await.expect("s2c writer must not panic"); + + let mut got_at_server = [0u8; 1]; + let mut got_at_client = [0u8; 1]; + + let n_server = match timeout( + Duration::from_millis(120), + server_peer.read(&mut got_at_server), + ) + .await + { + Ok(Ok(n)) => n, + _ => 0, + }; + let n_client = match timeout( + Duration::from_millis(120), + client_peer.read(&mut got_at_client), + ) + .await + { + Ok(Ok(n)) => n, + _ => 0, + }; + + assert!( + n_server + n_client <= 1, + "quota=1 race must not forward both concurrent direction bytes" + ); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("quota race relay must terminate") + .expect("quota race relay task must not panic"); + + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(stats.get_user_total_octets(user) <= 1); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_model_load() { + let stats = Arc::new(Stats::new()); + let user = "quota-model-stress-user"; + let quota = 96u64; + + let mut workers = Vec::new(); + for worker_id in 0..6u64 { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + workers.push(tokio::spawn(async move { + let mut rng = StdRng::seed_from_u64(0x9E37_79B9_7F4A_7C15 ^ worker_id); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 192, + 192, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut sent_c2s = Vec::new(); + let mut sent_s2c = Vec::new(); + let mut recv_at_server = Vec::new(); + let mut recv_at_client = Vec::new(); + + for _ in 0..64usize { + if relay.is_finished() { + break; + } + + let choose_c2s = rng.random::(); + let len = rng.random_range(1usize..=10usize); + let mut payload = vec![0u8; len]; + for b in &mut payload { + *b = rng.random::(); + } + + if choose_c2s { + if client_peer.write_all(&payload).await.is_ok() { + sent_c2s.extend_from_slice(&payload); + } + } else if server_peer.write_all(&payload).await.is_ok() { + sent_s2c.extend_from_slice(&payload); + } + + drain_available(&mut server_peer, &mut recv_at_server, 2).await; + drain_available(&mut client_peer, &mut recv_at_client, 2).await; + + assert_is_prefix(&recv_at_server, &sent_c2s, "stress C->S"); + assert_is_prefix(&recv_at_client, &sent_s2c, "stress S->C"); + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must end cleanly or with typed quota error" + ); + + recv_at_server.len() + recv_at_client.len() + })); + } + + let mut delivered_sum = 0usize; + for worker in workers { + delivered_sum = delivered_sum.saturating_add(worker.await.expect("worker must not panic")); + } + + assert!( + stats.get_user_total_octets(user) <= quota, + "global per-user quota must never overshoot under concurrent multi-relay model load" + ); + assert!( + delivered_sum <= quota as usize, + "aggregate delivered bytes across relays must remain within global quota" + ); +} diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs new file mode 100644 index 0000000..dfbab85 --- /dev/null +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -0,0 +1,207 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +async fn read_available(reader: &mut R, budget_ms: u64) -> usize { + let mut total = 0usize; + loop { + let mut buf = [0u8; 64]; + match timeout(Duration::from_millis(budget_ms), reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + total +} + +#[tokio::test] +async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-client-chunk"; + + // Leave only 1 byte remaining under quota. + stats.add_user_octets_from(user, 9); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + // Single chunk attempts to cross remaining budget (4 > 1). + client_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .unwrap(); + client_peer.shutdown().await.unwrap(); + + let forwarded = read_available(&mut server_peer, 60).await; + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate after quota overflow attempt") + .expect("relay task must not panic"); + + assert_eq!( + forwarded, 0, + "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" + ); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!( + stats.get_user_total_octets(user) <= 10, + "accounted bytes must never exceed quota after overflowing chunk" + ); +} + +#[tokio::test] +async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_off() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-boundary"; + + // Leave exactly 4 bytes remaining. + stats.add_user_octets_from(user, 6); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + // Exact boundary write should pass once. + client_peer + .write_all(&[0xAA, 0xBB, 0xCC, 0xDD]) + .await + .unwrap(); + + let mut exact = [0u8; 4]; + timeout(Duration::from_secs(1), server_peer.read_exact(&mut exact)) + .await + .unwrap() + .unwrap(); + assert_eq!(exact, [0xAA, 0xBB, 0xCC, 0xDD]); + + // Any extra byte after boundary should be rejected/cut off. + let _ = client_peer.write_all(&[0xEE]).await; + client_peer.shutdown().await.unwrap(); + + let leaked_after = read_available(&mut server_peer, 60).await; + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate at quota boundary") + .expect("relay task must not panic"); + + assert_eq!( + leaked_after, 0, + "no bytes may pass after exact boundary is consumed" + ); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(stats.get_user_total_octets(user) <= 10); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-stress"; + let quota = 12u64; + + let mut handles = Vec::new(); + for _ in 0..4usize { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + handles.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 192, + 192, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + // Aggressive sender tries to overflow shared user quota. + let burst = vec![0x5Au8; 64]; + let _ = client_peer.write_all(&burst).await; + let _ = client_peer.shutdown().await; + + let mut forwarded = 0usize; + forwarded = forwarded.saturating_add(read_available(&mut server_peer, 40).await); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must finish cleanly or with typed quota error" + ); + forwarded + })); + } + + let mut forwarded_sum = 0usize; + for handle in handles { + forwarded_sum = forwarded_sum.saturating_add(handle.await.expect("worker must not panic")); + } + + assert!( + forwarded_sum <= quota as usize, + "aggregate forwarded bytes across relays must stay within global user quota" + ); + assert!( + stats.get_user_total_octets(user) <= quota, + "global accounted bytes must stay within quota under overflow stress" + ); +} diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs new file mode 100644 index 0000000..9f68258 --- /dev/null +++ b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs @@ -0,0 +1,294 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Barrier; +use tokio::time::{Duration, timeout}; + +fn saturate_lock_cache() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}"))); + } + retained +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_writer_progresses_after_contention_release_without_external_wake() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-writer-positive"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before write"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x11]).await }); + + // Let the initial deferred wake fire while contention is still active. + tokio::time::sleep(Duration::from_millis(4)).await; + + drop(held_guard); + + let completed = timeout(Duration::from_millis(250), writer) + .await + .expect("writer must be re-polled and complete after lock release") + .expect("writer task must not panic"); + assert!(completed.is_ok(), "writer must complete after lock release"); +} + +#[tokio::test] +async fn edge_reader_progresses_after_contention_release_without_external_wake() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-reader-edge"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before read"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let reader = tokio::spawn(async move { + let mut one = [0u8; 1]; + io.read(&mut one).await + }); + + tokio::time::sleep(Duration::from_millis(4)).await; + drop(held_guard); + + let completed = timeout(Duration::from_millis(250), reader) + .await + .expect("reader must be re-polled and complete after lock release") + .expect("reader task must not panic"); + assert!(completed.is_ok(), "reader must complete after lock release"); +} + +#[tokio::test] +async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-adversarial"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before adversarial write"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x22]).await }); + + // Force multiple scheduler rounds while lock remains held so the first + // deferred wake has already been consumed under contention. + for _ in 0..32 { + tokio::task::yield_now().await; + } + + drop(held_guard); + + let completed = timeout(Duration::from_millis(300), writer) + .await + .expect("writer must not stay parked forever after release") + .expect("writer task must not panic"); + assert!(completed.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_parallel_waiters_resume_after_single_release_event() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = format!("quota-liveness-integration-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(13)); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before launching waiters"); + + let mut waiters = Vec::new(); + for _ in 0..12 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let barrier = Arc::clone(&barrier); + waiters.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(4096), + quota_exceeded, + tokio::time::Instant::now(), + ); + barrier.wait().await; + io.write_all(&[0x33]).await + })); + } + + barrier.wait().await; + tokio::time::sleep(Duration::from_millis(4)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let outcome = waiter.await.expect("waiter must not panic"); + assert!( + outcome.is_ok(), + "waiter must resume and complete after release" + ); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test] +async fn light_fuzz_release_timing_matrix_preserves_liveness() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let stats = Arc::new(Stats::new()); + + let mut seed = 0xD1CE_F00D_0123_4567u64; + for round in 0..64u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let delay_ms = 1 + (seed & 0x7) as u64; + let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock in fuzz round"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x44]).await }); + + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(300), writer) + .await + .expect("fuzz round writer must complete") + .expect("fuzz writer task must not panic"); + assert!( + done.is_ok(), + "fuzz round writer must not stall after release" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_repeated_contention_cycles_remain_live() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let stats = Arc::new(Stats::new()); + + for cycle in 0..40u32 { + let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold lock before stress cycle"); + + let mut tasks = Vec::new(); + for _ in 0..6 { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + io.write_all(&[0x55]).await + })); + } + + tokio::task::yield_now().await; + drop(held_guard); + + timeout(Duration::from_millis(700), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert!(outcome.is_ok(), "stress writer must complete"); + } + }) + .await + .expect("stress cycle must finish in bounded time"); + } +} diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs new file mode 100644 index 0000000..fa4878a --- /dev/null +++ b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs @@ -0,0 +1,310 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{AsyncWriteExt, ReadBuf}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}"))); + } + retained +} + +#[tokio::test] +async fn positive_contended_writer_emits_deferred_wake_for_liveness() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-positive-user"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling writer"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); + assert!(pending.is_pending()); + + timeout(Duration::from_millis(100), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("contended writer must receive deferred wake"); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); + assert!( + ready.is_ready(), + "writer must progress after contention release" + ); +} + +#[tokio::test] +async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-blackhat-writer"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling writer"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + for _ in 0..512 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); + assert!( + poll.is_pending(), + "writer must stay pending while lock is held" + ); + tokio::task::yield_now().await; + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes <= 128, + "pending writer retries must not trigger wake storm; observed wakes={wakes}" + ); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn edge_read_path_contention_keeps_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-read-edge"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling reader"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + + for _ in 0..512 { + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); + tokio::task::yield_now().await; + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes <= 128, + "pending reader retries must not trigger wake storm; observed wakes={wakes}" + ); + + drop(held_guard); + let mut buf = ReadBuf::new(&mut storage); + let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-fuzz-user"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before fuzz polling"); + + let counters_w = Arc::new(SharedCounters::new()); + let mut writer_io = StatsIo::new( + tokio::io::sink(), + counters_w, + Arc::clone(&stats), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let counters_r = Arc::new(SharedCounters::new()); + let mut reader_io = StatsIo::new( + tokio::io::empty(), + counters_r, + Arc::clone(&stats), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut seed = 0xBADC_0FFE_EE11_2211u64; + let mut storage = [0u8; 1]; + + for _ in 0..1024 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + if (seed & 1) == 0 { + let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]); + assert!(poll.is_pending()); + } else { + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); + } + tokio::task::yield_now().await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 192, + "mixed contention fuzz must keep deferred wake count tightly bounded" + ); + + drop(held_guard); + let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]); + assert!(ready_w.is_ready()); + + let mut buf = ReadBuf::new(&mut storage); + let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); + assert!(ready_r.is_ready()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"] +async fn stress_many_contended_writers_complete_after_release() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-waker-stress-user".to_string(); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before launching contended tasks"); + + let mut tasks = Vec::new(); + for _ in 0..32 { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + + io.write_all(&[0xAA]).await + })); + } + + for _ in 0..8 { + tokio::task::yield_now().await; + } + + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("stress task must not panic"); + assert!(result.is_ok(), "task must complete after lock release"); + } + }) + .await + .expect("all contended writer tasks must finish in bounded time after release"); +} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs new file mode 100644 index 0000000..50cdfa3 --- /dev/null +++ b/src/proxy/tests/relay_security_tests.rs @@ -0,0 +1,1284 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::future::poll_fn; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::Waker; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +#[tokio::test] +async fn quota_lock_contention_does_not_self_wake_pending_writer() { + let _guard = super::quota_user_lock_test_scope(); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + + let stats = Arc::new(Stats::new()); + let user = "quota-lock-contention-user"; + + let lock = super::quota_user_lock(user); + let _held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling writer"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!( + poll.is_pending(), + "writer must remain pending while lock is contended" + ); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "contended quota lock must not self-wake immediately and spin the executor" + ); +} + +#[tokio::test] +async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { + let _guard = super::quota_user_lock_test_scope(); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + + let stats = Arc::new(Stats::new()); + let user = "quota-lock-writer-liveness-user"; + + let lock = super::quota_user_lock(user); + let held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling writer"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!( + first.is_pending(), + "writer must remain pending while lock is contended" + ); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "deferred wake must not fire synchronously" + ); + + timeout(Duration::from_millis(50), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("contended writer must schedule a deferred wake in bounded time"); + let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes_after_first_yield >= 1, + "contended writer must schedule at least one deferred wake for liveness" + ); + + let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); + assert!( + second.is_pending(), + "writer remains pending while lock is still held" + ); + + for _ in 0..8 { + tokio::task::yield_now().await; + } + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + wakes_after_first_yield, + "writer contention should not schedule unbounded wake storms before lock acquisition" + ); + + drop(held_lock); + let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); + assert!( + released.is_ready(), + "writer must make progress once quota lock is released" + ); +} + +#[tokio::test] +async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { + let _guard = super::quota_user_lock_test_scope(); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + + let stats = Arc::new(Stats::new()); + let user = "quota-lock-read-liveness-user"; + + let lock = super::quota_user_lock(user); + let held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling reader"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!( + first.is_pending(), + "reader must remain pending while lock is contended" + ); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "read contention wake must not fire synchronously" + ); + + timeout(Duration::from_millis(50), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("read contention must schedule a deferred wake in bounded time"); + + drop(held_lock); + let mut buf_after_release = ReadBuf::new(&mut storage); + let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); + assert!( + released.is_ready(), + "reader must make progress once quota lock is released" + ); +} + +#[tokio::test] +async fn relay_bidirectional_enforces_live_user_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-user"; + stats.add_user_octets_from(user, 6); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x10, 0x20, 0x30, 0x40]) + .await + .expect("client write must succeed"); + + let mut forwarded = [0u8; 4]; + let _ = timeout( + Duration::from_millis(200), + server_peer.read_exact(&mut forwarded), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), + "relay must surface a typed quota error once live quota is exceeded" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { + let stats = Arc::new(Stats::new()); + let quota_user = "quota-exhausted-user"; + stats.add_user_octets_from(quota_user, 1); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0xde, 0xad, 0xbe, 0xef]) + .await + .expect("server write must succeed"); + + let mut observed = [0u8; 4]; + let forwarded = timeout( + Duration::from_millis(200), + client_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), + "no full server payload should be forwarded once quota is already exhausted" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() + { + let stats = Arc::new(Stats::new()); + let quota_user = "partial-leak-user"; + stats.add_user_octets_from(quota_user, 3); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(4), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .expect("server write must succeed"); + + let mut observed = [0u8; 8]; + let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n > 0), + "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { + let stats = Arc::new(Stats::new()); + let quota_user = "zero-quota-user"; + + for payload_len in [1usize, 16, 512, 4096] { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(0), + Arc::new(BufferPool::new()), + )); + + let payload = vec![0x7f; payload_len]; + let _ = server_peer.write_all(&payload).await; + + let mut observed = vec![0u8; payload_len]; + let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under zero-quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n > 0), + "zero quota must not forward any server bytes for payload_len={payload_len}" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "zero quota must terminate with the typed quota error for payload_len={payload_len}" + ); + } +} + +#[tokio::test] +async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { + let stats = Arc::new(Stats::new()); + let quota_user = "exact-boundary-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(4), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0x91, 0x92, 0x93, 0x94]) + .await + .expect("server write must succeed at exact quota boundary"); + + let mut observed = [0u8; 4]; + client_peer + .read_exact(&mut observed) + .await + .expect("client must receive the full payload at the exact quota boundary"); + assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish after exact boundary delivery") + .expect("relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must close with a typed quota error after reaching the exact boundary" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { + let stats = Arc::new(Stats::new()); + let quota_user = "client-exhausted-user"; + stats.add_user_octets_from(quota_user, 1); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x51, 0x52, 0x53, 0x54]) + .await + .expect("client write must succeed even when quota is already exhausted"); + + let mut observed = [0u8; 4]; + let forwarded = timeout( + Duration::from_millis(200), + server_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), + "client payload must not be fully forwarded once quota is already exhausted" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { + let stats = Arc::new(Stats::new()); + let quota_user = "quota-fuzz-user"; + stats.add_user_octets_from(quota_user, 2); + + for payload_len in [1usize, 32, 1024, 8192] { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(2), + Arc::new(BufferPool::new()), + )); + + let payload = vec![0xaa; payload_len]; + let _ = server_peer.write_all(&payload).await; + + let mut observed = vec![0u8; payload_len]; + let forwarded = timeout( + Duration::from_millis(200), + client_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == payload_len), + "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must keep returning the typed quota error for payload_len={payload_len}" + ); + } +} + +#[tokio::test] +async fn relay_bidirectional_terminates_on_activity_timeout() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let user = "timeout-user"; + + let (client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + None, // No quota + Arc::new(BufferPool::new()), + )); + + // Wait past the activity timeout threshold (1800 seconds) + buffer + tokio::time::sleep(Duration::from_secs(1805)).await; + + // Resume time to process timeouts + tokio::time::resume(); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must finish inside bounded timeout due to inactivity cutoff") + .expect("relay task must not panic"); + + assert!( + relay_result.is_ok(), + "relay should complete successfully on scheduled inactivity timeout" + ); + + // Verify client/server sockets are closed + drop(client_peer); + drop(server_peer); +} + +#[tokio::test] +async fn relay_bidirectional_watchdog_resists_premature_execution() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let user = "activity-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let mut relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + // Advance by half the timeout + tokio::time::sleep(Duration::from_secs(900)).await; + + // Provide activity + client_peer + .write_all(&[0xaa, 0xbb]) + .await + .expect("client write must succeed"); + client_peer.flush().await.unwrap(); + + // Advance by another half (total time since start is 1800, but since last activity is 900) + tokio::time::sleep(Duration::from_secs(900)).await; + + tokio::time::resume(); + + // Re-evaluating the task, it should NOT have timed out and still be pending + let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await; + assert!( + relay_result.is_err(), + "Relay must not exit prematurely as long as activity was received before timeout" + ); + + // Explicitly drop sockets to cleanly shut down relay loop + drop(client_peer); + drop(server_peer); + + let completion = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must complete securely after client disconnection") + .expect("relay task must not panic"); + assert!(completion.is_ok(), "relay exits clean"); +} + +#[tokio::test] +async fn relay_bidirectional_half_closure_terminates_cleanly() { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "half-close", + stats, + None, + Arc::new(BufferPool::new()), + )); + + // Half closure: drop the client completely but leave the server active. + drop(client_peer); + + // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. + // Eventually dropping the server cleanly closes the task. + drop(server_peer); + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn relay_bidirectional_zero_length_noise_fuzzing() { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "fuzz", + stats, + None, + Arc::new(BufferPool::new()), + )); + + // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) + for _ in 0..100 { + client_peer.write_all(&[]).await.unwrap(); + } + client_peer.write_all(&[1, 2, 3]).await.unwrap(); + client_peer.flush().await.unwrap(); + + let mut buf = [0u8; 3]; + server_peer.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, &[1, 2, 3]); + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn relay_bidirectional_asymmetric_backpressure() { + let stats = Arc::new(Stats::new()); + // Give the client stream an extremely narrow throughput limit explicitly + let (client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "slowloris", + stats, + None, + Arc::new(BufferPool::new()), + )); + + let payload = vec![0xba; 65536]; // 64k payload + + // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! + let write_res = + tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; + + assert!( + write_res.is_err(), + "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" + ); + + drop(client_peer); + drop(server_peer); + + let completion = timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap(); + assert!( + completion.is_ok() || completion.is_err(), + "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" + ); +} + +use rand::{RngExt, SeedableRng, rngs::StdRng}; + +#[tokio::test] +async fn relay_bidirectional_light_fuzzing_temporal_jitter() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let mut relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "fuzz-user", + stats, + None, + Arc::new(BufferPool::new()), + )); + + let mut rng = StdRng::seed_from_u64(0xDEADBEEF); + + for _ in 0..10 { + // Vary timing significantly up to 1600 seconds (limit is 1800s) + let jitter = rng.random_range(100..1600); + tokio::time::sleep(Duration::from_secs(jitter)).await; + + client_peer.write_all(&[0x11]).await.unwrap(); + client_peer.flush().await.unwrap(); + + // Ensure task has not died + let res = timeout(Duration::from_millis(10), &mut relay_task).await; + assert!( + res.is_err(), + "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses" + ); + } + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +struct FaultyReader { + error_once: Option, +} + +struct TwoPartyGate { + arrivals: AtomicUsize, + total_bytes: AtomicUsize, + wakers: Mutex>, +} + +impl TwoPartyGate { + fn new() -> Self { + Self { + arrivals: AtomicUsize::new(0), + total_bytes: AtomicUsize::new(0), + wakers: Mutex::new(Vec::new()), + } + } + + fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool { + if self.arrivals.load(Ordering::Relaxed) >= 2 { + return true; + } + + let prev = self.arrivals.fetch_add(1, Ordering::AcqRel); + if prev + 1 >= 2 { + let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); + for waker in wakers.drain(..) { + waker.wake(); + } + true + } else { + let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); + wakers.push(cx.waker().clone()); + false + } + } + + fn total_bytes(&self) -> usize { + self.total_bytes.load(Ordering::Relaxed) + } +} + +struct GateWriter { + gate: Arc, + entered: bool, +} + +impl GateWriter { + fn new(gate: Arc) -> Self { + Self { + gate, + entered: false, + } + } +} + +impl AsyncWrite for GateWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if !self.entered { + self.entered = true; + } + + if !self.gate.arrive_or_park(cx) { + return Poll::Pending; + } + + self.gate + .total_bytes + .fetch_add(buf.len(), Ordering::Relaxed); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct GateReader { + gate: Arc, + entered: bool, + emitted: bool, +} + +impl GateReader { + fn new(gate: Arc) -> Self { + Self { + gate, + entered: false, + emitted: false, + } + } +} + +impl AsyncRead for GateReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.emitted { + return Poll::Ready(Ok(())); + } + + if !self.entered { + self.entered = true; + } + + if !self.gate.arrive_or_park(cx) { + return Poll::Pending; + } + + buf.put_slice(&[0x42]); + self.gate.total_bytes.fetch_add(1, Ordering::Relaxed); + self.emitted = true; + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { + let stats = Arc::new(Stats::new()); + let gate = Arc::new(TwoPartyGate::new()); + let user = "concurrent-quota-write".to_string(); + + let writer_a = super::StatsIo::new( + GateWriter::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let writer_b = super::StatsIo::new( + GateWriter::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let task_a = tokio::spawn(async move { + let mut w = writer_a; + AsyncWriteExt::write_all(&mut w, &[0x01]).await + }); + let task_b = tokio::spawn(async move { + let mut w = writer_b; + AsyncWriteExt::write_all(&mut w, &[0x02]).await + }); + + let (res_a, res_b) = tokio::join!(task_a, task_b); + let _ = res_a.expect("task a must join"); + let _ = res_b.expect("task b must join"); + + assert!( + gate.total_bytes() <= 1, + "concurrent same-user writes must not forward more than one byte under quota=1" + ); + assert!( + stats.get_user_total_octets(&user) <= 1, + "concurrent same-user writes must not account over limit" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { + let stats = Arc::new(Stats::new()); + let gate = Arc::new(TwoPartyGate::new()); + let user = "concurrent-quota-read".to_string(); + + let reader_a = super::StatsIo::new( + GateReader::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let reader_b = super::StatsIo::new( + GateReader::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let task_a = tokio::spawn(async move { + let mut r = reader_a; + let mut one = [0u8; 1]; + AsyncReadExt::read_exact(&mut r, &mut one).await + }); + let task_b = tokio::spawn(async move { + let mut r = reader_b; + let mut one = [0u8; 1]; + AsyncReadExt::read_exact(&mut r, &mut one).await + }); + + let (res_a, res_b) = tokio::join!(task_a, task_b); + let _ = res_a.expect("task a must join"); + let _ = res_b.expect("task b must join"); + + assert!( + gate.total_bytes() <= 1, + "concurrent same-user reads must not consume more than one byte under quota=1" + ); + assert!( + stats.get_user_total_octets(&user) <= 1, + "concurrent same-user reads must not account over limit" + ); +} + +#[tokio::test] +async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { + let stats = Arc::new(Stats::new()); + let user = "parallel-quota-user"; + + for _ in 0..128 { + let (mut client_peer_a, relay_client_a) = duplex(256); + let (relay_server_a, mut server_peer_a) = duplex(256); + let (mut client_peer_b, relay_client_b) = duplex(256); + let (relay_server_b, mut server_peer_b) = duplex(256); + + let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a); + let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a); + let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b); + let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b); + + let relay_a = tokio::spawn(relay_bidirectional( + client_reader_a, + client_writer_a, + server_reader_a, + server_writer_a, + 64, + 64, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let relay_b = tokio::spawn(relay_bidirectional( + client_reader_b, + client_writer_b, + server_reader_b, + server_writer_b, + 64, + 64, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer_a.write_all(&[0x01]), + server_peer_a.write_all(&[0x02]), + client_peer_b.write_all(&[0x03]), + server_peer_b.write_all(&[0x04]), + ); + + let _ = timeout( + Duration::from_millis(50), + poll_fn(|cx| { + let mut one = [0u8; 1]; + let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); + Poll::Ready(()) + }), + ) + .await; + + drop(client_peer_a); + drop(server_peer_a); + drop(client_peer_b); + drop(server_peer_b); + + let _ = timeout(Duration::from_secs(1), relay_a).await; + let _ = timeout(Duration::from_secs(1), relay_b).await; + + assert!( + stats.get_user_total_octets(user) <= 1, + "parallel relays must not exceed configured quota" + ); + } +} + +impl FaultyReader { + fn permission_denied_with_message(message: impl Into) -> Self { + Self { + error_once: Some(io::Error::new( + io::ErrorKind::PermissionDenied, + message.into(), + )), + } + } +} + +impl AsyncRead for FaultyReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(err) = self.error_once.take() { + return Poll::Ready(Err(err)); + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + FaultyReader::permission_denied_with_message("user data quota exceeded"), + tokio::io::sink(), + 1024, + 1024, + "non-quota-permission-denied", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + ) + .await; + + drop(client_peer); + + assert!( + matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), + "non-quota transport PermissionDenied errors must remain IO errors" + ); +} + +#[tokio::test] +async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() { + let mut rng = StdRng::seed_from_u64(0xA11CE0B5); + + for i in 0..128u64 { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + + let random_len = rng.random_range(1..=48); + let mut msg = String::with_capacity(random_len); + for _ in 0..random_len { + let ch = (b'a' + (rng.random::() % 26)) as char; + msg.push(ch); + } + // Include the legacy quota string in a subset of fuzz cases to validate + // collision resistance against message-based classification. + if i % 7 == 0 { + msg = "user data quota exceeded".to_string(); + } + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + FaultyReader::permission_denied_with_message(msg), + tokio::io::sink(), + 1024, + 1024, + "fuzz-perm-denied", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + ) + .await; + + drop(client_peer); + + assert!( + matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), + "transport PermissionDenied case must stay typed as IO regardless of message content" + ); + } +} + +#[tokio::test] +async fn relay_half_close_keeps_reverse_direction_progressing() { + let stats = Arc::new(Stats::new()); + let user = "half-close-user"; + + let (client_peer, relay_client) = duplex(1024); + let (relay_server, server_peer) = duplex(1024); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 8192, + 8192, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + sp_writer + .write_all(&[0x10, 0x20, 0x30, 0x40]) + .await + .unwrap(); + sp_writer.shutdown().await.unwrap(); + + let mut inbound = [0u8; 4]; + cp_reader.read_exact(&mut inbound).await.unwrap(); + assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); + + cp_writer + .write_all(&[0xaa, 0xbb, 0xcc, 0xdd]) + .await + .unwrap(); + let mut outbound = [0u8; 4]; + sp_reader.read_exact(&mut outbound).await.unwrap(); + assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); +} diff --git a/src/proxy/tests/relay_watchdog_delta_security_tests.rs b/src/proxy/tests/relay_watchdog_delta_security_tests.rs new file mode 100644 index 0000000..8b9b209 --- /dev/null +++ b/src/proxy/tests/relay_watchdog_delta_security_tests.rs @@ -0,0 +1,64 @@ +use super::watchdog_delta; + +#[test] +fn positive_monotonic_growth_returns_exact_delta() { + assert_eq!(watchdog_delta(42, 40), 2); + assert_eq!(watchdog_delta(4096, 1024), 3072); +} + +#[test] +fn edge_equal_values_return_zero_delta() { + assert_eq!(watchdog_delta(0, 0), 0); + assert_eq!(watchdog_delta(777, 777), 0); +} + +#[test] +fn adversarial_wrap_like_regression_saturates_to_zero() { + // Simulates a wrapped or reset counter observation where current < previous. + assert_eq!(watchdog_delta(0, 1), 0); + assert_eq!(watchdog_delta(12, 4096), 0); +} + +#[test] +fn adversarial_blackhat_large_previous_value_never_underflows() { + let current = 3u64; + let previous = u64::MAX - 1; + assert_eq!(watchdog_delta(current, previous), 0); +} + +#[test] +fn light_fuzz_mixed_pairs_match_saturating_sub_contract() { + // Deterministic xorshift64* generator for reproducible pseudo-fuzzing. + let mut seed = 0xA51C_ED42_D00D_F00Du64; + + for _ in 0..10_000 { + seed ^= seed >> 12; + seed ^= seed << 25; + seed ^= seed >> 27; + let current = seed.wrapping_mul(0x2545_F491_4F6C_DD1D); + + seed ^= seed >> 12; + seed ^= seed << 25; + seed ^= seed >> 27; + let previous = seed.wrapping_mul(0x2545_F491_4F6C_DD1D); + + let expected = current.saturating_sub(previous); + let actual = watchdog_delta(current, previous); + assert_eq!( + actual, expected, + "delta mismatch for ({current}, {previous})" + ); + } +} + +#[test] +fn stress_long_running_monotonic_sequence_remains_exact() { + let mut prev = 0u64; + + for step in 1u64..=200_000 { + let curr = prev.saturating_add(step & 0x7); + let delta = watchdog_delta(curr, prev); + assert_eq!(delta, curr - prev); + prev = curr; + } +} diff --git a/src/proxy/tests/route_mode_coherence_adversarial_tests.rs b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs new file mode 100644 index 0000000..b7f816e --- /dev/null +++ b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs @@ -0,0 +1,234 @@ +use super::*; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; + +#[test] +fn positive_direct_cutover_sets_timestamp_and_snapshot_coherently() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle startup must not expose direct-since timestamp" + ); + + let emitted = runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must emit cutover"); + let observed = *rx.borrow(); + + assert_eq!( + observed, emitted, + "watch snapshot must match emitted cutover" + ); + assert_eq!(observed.mode, RelayRouteMode::Direct); + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct cutover must publish a non-empty direct-since timestamp" + ); +} + +#[test] +fn negative_idempotent_set_mode_does_not_mutate_timestamp_or_generation() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + let before_state = runtime.snapshot(); + let before_ts = runtime.direct_since_epoch_secs(); + + let changed = runtime.set_mode(RelayRouteMode::Direct); + + let after_state = runtime.snapshot(); + let after_ts = runtime.direct_since_epoch_secs(); + + assert!(changed.is_none(), "idempotent set_mode must return None"); + assert_eq!( + after_state.generation, before_state.generation, + "idempotent set_mode must not advance generation" + ); + assert_eq!( + after_ts, before_ts, + "idempotent set_mode must not alter direct-since timestamp" + ); +} + +#[test] +fn edge_middle_cutover_clears_timestamp() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct startup must expose direct-since timestamp" + ); + + let emitted = runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must emit cutover"); + let observed = *rx.borrow(); + + assert_eq!( + observed, emitted, + "watch snapshot must match emitted cutover" + ); + assert_eq!(observed.mode, RelayRouteMode::Middle); + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle cutover must clear direct-since timestamp" + ); +} + +#[test] +fn adversarial_blackhat_probe_sequence_observes_consistent_mode_timestamp_pairs() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + for _ in 0..2048usize { + let emitted_direct = runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must emit"); + let observed_direct = *rx.borrow(); + assert_eq!(observed_direct, emitted_direct); + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct observation must never expose empty timestamp" + ); + + let emitted_middle = runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must emit"); + let observed_middle = *rx.borrow(); + assert_eq!(observed_middle, emitted_middle); + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle observation must never expose direct timestamp" + ); + } +} + +#[test] +fn integration_subscriber_and_runtime_gates_stay_coherent_across_cutovers() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + let plan = [ + RelayRouteMode::Direct, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + ]; + + let mut expected_generation = 0u64; + + for mode in plan { + let emitted = runtime + .set_mode(mode) + .expect("each planned transition toggles mode and must emit"); + expected_generation = expected_generation.saturating_add(1); + + let watched = *rx.borrow(); + let snapshot = runtime.snapshot(); + + assert_eq!(emitted.mode, mode); + assert_eq!(emitted.generation, expected_generation); + assert_eq!(watched, emitted); + assert_eq!(snapshot, emitted); + + if matches!(mode, RelayRouteMode::Direct) { + assert!(runtime.direct_since_epoch_secs().is_some()); + } else { + assert!(runtime.direct_since_epoch_secs().is_none()); + } + } +} + +#[test] +fn light_fuzz_random_mode_plan_preserves_timestamp_and_generation_invariants() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let mut rng = StdRng::seed_from_u64(0x5EED_CAFE_D15C_A11E); + + let mut expected_mode = RelayRouteMode::Middle; + let mut expected_generation = 0u64; + + for _ in 0..25_000usize { + let candidate = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + + let changed = runtime.set_mode(candidate); + if candidate == expected_mode { + assert!(changed.is_none(), "idempotent fuzz step must not emit"); + continue; + } + + expected_mode = candidate; + expected_generation = expected_generation.saturating_add(1); + + let emitted = changed.expect("non-idempotent fuzz step must emit"); + assert_eq!(emitted.mode, expected_mode); + assert_eq!(emitted.generation, expected_generation); + + let snapshot = runtime.snapshot(); + assert_eq!(snapshot, emitted, "snapshot must match emitted cutover"); + + if matches!(snapshot.mode, RelayRouteMode::Direct) { + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct fuzz state must expose timestamp" + ); + } else { + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle fuzz state must clear timestamp" + ); + } + } +} + +#[test] +fn stress_parallel_subscribers_never_observe_generation_regression() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + + let mut readers = Vec::new(); + for _ in 0..4usize { + let runtime = Arc::clone(&runtime); + readers.push(std::thread::spawn(move || { + let rx = runtime.subscribe(); + let mut last = rx.borrow().generation; + for _ in 0..10_000usize { + let current = rx.borrow().generation; + assert!( + current >= last, + "watch generation must be monotonic for every subscriber" + ); + last = current; + std::thread::yield_now(); + } + })); + } + + for step in 0..20_000usize { + let mode = if (step & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = runtime.set_mode(mode); + } + + for reader in readers { + reader + .join() + .expect("parallel subscriber reader must not panic"); + } + + let final_state = runtime.snapshot(); + if matches!(final_state.mode, RelayRouteMode::Direct) { + assert!(runtime.direct_since_epoch_secs().is_some()); + } else { + assert!(runtime.direct_since_epoch_secs().is_none()); + } +} diff --git a/src/proxy/tests/route_mode_security_tests.rs b/src/proxy/tests/route_mode_security_tests.rs new file mode 100644 index 0000000..e5925fc --- /dev/null +++ b/src/proxy/tests/route_mode_security_tests.rs @@ -0,0 +1,404 @@ +use super::*; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +#[test] +fn cutover_stagger_delay_is_deterministic_for_same_inputs() { + let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + assert_eq!( + d1, d2, + "stagger delay must be deterministic for identical session/generation inputs" + ); +} + +#[test] +fn cutover_stagger_delay_stays_within_budget_bounds() { + // Black-hat model: censors trigger many cutovers and correlate disconnect timing. + // Keep delay inside a narrow coarse window to avoid long-tail spikes. + for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] { + for session_id in [0u64, 1, 2, 0xdead_beef, 0xfeed_face_cafe_babe, u64::MAX] { + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "stagger delay must remain in fixed 1000..=1999ms budget" + ); + } + } +} + +#[test] +fn cutover_stagger_delay_changes_with_generation_for_same_session() { + let session_id = 0x0123_4567_89ab_cdef; + let first = cutover_stagger_delay(session_id, 100); + let second = cutover_stagger_delay(session_id, 101); + assert_ne!( + first, second, + "adjacent cutover generations should decorrelate disconnect delays" + ); +} + +#[test] +fn route_runtime_set_mode_is_idempotent_for_same_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let first = runtime.snapshot(); + let changed = runtime.set_mode(RelayRouteMode::Direct); + let second = runtime.snapshot(); + + assert!( + changed.is_none(), + "setting already-active mode must not produce a cutover event" + ); + assert_eq!( + first.generation, second.generation, + "idempotent mode set must not bump generation" + ); +} + +#[test] +fn affected_cutover_state_triggers_only_for_newer_generation() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let initial = runtime.snapshot(); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(), + "current generation must not be considered a cutover for existing session" + ); + + let next = runtime + .set_mode(RelayRouteMode::Middle) + .expect("mode change must produce cutover state"); + let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation) + .expect("newer generation must be observed as cutover"); + + assert_eq!(seen.generation, next.generation); + assert_eq!(seen.mode, RelayRouteMode::Middle); +} + +#[test] +fn integration_watch_and_snapshot_follow_same_transition_sequence() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + + let sequence = [ + RelayRouteMode::Middle, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + RelayRouteMode::Direct, + RelayRouteMode::Middle, + ]; + + let mut expected_generation = 0u64; + let mut expected_mode = RelayRouteMode::Direct; + + for target in sequence { + let changed = runtime.set_mode(target); + if target == expected_mode { + assert!(changed.is_none(), "idempotent transition must return none"); + } else { + expected_mode = target; + expected_generation = expected_generation.saturating_add(1); + let emitted = changed.expect("real transition must emit cutover state"); + assert_eq!(emitted.mode, expected_mode); + assert_eq!(emitted.generation, expected_generation); + } + + let snap = runtime.snapshot(); + let watched = *rx.borrow(); + assert_eq!(snap, watched, "snapshot and watch state must stay aligned"); + assert_eq!(snap.mode, expected_mode); + assert_eq!(snap.generation, expected_generation); + } +} + +#[test] +fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() { + let session_mode = RelayRouteMode::Direct; + let current = RouteCutoverState { + mode: RelayRouteMode::Direct, + generation: 2, + }; + let session_generation = 0; + + assert!( + !is_session_affected_by_cutover(current, session_mode, session_generation), + "session on matching final route mode should not be force-cut over on intermediate generation bumps" + ); +} + +#[test] +fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() { + let current = RouteCutoverState { + mode: RelayRouteMode::Middle, + generation: 77, + }; + assert!( + !is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77), + "equal generation must never trigger cutover regardless of mode mismatch" + ); +} + +#[test] +fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let session_generation = runtime.snapshot().generation; + + runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must transition"); + runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must transition"); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(), + "direct session should survive when final mode returns to direct" + ); + assert!( + affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(), + "middle session should be cut over when final mode is direct" + ); +} + +#[test] +fn light_fuzz_cutover_predicate_matches_reference_oracle() { + let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED); + for _ in 0..20_000 { + let current = RouteCutoverState { + mode: if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }, + generation: rng.random_range(0u64..1_000_000), + }; + let session_mode = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let session_generation = rng.random_range(0u64..1_000_000); + + let expected = current.generation > session_generation && current.mode != session_mode; + let actual = is_session_affected_by_cutover(current, session_mode, session_generation); + assert_eq!( + actual, expected, + "cutover predicate must match mode-aware generation oracle" + ); + } +} + +#[test] +fn light_fuzz_set_mode_generation_tracks_only_real_transitions() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let mut rng = StdRng::seed_from_u64(0x0DDC0FFE); + + let mut expected_mode = RelayRouteMode::Direct; + let mut expected_generation = 0u64; + + for _ in 0..10_000 { + let candidate = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let changed = runtime.set_mode(candidate); + + if candidate == expected_mode { + assert!( + changed.is_none(), + "idempotent set_mode must not emit cutover state" + ); + } else { + expected_mode = candidate; + expected_generation = expected_generation.saturating_add(1); + let next = changed.expect("mode transition must emit cutover state"); + assert_eq!(next.mode, expected_mode); + assert_eq!(next.generation, expected_generation); + } + } + + let final_state = runtime.snapshot(); + assert_eq!(final_state.mode, expected_mode); + assert_eq!(final_state.generation, expected_generation); +} + +#[test] +fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + std::thread::scope(|scope| { + let mut writers = Vec::new(); + for worker in 0..4usize { + let runtime = Arc::clone(&runtime); + writers.push(scope.spawn(move || { + for step in 0..20_000usize { + let mode = if (worker + step) % 2 == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = runtime.set_mode(mode); + } + })); + } + + for writer in writers { + writer + .join() + .expect("route mode writer thread must not panic"); + } + + let rx = runtime.subscribe(); + for _ in 0..128 { + assert_eq!( + runtime.snapshot(), + *rx.borrow(), + "snapshot and watch state must converge after concurrent set_mode churn" + ); + std::thread::yield_now(); + } + }); +} + +#[test] +fn stress_concurrent_transition_count_matches_final_generation() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let successful_transitions = Arc::new(AtomicU64::new(0)); + + std::thread::scope(|scope| { + let mut workers = Vec::new(); + for worker in 0..6usize { + let runtime = Arc::clone(&runtime); + let successful_transitions = Arc::clone(&successful_transitions); + workers.push(scope.spawn(move || { + let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15); + for _ in 0..25_000usize { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + let mode = if (state & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + if runtime.set_mode(mode).is_some() { + successful_transitions.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker + .join() + .expect("route mode transition worker must not panic"); + } + }); + + let final_state = runtime.snapshot(); + assert_eq!( + final_state.generation, + successful_transitions.load(Ordering::Relaxed), + "final generation must equal number of accepted mode transitions" + ); + assert_eq!( + final_state, + *runtime.subscribe().borrow(), + "watch and snapshot state must match after concurrent transition accounting" + ); +} + +#[test] +fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() { + // Deterministic xorshift fuzzing keeps this test stable across runs. + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let session_id = s; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let generation = s; + + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "fuzzed inputs must always map into fixed stagger window" + ); + } +} + +#[test] +fn cutover_stagger_delay_distribution_has_no_empty_buckets_under_sequential_sessions() { + let mut buckets = [0usize; 1000]; + let generation = 4242u64; + + for session_id in 0..250_000u64 { + let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize; + let idx = delay_ms - 1000; + buckets[idx] += 1; + } + + let empty = buckets.iter().filter(|&&count| count == 0).count(); + assert_eq!( + empty, 0, + "all 1000 delay buckets must be exercised to avoid cutover herd clustering" + ); +} + +#[test] +fn light_fuzz_cutover_stagger_delay_distribution_stays_reasonably_uniform() { + let mut buckets = [0usize; 1000]; + let mut s: u64 = 0x1BAD_B002_CAFE_F00D; + + for _ in 0..300_000usize { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let session_id = s; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let generation = s; + + let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize; + buckets[delay_ms - 1000] += 1; + } + + let min = *buckets.iter().min().unwrap_or(&0); + let max = *buckets.iter().max().unwrap_or(&0); + assert!(min > 0, "fuzzed distribution must not leave empty buckets"); + assert!( + max <= min.saturating_mul(3), + "bucket skew is too high for anti-herd staggering (max={max}, min={min})" + ); +} + +#[test] +fn stress_cutover_stagger_delay_distribution_remains_stable_across_generations() { + for generation in [0u64, 1, 7, 31, 255, 1024, u32::MAX as u64, u64::MAX - 1] { + let mut buckets = [0usize; 1000]; + for session_id in 0..100_000u64 { + let delay_ms = + cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation).as_millis() as usize; + buckets[delay_ms - 1000] += 1; + } + + let min = *buckets.iter().min().unwrap_or(&0); + let max = *buckets.iter().max().unwrap_or(&0); + assert!( + max <= min.saturating_mul(4).max(1), + "generation={generation}: distribution collapsed (max={max}, min={min})" + ); + } +} diff --git a/src/startup.rs b/src/startup.rs index f6f857c..36b1506 100644 --- a/src/startup.rs +++ b/src/startup.rs @@ -175,7 +175,11 @@ impl StartupTracker { pub async fn start_component(&self, id: &'static str, details: Option) { let mut guard = self.state.write().await; guard.current_stage = id.to_string(); - if let Some(component) = guard.components.iter_mut().find(|component| component.id == id) { + if let Some(component) = guard + .components + .iter_mut() + .find(|component| component.id == id) + { if component.started_at_epoch_ms.is_none() { component.started_at_epoch_ms = Some(now_epoch_ms()); } @@ -208,7 +212,11 @@ impl StartupTracker { ) { let mut guard = self.state.write().await; let finished_at = now_epoch_ms(); - if let Some(component) = guard.components.iter_mut().find(|component| component.id == id) { + if let Some(component) = guard + .components + .iter_mut() + .find(|component| component.id == id) + { if component.started_at_epoch_ms.is_none() { component.started_at_epoch_ms = Some(finished_at); component.attempts = component.attempts.saturating_add(1); diff --git a/src/stats/beobachten.rs b/src/stats/beobachten.rs index 2e87fcc..3d3a2da 100644 --- a/src/stats/beobachten.rs +++ b/src/stats/beobachten.rs @@ -110,8 +110,8 @@ impl BeobachtenStore { } fn cleanup(inner: &mut BeobachtenInner, now: Instant, ttl: Duration) { - inner.entries.retain(|_, entry| { - now.saturating_duration_since(entry.last_seen) <= ttl - }); + inner + .entries + .retain(|_, entry| now.saturating_duration_since(entry.last_seen) <= ttl); } } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 0df4dc0..d13d834 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -5,149 +5,81 @@ pub mod beobachten; pub mod telemetry; +use dashmap::DashMap; +use lru::LruCache; +use parking_lot::Mutex; +use std::collections::VecDeque; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::num::NonZeroUsize; +use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use dashmap::DashMap; -use parking_lot::Mutex; -use lru::LruCache; -use std::num::NonZeroUsize; -use std::hash::{Hash, Hasher}; -use std::collections::hash_map::DefaultHasher; -use std::collections::VecDeque; use tracing::debug; -use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use self::telemetry::TelemetryPolicy; +use crate::config::{MeTelemetryLevel, MeWriterPickMode}; -const ME_WRITER_TEARDOWN_MODE_COUNT: usize = 2; -const ME_WRITER_TEARDOWN_REASON_COUNT: usize = 11; -const ME_WRITER_CLEANUP_SIDE_EFFECT_STEP_COUNT: usize = 2; -const ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT: usize = 12; -const ME_WRITER_TEARDOWN_DURATION_BUCKET_BOUNDS_MICROS: [u64; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT] = [ - 1_000, - 5_000, - 10_000, - 25_000, - 50_000, - 100_000, - 250_000, - 500_000, - 1_000_000, - 2_500_000, - 5_000_000, - 10_000_000, -]; -const ME_WRITER_TEARDOWN_DURATION_BUCKET_LABELS: [&str; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT] = [ - "0.001", - "0.005", - "0.01", - "0.025", - "0.05", - "0.1", - "0.25", - "0.5", - "1", - "2.5", - "5", - "10", -]; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -pub enum MeWriterTeardownMode { - Normal = 0, - HardDetach = 1, +#[derive(Clone, Copy)] +enum RouteConnectionGauge { + Direct, + Middle, } -impl MeWriterTeardownMode { - pub const ALL: [Self; ME_WRITER_TEARDOWN_MODE_COUNT] = - [Self::Normal, Self::HardDetach]; +#[derive(Debug, Clone, Copy)] +pub enum MeD2cFlushReason { + QueueDrain, + BatchFrames, + BatchBytes, + MaxDelay, + AckImmediate, + Close, +} - pub const fn as_str(self) -> &'static str { - match self { - Self::Normal => "normal", - Self::HardDetach => "hard_detach", +#[derive(Debug, Clone, Copy)] +pub enum MeD2cWriteMode { + Coalesced, + Split, +} + +#[derive(Debug, Clone, Copy)] +pub enum MeD2cQuotaRejectStage { + PreWrite, + PostWrite, +} + +#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"] +pub struct RouteConnectionLease { + stats: Arc, + gauge: RouteConnectionGauge, + active: bool, +} + +impl RouteConnectionLease { + fn new(stats: Arc, gauge: RouteConnectionGauge) -> Self { + Self { + stats, + gauge, + active: true, } } - const fn idx(self) -> usize { - self as usize + #[cfg(test)] + fn disarm(&mut self) { + self.active = false; } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -pub enum MeWriterTeardownReason { - ReaderExit = 0, - WriterTaskExit = 1, - PingSendFail = 2, - SignalSendFail = 3, - RouteChannelClosed = 4, - CloseRpcChannelClosed = 5, - PruneClosedWriter = 6, - ReapTimeoutExpired = 7, - ReapThresholdForce = 8, - ReapEmpty = 9, - WatchdogStuckDraining = 10, -} - -impl MeWriterTeardownReason { - pub const ALL: [Self; ME_WRITER_TEARDOWN_REASON_COUNT] = [ - Self::ReaderExit, - Self::WriterTaskExit, - Self::PingSendFail, - Self::SignalSendFail, - Self::RouteChannelClosed, - Self::CloseRpcChannelClosed, - Self::PruneClosedWriter, - Self::ReapTimeoutExpired, - Self::ReapThresholdForce, - Self::ReapEmpty, - Self::WatchdogStuckDraining, - ]; - - pub const fn as_str(self) -> &'static str { - match self { - Self::ReaderExit => "reader_exit", - Self::WriterTaskExit => "writer_task_exit", - Self::PingSendFail => "ping_send_fail", - Self::SignalSendFail => "signal_send_fail", - Self::RouteChannelClosed => "route_channel_closed", - Self::CloseRpcChannelClosed => "close_rpc_channel_closed", - Self::PruneClosedWriter => "prune_closed_writer", - Self::ReapTimeoutExpired => "reap_timeout_expired", - Self::ReapThresholdForce => "reap_threshold_force", - Self::ReapEmpty => "reap_empty", - Self::WatchdogStuckDraining => "watchdog_stuck_draining", +impl Drop for RouteConnectionLease { + fn drop(&mut self) { + if !self.active { + return; } - } - - const fn idx(self) -> usize { - self as usize - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -pub enum MeWriterCleanupSideEffectStep { - CloseSignalChannelFull = 0, - CloseSignalChannelClosed = 1, -} - -impl MeWriterCleanupSideEffectStep { - pub const ALL: [Self; ME_WRITER_CLEANUP_SIDE_EFFECT_STEP_COUNT] = - [Self::CloseSignalChannelFull, Self::CloseSignalChannelClosed]; - - pub const fn as_str(self) -> &'static str { - match self { - Self::CloseSignalChannelFull => "close_signal_channel_full", - Self::CloseSignalChannelClosed => "close_signal_channel_closed", + match self.gauge { + RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(), + RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(), } } - - const fn idx(self) -> usize { - self as usize - } } // ============= Stats ============= @@ -189,6 +121,10 @@ pub struct Stats { me_handshake_reject_total: AtomicU64, me_reader_eof_total: AtomicU64, me_idle_close_by_peer_total: AtomicU64, + relay_idle_soft_mark_total: AtomicU64, + relay_idle_hard_close_total: AtomicU64, + relay_pressure_evict_total: AtomicU64, + relay_protocol_desync_close_total: AtomicU64, me_crc_mismatch: AtomicU64, me_seq_mismatch: AtomicU64, me_endpoint_quarantine_total: AtomicU64, @@ -226,6 +162,44 @@ pub struct Stats { me_route_drop_queue_full: AtomicU64, me_route_drop_queue_full_base: AtomicU64, me_route_drop_queue_full_high: AtomicU64, + me_d2c_batches_total: AtomicU64, + me_d2c_batch_frames_total: AtomicU64, + me_d2c_batch_bytes_total: AtomicU64, + me_d2c_flush_reason_queue_drain_total: AtomicU64, + me_d2c_flush_reason_batch_frames_total: AtomicU64, + me_d2c_flush_reason_batch_bytes_total: AtomicU64, + me_d2c_flush_reason_max_delay_total: AtomicU64, + me_d2c_flush_reason_ack_immediate_total: AtomicU64, + me_d2c_flush_reason_close_total: AtomicU64, + me_d2c_data_frames_total: AtomicU64, + me_d2c_ack_frames_total: AtomicU64, + me_d2c_payload_bytes_total: AtomicU64, + me_d2c_write_mode_coalesced_total: AtomicU64, + me_d2c_write_mode_split_total: AtomicU64, + me_d2c_quota_reject_pre_write_total: AtomicU64, + me_d2c_quota_reject_post_write_total: AtomicU64, + me_d2c_frame_buf_shrink_total: AtomicU64, + me_d2c_frame_buf_shrink_bytes_total: AtomicU64, + me_d2c_batch_frames_bucket_1: AtomicU64, + me_d2c_batch_frames_bucket_2_4: AtomicU64, + me_d2c_batch_frames_bucket_5_8: AtomicU64, + me_d2c_batch_frames_bucket_9_16: AtomicU64, + me_d2c_batch_frames_bucket_17_32: AtomicU64, + me_d2c_batch_frames_bucket_gt_32: AtomicU64, + me_d2c_batch_bytes_bucket_0_1k: AtomicU64, + me_d2c_batch_bytes_bucket_1k_4k: AtomicU64, + me_d2c_batch_bytes_bucket_4k_16k: AtomicU64, + me_d2c_batch_bytes_bucket_16k_64k: AtomicU64, + me_d2c_batch_bytes_bucket_64k_128k: AtomicU64, + me_d2c_batch_bytes_bucket_gt_128k: AtomicU64, + me_d2c_flush_duration_us_bucket_0_50: AtomicU64, + me_d2c_flush_duration_us_bucket_51_200: AtomicU64, + me_d2c_flush_duration_us_bucket_201_1000: AtomicU64, + me_d2c_flush_duration_us_bucket_1001_5000: AtomicU64, + me_d2c_flush_duration_us_bucket_5001_20000: AtomicU64, + me_d2c_flush_duration_us_bucket_gt_20000: AtomicU64, + me_d2c_batch_timeout_armed_total: AtomicU64, + me_d2c_batch_timeout_fired_total: AtomicU64, me_writer_pick_sorted_rr_success_try_total: AtomicU64, me_writer_pick_sorted_rr_success_fallback_total: AtomicU64, me_writer_pick_sorted_rr_full_total: AtomicU64, @@ -251,26 +225,9 @@ pub struct Stats { pool_swap_total: AtomicU64, pool_drain_active: AtomicU64, pool_force_close_total: AtomicU64, - pool_drain_soft_evict_total: AtomicU64, - pool_drain_soft_evict_writer_total: AtomicU64, pool_stale_pick_total: AtomicU64, - me_writer_close_signal_drop_total: AtomicU64, - me_writer_close_signal_channel_full_total: AtomicU64, - me_draining_writers_reap_progress_total: AtomicU64, me_writer_removed_total: AtomicU64, me_writer_removed_unexpected_total: AtomicU64, - me_writer_teardown_attempt_total: - [[AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT]; ME_WRITER_TEARDOWN_REASON_COUNT], - me_writer_teardown_success_total: [AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT], - me_writer_teardown_timeout_total: AtomicU64, - me_writer_teardown_escalation_total: AtomicU64, - me_writer_teardown_noop_total: AtomicU64, - me_writer_cleanup_side_effect_failures_total: - [AtomicU64; ME_WRITER_CLEANUP_SIDE_EFFECT_STEP_COUNT], - me_writer_teardown_duration_bucket_hits: - [[AtomicU64; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT + 1]; ME_WRITER_TEARDOWN_MODE_COUNT], - me_writer_teardown_duration_sum_micros: [AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT], - me_writer_teardown_duration_count: [AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT], me_refill_triggered_total: AtomicU64, me_refill_skipped_inflight_total: AtomicU64, me_refill_failed_total: AtomicU64, @@ -281,11 +238,6 @@ pub struct Stats { me_inline_recovery_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64, - relay_adaptive_promotions_total: AtomicU64, - relay_adaptive_demotions_total: AtomicU64, - relay_adaptive_hard_promotions_total: AtomicU64, - reconnect_evict_total: AtomicU64, - reconnect_stale_close_total: AtomicU64, telemetry_core_enabled: AtomicBool, telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, @@ -372,8 +324,7 @@ impl Stats { let last_cleanup_epoch_secs = self .user_stats_last_cleanup_epoch_secs .load(Ordering::Relaxed); - if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs) - < USER_STATS_CLEANUP_INTERVAL_SECS + if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs) < USER_STATS_CLEANUP_INTERVAL_SECS { return; } @@ -415,7 +366,7 @@ impl Stats { me_level: self.telemetry_me_level(), } } - + pub fn increment_connects_all(&self) { if self.telemetry_core_enabled() { self.connects_all.fetch_add(1, Ordering::Relaxed); @@ -427,7 +378,8 @@ impl Stats { } } pub fn increment_current_connections_direct(&self) { - self.current_connections_direct.fetch_add(1, Ordering::Relaxed); + self.current_connections_direct + .fetch_add(1, Ordering::Relaxed); } pub fn decrement_current_connections_direct(&self) { Self::decrement_atomic_saturating(&self.current_connections_direct); @@ -438,35 +390,15 @@ impl Stats { pub fn decrement_current_connections_me(&self) { Self::decrement_atomic_saturating(&self.current_connections_me); } - pub fn increment_relay_adaptive_promotions_total(&self) { - if self.telemetry_core_enabled() { - self.relay_adaptive_promotions_total - .fetch_add(1, Ordering::Relaxed); - } + + pub fn acquire_direct_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_direct(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct) } - pub fn increment_relay_adaptive_demotions_total(&self) { - if self.telemetry_core_enabled() { - self.relay_adaptive_demotions_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_relay_adaptive_hard_promotions_total(&self) { - if self.telemetry_core_enabled() { - self.relay_adaptive_hard_promotions_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_reconnect_evict_total(&self) { - if self.telemetry_core_enabled() { - self.reconnect_evict_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_reconnect_stale_close_total(&self) { - if self.telemetry_core_enabled() { - self.reconnect_stale_close_total - .fetch_add(1, Ordering::Relaxed); - } + + pub fn acquire_me_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_me(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle) } pub fn increment_handshake_timeouts(&self) { if self.telemetry_core_enabled() { @@ -588,7 +520,8 @@ impl Stats { } pub fn increment_me_keepalive_timeout_by(&self, value: u64) { if self.telemetry_me_allows_normal() { - self.me_keepalive_timeout.fetch_add(value, Ordering::Relaxed); + self.me_keepalive_timeout + .fetch_add(value, Ordering::Relaxed); } } pub fn increment_me_rpc_proxy_req_signal_sent_total(&self) { @@ -633,7 +566,8 @@ impl Stats { } pub fn increment_me_handshake_reject_total(&self) { if self.telemetry_me_allows_normal() { - self.me_handshake_reject_total.fetch_add(1, Ordering::Relaxed); + self.me_handshake_reject_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_handshake_error_code(&self, code: i32) { @@ -657,6 +591,30 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn increment_relay_idle_soft_mark_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_soft_mark_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_idle_hard_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_hard_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_pressure_evict_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_pressure_evict_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_protocol_desync_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_protocol_desync_close_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_crc_mismatch(&self) { if self.telemetry_me_allows_normal() { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); @@ -674,22 +632,235 @@ impl Stats { } pub fn increment_me_route_drop_channel_closed(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_channel_closed.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_channel_closed + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_route_drop_queue_full(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_queue_full + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_route_drop_queue_full_base(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full_base.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_queue_full_base + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_route_drop_queue_full_high(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full_high.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_queue_full_high + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_batches_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batches_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_me_d2c_batch_frames_total(&self, frames: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batch_frames_total + .fetch_add(frames, Ordering::Relaxed); + } + } + pub fn add_me_d2c_batch_bytes_total(&self, bytes: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batch_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_flush_reason(&self, reason: MeD2cFlushReason) { + if !self.telemetry_me_allows_normal() { + return; + } + match reason { + MeD2cFlushReason::QueueDrain => { + self.me_d2c_flush_reason_queue_drain_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::BatchFrames => { + self.me_d2c_flush_reason_batch_frames_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::BatchBytes => { + self.me_d2c_flush_reason_batch_bytes_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::MaxDelay => { + self.me_d2c_flush_reason_max_delay_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::AckImmediate => { + self.me_d2c_flush_reason_ack_immediate_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::Close => { + self.me_d2c_flush_reason_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_data_frames_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_data_frames_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_ack_frames_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_ack_frames_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_me_d2c_payload_bytes_total(&self, bytes: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_payload_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_write_mode(&self, mode: MeD2cWriteMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeD2cWriteMode::Coalesced => { + self.me_d2c_write_mode_coalesced_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cWriteMode::Split => { + self.me_d2c_write_mode_split_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_quota_reject_total(&self, stage: MeD2cQuotaRejectStage) { + if !self.telemetry_me_allows_normal() { + return; + } + match stage { + MeD2cQuotaRejectStage::PreWrite => { + self.me_d2c_quota_reject_pre_write_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cQuotaRejectStage::PostWrite => { + self.me_d2c_quota_reject_post_write_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_frame_buf_shrink(&self, bytes_freed: u64) { + if !self.telemetry_me_allows_normal() { + return; + } + self.me_d2c_frame_buf_shrink_total + .fetch_add(1, Ordering::Relaxed); + self.me_d2c_frame_buf_shrink_bytes_total + .fetch_add(bytes_freed, Ordering::Relaxed); + } + pub fn observe_me_d2c_batch_frames(&self, frames: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match frames { + 0 => {} + 1 => { + self.me_d2c_batch_frames_bucket_1 + .fetch_add(1, Ordering::Relaxed); + } + 2..=4 => { + self.me_d2c_batch_frames_bucket_2_4 + .fetch_add(1, Ordering::Relaxed); + } + 5..=8 => { + self.me_d2c_batch_frames_bucket_5_8 + .fetch_add(1, Ordering::Relaxed); + } + 9..=16 => { + self.me_d2c_batch_frames_bucket_9_16 + .fetch_add(1, Ordering::Relaxed); + } + 17..=32 => { + self.me_d2c_batch_frames_bucket_17_32 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_batch_frames_bucket_gt_32 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_batch_bytes(&self, bytes: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match bytes { + 0..=1024 => { + self.me_d2c_batch_bytes_bucket_0_1k + .fetch_add(1, Ordering::Relaxed); + } + 1025..=4096 => { + self.me_d2c_batch_bytes_bucket_1k_4k + .fetch_add(1, Ordering::Relaxed); + } + 4097..=16_384 => { + self.me_d2c_batch_bytes_bucket_4k_16k + .fetch_add(1, Ordering::Relaxed); + } + 16_385..=65_536 => { + self.me_d2c_batch_bytes_bucket_16k_64k + .fetch_add(1, Ordering::Relaxed); + } + 65_537..=131_072 => { + self.me_d2c_batch_bytes_bucket_64k_128k + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_batch_bytes_bucket_gt_128k + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_flush_duration_us(&self, duration_us: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match duration_us { + 0..=50 => { + self.me_d2c_flush_duration_us_bucket_0_50 + .fetch_add(1, Ordering::Relaxed); + } + 51..=200 => { + self.me_d2c_flush_duration_us_bucket_51_200 + .fetch_add(1, Ordering::Relaxed); + } + 201..=1000 => { + self.me_d2c_flush_duration_us_bucket_201_1000 + .fetch_add(1, Ordering::Relaxed); + } + 1001..=5000 => { + self.me_d2c_flush_duration_us_bucket_1001_5000 + .fetch_add(1, Ordering::Relaxed); + } + 5001..=20_000 => { + self.me_d2c_flush_duration_us_bucket_5001_20000 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_flush_duration_us_bucket_gt_20000 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_batch_timeout_armed_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_d2c_batch_timeout_armed_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_batch_timeout_fired_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_d2c_batch_timeout_fired_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_writer_pick_success_try_total(&self, mode: MeWriterPickMode) { @@ -781,12 +952,14 @@ impl Stats { } pub fn increment_me_socks_kdf_strict_reject(&self) { if self.telemetry_me_allows_normal() { - self.me_socks_kdf_strict_reject.fetch_add(1, Ordering::Relaxed); + self.me_socks_kdf_strict_reject + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_socks_kdf_compat_fallback(&self) { if self.telemetry_me_allows_debug() { - self.me_socks_kdf_compat_fallback.fetch_add(1, Ordering::Relaxed); + self.me_socks_kdf_compat_fallback + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_secure_padding_invalid(&self) { @@ -818,13 +991,16 @@ impl Stats { self.desync_frames_bucket_0.fetch_add(1, Ordering::Relaxed); } 1..=2 => { - self.desync_frames_bucket_1_2.fetch_add(1, Ordering::Relaxed); + self.desync_frames_bucket_1_2 + .fetch_add(1, Ordering::Relaxed); } 3..=10 => { - self.desync_frames_bucket_3_10.fetch_add(1, Ordering::Relaxed); + self.desync_frames_bucket_3_10 + .fetch_add(1, Ordering::Relaxed); } _ => { - self.desync_frames_bucket_gt_10.fetch_add(1, Ordering::Relaxed); + self.desync_frames_bucket_gt_10 + .fetch_add(1, Ordering::Relaxed); } } } @@ -863,41 +1039,11 @@ impl Stats { self.pool_force_close_total.fetch_add(1, Ordering::Relaxed); } } - pub fn increment_pool_drain_soft_evict_total(&self) { - if self.telemetry_me_allows_normal() { - self.pool_drain_soft_evict_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_pool_drain_soft_evict_writer_total(&self) { - if self.telemetry_me_allows_normal() { - self.pool_drain_soft_evict_writer_total - .fetch_add(1, Ordering::Relaxed); - } - } pub fn increment_pool_stale_pick_total(&self) { if self.telemetry_me_allows_normal() { self.pool_stale_pick_total.fetch_add(1, Ordering::Relaxed); } } - pub fn increment_me_writer_close_signal_drop_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_close_signal_drop_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_close_signal_channel_full_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_close_signal_channel_full_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_draining_writers_reap_progress_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_draining_writers_reap_progress_total - .fetch_add(1, Ordering::Relaxed); - } - } pub fn increment_me_writer_removed_total(&self) { if self.telemetry_me_allows_debug() { self.me_writer_removed_total.fetch_add(1, Ordering::Relaxed); @@ -905,85 +1051,20 @@ impl Stats { } pub fn increment_me_writer_removed_unexpected_total(&self) { if self.telemetry_me_allows_normal() { - self.me_writer_removed_unexpected_total.fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_attempt_total( - &self, - reason: MeWriterTeardownReason, - mode: MeWriterTeardownMode, - ) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_attempt_total[reason.idx()][mode.idx()] + self.me_writer_removed_unexpected_total .fetch_add(1, Ordering::Relaxed); } } - pub fn increment_me_writer_teardown_success_total(&self, mode: MeWriterTeardownMode) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_success_total[mode.idx()].fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_timeout_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_timeout_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_escalation_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_escalation_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_noop_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_noop_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_cleanup_side_effect_failures_total( - &self, - step: MeWriterCleanupSideEffectStep, - ) { - if self.telemetry_me_allows_normal() { - self.me_writer_cleanup_side_effect_failures_total[step.idx()] - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn observe_me_writer_teardown_duration( - &self, - mode: MeWriterTeardownMode, - duration: Duration, - ) { - if !self.telemetry_me_allows_normal() { - return; - } - let duration_micros = duration.as_micros().min(u64::MAX as u128) as u64; - let mut bucket_idx = ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT; - for (idx, upper_bound_micros) in ME_WRITER_TEARDOWN_DURATION_BUCKET_BOUNDS_MICROS - .iter() - .copied() - .enumerate() - { - if duration_micros <= upper_bound_micros { - bucket_idx = idx; - break; - } - } - self.me_writer_teardown_duration_bucket_hits[mode.idx()][bucket_idx] - .fetch_add(1, Ordering::Relaxed); - self.me_writer_teardown_duration_sum_micros[mode.idx()] - .fetch_add(duration_micros, Ordering::Relaxed); - self.me_writer_teardown_duration_count[mode.idx()].fetch_add(1, Ordering::Relaxed); - } pub fn increment_me_refill_triggered_total(&self) { if self.telemetry_me_allows_debug() { - self.me_refill_triggered_total.fetch_add(1, Ordering::Relaxed); + self.me_refill_triggered_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_refill_skipped_inflight_total(&self) { if self.telemetry_me_allows_debug() { - self.me_refill_skipped_inflight_total.fetch_add(1, Ordering::Relaxed); + self.me_refill_skipped_inflight_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_refill_failed_total(&self) { @@ -1005,7 +1086,8 @@ impl Stats { } pub fn increment_me_no_writer_failfast_total(&self) { if self.telemetry_me_allows_normal() { - self.me_no_writer_failfast_total.fetch_add(1, Ordering::Relaxed); + self.me_no_writer_failfast_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_async_recovery_trigger_total(&self) { @@ -1016,7 +1098,8 @@ impl Stats { } pub fn increment_me_inline_recovery_total(&self) { if self.telemetry_me_allows_normal() { - self.me_inline_recovery_total.fetch_add(1, Ordering::Relaxed); + self.me_inline_recovery_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_ip_reservation_rollback_tcp_limit_total(&self) { @@ -1188,12 +1271,14 @@ impl Stats { } pub fn increment_me_floor_cap_block_total(&self) { if self.telemetry_me_allows_normal() { - self.me_floor_cap_block_total.fetch_add(1, Ordering::Relaxed); + self.me_floor_cap_block_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_floor_swap_idle_total(&self) { if self.telemetry_me_allows_normal() { - self.me_floor_swap_idle_total.fetch_add(1, Ordering::Relaxed); + self.me_floor_swap_idle_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_floor_swap_idle_failed_total(&self) { @@ -1202,8 +1287,12 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } - pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } - pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } + pub fn get_connects_all(&self) -> u64 { + self.connects_all.load(Ordering::Relaxed) + } + pub fn get_connects_bad(&self) -> u64 { + self.connects_bad.load(Ordering::Relaxed) + } pub fn get_current_connections_direct(&self) -> u64 { self.current_connections_direct.load(Ordering::Relaxed) } @@ -1214,26 +1303,18 @@ impl Stats { self.get_current_connections_direct() .saturating_add(self.get_current_connections_me()) } - pub fn get_relay_adaptive_promotions_total(&self) -> u64 { - self.relay_adaptive_promotions_total.load(Ordering::Relaxed) + pub fn get_me_keepalive_sent(&self) -> u64 { + self.me_keepalive_sent.load(Ordering::Relaxed) } - pub fn get_relay_adaptive_demotions_total(&self) -> u64 { - self.relay_adaptive_demotions_total.load(Ordering::Relaxed) + pub fn get_me_keepalive_failed(&self) -> u64 { + self.me_keepalive_failed.load(Ordering::Relaxed) } - pub fn get_relay_adaptive_hard_promotions_total(&self) -> u64 { - self.relay_adaptive_hard_promotions_total - .load(Ordering::Relaxed) + pub fn get_me_keepalive_pong(&self) -> u64 { + self.me_keepalive_pong.load(Ordering::Relaxed) } - pub fn get_reconnect_evict_total(&self) -> u64 { - self.reconnect_evict_total.load(Ordering::Relaxed) + pub fn get_me_keepalive_timeout(&self) -> u64 { + self.me_keepalive_timeout.load(Ordering::Relaxed) } - pub fn get_reconnect_stale_close_total(&self) -> u64 { - self.reconnect_stale_close_total.load(Ordering::Relaxed) - } - pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } - pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) } - pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) } - pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) } pub fn get_me_rpc_proxy_req_signal_sent_total(&self) -> u64 { self.me_rpc_proxy_req_signal_sent_total .load(Ordering::Relaxed) @@ -1254,8 +1335,12 @@ impl Stats { self.me_rpc_proxy_req_signal_close_sent_total .load(Ordering::Relaxed) } - pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) } - pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) } + pub fn get_me_reconnect_attempts(&self) -> u64 { + self.me_reconnect_attempts.load(Ordering::Relaxed) + } + pub fn get_me_reconnect_success(&self) -> u64 { + self.me_reconnect_success.load(Ordering::Relaxed) + } pub fn get_me_handshake_reject_total(&self) -> u64 { self.me_handshake_reject_total.load(Ordering::Relaxed) } @@ -1265,8 +1350,25 @@ impl Stats { pub fn get_me_idle_close_by_peer_total(&self) -> u64 { self.me_idle_close_by_peer_total.load(Ordering::Relaxed) } - pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) } - pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) } + pub fn get_relay_idle_soft_mark_total(&self) -> u64 { + self.relay_idle_soft_mark_total.load(Ordering::Relaxed) + } + pub fn get_relay_idle_hard_close_total(&self) -> u64 { + self.relay_idle_hard_close_total.load(Ordering::Relaxed) + } + pub fn get_relay_pressure_evict_total(&self) -> u64 { + self.relay_pressure_evict_total.load(Ordering::Relaxed) + } + pub fn get_relay_protocol_desync_close_total(&self) -> u64 { + self.relay_protocol_desync_close_total + .load(Ordering::Relaxed) + } + pub fn get_me_crc_mismatch(&self) -> u64 { + self.me_crc_mismatch.load(Ordering::Relaxed) + } + pub fn get_me_seq_mismatch(&self) -> u64 { + self.me_seq_mismatch.load(Ordering::Relaxed) + } pub fn get_me_endpoint_quarantine_total(&self) -> u64 { self.me_endpoint_quarantine_total.load(Ordering::Relaxed) } @@ -1277,8 +1379,7 @@ impl Stats { self.me_kdf_port_only_drift_total.load(Ordering::Relaxed) } pub fn get_me_hardswap_pending_reuse_total(&self) -> u64 { - self.me_hardswap_pending_reuse_total - .load(Ordering::Relaxed) + self.me_hardswap_pending_reuse_total.load(Ordering::Relaxed) } pub fn get_me_hardswap_pending_ttl_expired_total(&self) -> u64 { self.me_hardswap_pending_ttl_expired_total @@ -1359,12 +1460,10 @@ impl Stats { .load(Ordering::Relaxed) } pub fn get_me_writers_active_current_gauge(&self) -> u64 { - self.me_writers_active_current_gauge - .load(Ordering::Relaxed) + self.me_writers_active_current_gauge.load(Ordering::Relaxed) } pub fn get_me_writers_warm_current_gauge(&self) -> u64 { - self.me_writers_warm_current_gauge - .load(Ordering::Relaxed) + self.me_writers_warm_current_gauge.load(Ordering::Relaxed) } pub fn get_me_floor_cap_block_total(&self) -> u64 { self.me_floor_cap_block_total.load(Ordering::Relaxed) @@ -1384,7 +1483,9 @@ impl Stats { out.sort_by_key(|(code, _)| *code); out } - pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) } + pub fn get_me_route_drop_no_conn(&self) -> u64 { + self.me_route_drop_no_conn.load(Ordering::Relaxed) + } pub fn get_me_route_drop_channel_closed(&self) -> u64 { self.me_route_drop_channel_closed.load(Ordering::Relaxed) } @@ -1397,6 +1498,142 @@ impl Stats { pub fn get_me_route_drop_queue_full_high(&self) -> u64 { self.me_route_drop_queue_full_high.load(Ordering::Relaxed) } + pub fn get_me_d2c_batches_total(&self) -> u64 { + self.me_d2c_batches_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_total(&self) -> u64 { + self.me_d2c_batch_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_total(&self) -> u64 { + self.me_d2c_batch_bytes_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_queue_drain_total(&self) -> u64 { + self.me_d2c_flush_reason_queue_drain_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_batch_frames_total(&self) -> u64 { + self.me_d2c_flush_reason_batch_frames_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_batch_bytes_total(&self) -> u64 { + self.me_d2c_flush_reason_batch_bytes_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_max_delay_total(&self) -> u64 { + self.me_d2c_flush_reason_max_delay_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_ack_immediate_total(&self) -> u64 { + self.me_d2c_flush_reason_ack_immediate_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_close_total(&self) -> u64 { + self.me_d2c_flush_reason_close_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_data_frames_total(&self) -> u64 { + self.me_d2c_data_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_ack_frames_total(&self) -> u64 { + self.me_d2c_ack_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_payload_bytes_total(&self) -> u64 { + self.me_d2c_payload_bytes_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_write_mode_coalesced_total(&self) -> u64 { + self.me_d2c_write_mode_coalesced_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_write_mode_split_total(&self) -> u64 { + self.me_d2c_write_mode_split_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_quota_reject_pre_write_total(&self) -> u64 { + self.me_d2c_quota_reject_pre_write_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_quota_reject_post_write_total(&self) -> u64 { + self.me_d2c_quota_reject_post_write_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_frame_buf_shrink_total(&self) -> u64 { + self.me_d2c_frame_buf_shrink_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_frame_buf_shrink_bytes_total(&self) -> u64 { + self.me_d2c_frame_buf_shrink_bytes_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_1(&self) -> u64 { + self.me_d2c_batch_frames_bucket_1.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_2_4(&self) -> u64 { + self.me_d2c_batch_frames_bucket_2_4.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_5_8(&self) -> u64 { + self.me_d2c_batch_frames_bucket_5_8.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_9_16(&self) -> u64 { + self.me_d2c_batch_frames_bucket_9_16.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_17_32(&self) -> u64 { + self.me_d2c_batch_frames_bucket_17_32 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_gt_32(&self) -> u64 { + self.me_d2c_batch_frames_bucket_gt_32 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_0_1k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_0_1k.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_1k_4k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_4k_16k.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_16k_64k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_64k_128k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_64k_128k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_gt_128k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_gt_128k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_0_50(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_0_50 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_51_200(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_51_200 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_201_1000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_201_1000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_1001_5000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_1001_5000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_5001_20000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_5001_20000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_gt_20000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_gt_20000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_timeout_armed_total(&self) -> u64 { + self.me_d2c_batch_timeout_armed_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_timeout_fired_total(&self) -> u64 { + self.me_d2c_batch_timeout_fired_total + .load(Ordering::Relaxed) + } pub fn get_me_writer_pick_sorted_rr_success_try_total(&self) -> u64 { self.me_writer_pick_sorted_rr_success_try_total .load(Ordering::Relaxed) @@ -1482,119 +1719,33 @@ impl Stats { pub fn get_pool_force_close_total(&self) -> u64 { self.pool_force_close_total.load(Ordering::Relaxed) } - pub fn get_pool_drain_soft_evict_total(&self) -> u64 { - self.pool_drain_soft_evict_total.load(Ordering::Relaxed) - } - pub fn get_pool_drain_soft_evict_writer_total(&self) -> u64 { - self.pool_drain_soft_evict_writer_total.load(Ordering::Relaxed) - } pub fn get_pool_stale_pick_total(&self) -> u64 { self.pool_stale_pick_total.load(Ordering::Relaxed) } - pub fn get_me_writer_close_signal_drop_total(&self) -> u64 { - self.me_writer_close_signal_drop_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_close_signal_channel_full_total(&self) -> u64 { - self.me_writer_close_signal_channel_full_total - .load(Ordering::Relaxed) - } - pub fn get_me_draining_writers_reap_progress_total(&self) -> u64 { - self.me_draining_writers_reap_progress_total - .load(Ordering::Relaxed) - } pub fn get_me_writer_removed_total(&self) -> u64 { self.me_writer_removed_total.load(Ordering::Relaxed) } pub fn get_me_writer_removed_unexpected_total(&self) -> u64 { - self.me_writer_removed_unexpected_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_attempt_total( - &self, - reason: MeWriterTeardownReason, - mode: MeWriterTeardownMode, - ) -> u64 { - self.me_writer_teardown_attempt_total[reason.idx()][mode.idx()] + self.me_writer_removed_unexpected_total .load(Ordering::Relaxed) } - pub fn get_me_writer_teardown_attempt_total_by_mode(&self, mode: MeWriterTeardownMode) -> u64 { - MeWriterTeardownReason::ALL - .iter() - .copied() - .map(|reason| self.get_me_writer_teardown_attempt_total(reason, mode)) - .sum() - } - pub fn get_me_writer_teardown_success_total(&self, mode: MeWriterTeardownMode) -> u64 { - self.me_writer_teardown_success_total[mode.idx()].load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_timeout_total(&self) -> u64 { - self.me_writer_teardown_timeout_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_escalation_total(&self) -> u64 { - self.me_writer_teardown_escalation_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_noop_total(&self) -> u64 { - self.me_writer_teardown_noop_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_cleanup_side_effect_failures_total( - &self, - step: MeWriterCleanupSideEffectStep, - ) -> u64 { - self.me_writer_cleanup_side_effect_failures_total[step.idx()] - .load(Ordering::Relaxed) - } - pub fn get_me_writer_cleanup_side_effect_failures_total_all(&self) -> u64 { - MeWriterCleanupSideEffectStep::ALL - .iter() - .copied() - .map(|step| self.get_me_writer_cleanup_side_effect_failures_total(step)) - .sum() - } - pub fn me_writer_teardown_duration_bucket_labels( - ) -> &'static [&'static str; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT] { - &ME_WRITER_TEARDOWN_DURATION_BUCKET_LABELS - } - pub fn get_me_writer_teardown_duration_bucket_hits( - &self, - mode: MeWriterTeardownMode, - bucket_idx: usize, - ) -> u64 { - self.me_writer_teardown_duration_bucket_hits[mode.idx()][bucket_idx] - .load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_duration_bucket_total( - &self, - mode: MeWriterTeardownMode, - bucket_idx: usize, - ) -> u64 { - let capped_idx = bucket_idx.min(ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT); - let mut total = 0u64; - for idx in 0..=capped_idx { - total = total.saturating_add(self.get_me_writer_teardown_duration_bucket_hits(mode, idx)); - } - total - } - pub fn get_me_writer_teardown_duration_count(&self, mode: MeWriterTeardownMode) -> u64 { - self.me_writer_teardown_duration_count[mode.idx()].load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_duration_sum_seconds(&self, mode: MeWriterTeardownMode) -> f64 { - self.me_writer_teardown_duration_sum_micros[mode.idx()].load(Ordering::Relaxed) as f64 - / 1_000_000.0 - } pub fn get_me_refill_triggered_total(&self) -> u64 { self.me_refill_triggered_total.load(Ordering::Relaxed) } pub fn get_me_refill_skipped_inflight_total(&self) -> u64 { - self.me_refill_skipped_inflight_total.load(Ordering::Relaxed) + self.me_refill_skipped_inflight_total + .load(Ordering::Relaxed) } pub fn get_me_refill_failed_total(&self) -> u64 { self.me_refill_failed_total.load(Ordering::Relaxed) } pub fn get_me_writer_restored_same_endpoint_total(&self) -> u64 { - self.me_writer_restored_same_endpoint_total.load(Ordering::Relaxed) + self.me_writer_restored_same_endpoint_total + .load(Ordering::Relaxed) } pub fn get_me_writer_restored_fallback_total(&self) -> u64 { - self.me_writer_restored_fallback_total.load(Ordering::Relaxed) + self.me_writer_restored_fallback_total + .load(Ordering::Relaxed) } pub fn get_me_no_writer_failfast_total(&self) -> u64 { self.me_no_writer_failfast_total.load(Ordering::Relaxed) @@ -1613,7 +1764,7 @@ impl Stats { self.ip_reservation_rollback_quota_limit_total .load(Ordering::Relaxed) } - + pub fn increment_user_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1628,7 +1779,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.connects.fetch_add(1, Ordering::Relaxed); } - + pub fn increment_user_curr_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1643,11 +1794,37 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } - - pub fn decrement_user_curr_connects(&self, user: &str) { + + pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option) -> bool { if !self.telemetry_user_enabled() { - return; + return true; } + + self.maybe_cleanup_user_stats(); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if let Some(max) = limit + && current >= max + { + return false; + } + match counter.compare_exchange_weak( + current, + current.saturating_add(1), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return true, + Err(actual) => current = actual, + } + } + } + + pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { Self::touch_user_stats(stats.value()); @@ -1669,13 +1846,14 @@ impl Stats { } } } - + pub fn get_user_curr_connects(&self, user: &str) -> u64 { - self.user_stats.get(user) + self.user_stats + .get(user) .map(|s| s.curr_connects.load(Ordering::Relaxed)) .unwrap_or(0) } - + pub fn add_user_octets_from(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; @@ -1690,7 +1868,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); } - + pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; @@ -1705,7 +1883,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); } - + pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1720,7 +1898,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); } - + pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1735,17 +1913,20 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); } - + pub fn get_user_total_octets(&self, user: &str) -> u64 { - self.user_stats.get(user) + self.user_stats + .get(user) .map(|s| { - s.octets_from_client.load(Ordering::Relaxed) + - s.octets_to_client.load(Ordering::Relaxed) + s.octets_from_client.load(Ordering::Relaxed) + + s.octets_to_client.load(Ordering::Relaxed) }) .unwrap_or(0) } - - pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } + + pub fn get_handshake_timeouts(&self) -> u64 { + self.handshake_timeouts.load(Ordering::Relaxed) + } pub fn get_upstream_connect_attempt_total(&self) -> u64 { self.upstream_connect_attempt_total.load(Ordering::Relaxed) } @@ -1760,10 +1941,12 @@ impl Stats { .load(Ordering::Relaxed) } pub fn get_upstream_connect_attempts_bucket_1(&self) -> u64 { - self.upstream_connect_attempts_bucket_1.load(Ordering::Relaxed) + self.upstream_connect_attempts_bucket_1 + .load(Ordering::Relaxed) } pub fn get_upstream_connect_attempts_bucket_2(&self) -> u64 { - self.upstream_connect_attempts_bucket_2.load(Ordering::Relaxed) + self.upstream_connect_attempts_bucket_2 + .load(Ordering::Relaxed) } pub fn get_upstream_connect_attempts_bucket_3_4(&self) -> u64 { self.upstream_connect_attempts_bucket_3_4 @@ -1811,7 +1994,8 @@ impl Stats { } pub fn uptime_secs(&self) -> f64 { - self.start_time.read() + self.start_time + .read() .map(|t| t.elapsed().as_secs_f64()) .unwrap_or(0.0) } @@ -1820,9 +2004,11 @@ impl Stats { // ============= Replay Checker ============= pub struct ReplayChecker { - shards: Vec>, + handshake_shards: Vec>, + tls_shards: Vec>, shard_mask: usize, window: Duration, + tls_window: Duration, checks: AtomicU64, hits: AtomicU64, additions: AtomicU64, @@ -1848,7 +2034,7 @@ impl ReplayShard { seq_counter: 0, } } - + fn next_seq(&mut self) -> u64 { self.seq_counter += 1; self.seq_counter @@ -1859,13 +2045,13 @@ impl ReplayShard { return; } let cutoff = now.checked_sub(window).unwrap_or(now); - + while let Some((ts, _, _)) = self.queue.front() { if *ts >= cutoff { break; } let (_, key, queue_seq) = self.queue.pop_front().unwrap(); - + // Use key.as_ref() to get &[u8] — avoids Borrow ambiguity // between Borrow<[u8]> and Borrow> if let Some(entry) = self.cache.peek(key.as_ref()) @@ -1875,23 +2061,24 @@ impl ReplayShard { } } } - + fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { self.cleanup(now, window); // key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]> self.cache.get(key).is_some() } - + fn add(&mut self, key: &[u8], now: Instant, window: Duration) { self.cleanup(now, window); - + let seq = self.next_seq(); let boxed_key: Box<[u8]> = key.into(); - - self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq }); + + self.cache + .put(boxed_key.clone(), ReplayEntry { seen_at: now, seq }); self.queue.push_back((now, boxed_key, seq)); } - + fn len(&self) -> usize { self.cache.len() } @@ -1899,19 +2086,24 @@ impl ReplayShard { impl ReplayChecker { pub fn new(total_capacity: usize, window: Duration) -> Self { + const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120); let num_shards = 64; let shard_capacity = (total_capacity / num_shards).max(1); let cap = NonZeroUsize::new(shard_capacity).unwrap(); - let mut shards = Vec::with_capacity(num_shards); + let mut handshake_shards = Vec::with_capacity(num_shards); + let mut tls_shards = Vec::with_capacity(num_shards); for _ in 0..num_shards { - shards.push(Mutex::new(ReplayShard::new(cap))); + handshake_shards.push(Mutex::new(ReplayShard::new(cap))); + tls_shards.push(Mutex::new(ReplayShard::new(cap))); } Self { - shards, + handshake_shards, + tls_shards, shard_mask: num_shards - 1, window, + tls_window: window.max(MIN_TLS_REPLAY_WINDOW), checks: AtomicU64::new(0), hits: AtomicU64::new(0), additions: AtomicU64::new(0), @@ -1925,51 +2117,69 @@ impl ReplayChecker { (hasher.finish() as usize) & self.shard_mask } - fn check_and_add_internal(&self, data: &[u8]) -> bool { + fn check_and_add_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { self.checks.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); - let mut shard = self.shards[idx].lock(); + let mut shard = shards[idx].lock(); let now = Instant::now(); - let found = shard.check(data, now, self.window); + let found = shard.check(data, now, window); if found { self.hits.fetch_add(1, Ordering::Relaxed); } else { - shard.add(data, now, self.window); + shard.add(data, now, window); self.additions.fetch_add(1, Ordering::Relaxed); } found } - fn add_only(&self, data: &[u8]) { + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); - let mut shard = self.shards[idx].lock(); - shard.add(data, Instant::now(), self.window); + let mut shard = shards[idx].lock(); + shard.add(data, Instant::now(), window); } pub fn check_and_add_handshake(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data) + self.check_and_add_internal(data, &self.handshake_shards, self.window) } pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data) + self.check_and_add_internal(data, &self.tls_shards, self.tls_window) } // Compatibility helpers (non-atomic split operations) — prefer check_and_add_*. - pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) } - pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) } - pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) } - pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) } - + pub fn check_handshake(&self, data: &[u8]) -> bool { + self.check_and_add_handshake(data) + } + pub fn add_handshake(&self, data: &[u8]) { + self.add_only(data, &self.handshake_shards, self.window) + } + pub fn check_tls_digest(&self, data: &[u8]) -> bool { + self.check_and_add_tls_digest(data) + } + pub fn add_tls_digest(&self, data: &[u8]) { + self.add_only(data, &self.tls_shards, self.tls_window) + } + pub fn stats(&self) -> ReplayStats { let mut total_entries = 0; let mut total_queue_len = 0; - for shard in &self.shards { + for shard in &self.handshake_shards { let s = shard.lock(); total_entries += s.cache.len(); total_queue_len += s.queue.len(); } - + for shard in &self.tls_shards { + let s = shard.lock(); + total_entries += s.cache.len(); + total_queue_len += s.queue.len(); + } + ReplayStats { total_entries, total_queue_len, @@ -1977,34 +2187,41 @@ impl ReplayChecker { total_hits: self.hits.load(Ordering::Relaxed), total_additions: self.additions.load(Ordering::Relaxed), total_cleanups: self.cleanups.load(Ordering::Relaxed), - num_shards: self.shards.len(), + num_shards: self.handshake_shards.len() + self.tls_shards.len(), window_secs: self.window.as_secs(), } } - + pub async fn run_periodic_cleanup(&self) { let interval = if self.window.as_secs() > 60 { Duration::from_secs(30) } else { Duration::from_secs(self.window.as_secs().max(1) / 2) }; - + loop { tokio::time::sleep(interval).await; - + let now = Instant::now(); let mut cleaned = 0usize; - - for shard_mutex in &self.shards { + + for shard_mutex in &self.handshake_shards { let mut shard = shard_mutex.lock(); let before = shard.len(); shard.cleanup(now, self.window); let after = shard.len(); cleaned += before.saturating_sub(after); } - + for shard_mutex in &self.tls_shards { + let mut shard = shard_mutex.lock(); + let before = shard.len(); + shard.cleanup(now, self.tls_window); + let after = shard.len(); + cleaned += before.saturating_sub(after); + } + self.cleanups.fetch_add(1, Ordering::Relaxed); - + if cleaned > 0 { debug!(cleaned = cleaned, "Replay checker: periodic cleanup"); } @@ -2026,13 +2243,19 @@ pub struct ReplayStats { impl ReplayStats { pub fn hit_rate(&self) -> f64 { - if self.total_checks == 0 { 0.0 } - else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 } + if self.total_checks == 0 { + 0.0 + } else { + (self.total_hits as f64 / self.total_checks as f64) * 100.0 + } } - + pub fn ghost_ratio(&self) -> f64 { - if self.total_entries == 0 { 0.0 } - else { self.total_queue_len as f64 / self.total_entries as f64 } + if self.total_entries == 0 { + 0.0 + } else { + self.total_queue_len as f64 / self.total_entries as f64 + } } } @@ -2041,7 +2264,7 @@ mod tests { use super::*; use crate::config::MeTelemetryLevel; use std::sync::Arc; - + #[test] fn test_stats_shared_counters() { let stats = Arc::new(Stats::new()); @@ -2080,92 +2303,93 @@ mod tests { stats.increment_me_crc_mismatch(); stats.increment_me_keepalive_sent(); stats.increment_me_route_drop_queue_full(); + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(4); + stats.add_me_d2c_batch_bytes_total(4096); + stats.increment_me_d2c_flush_reason(MeD2cFlushReason::BatchBytes); + stats.increment_me_d2c_write_mode(MeD2cWriteMode::Coalesced); + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + stats.observe_me_d2c_frame_buf_shrink(1024); + stats.observe_me_d2c_batch_frames(4); + stats.observe_me_d2c_batch_bytes(4096); + stats.observe_me_d2c_flush_duration_us(120); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); assert_eq!(stats.get_me_crc_mismatch(), 0); assert_eq!(stats.get_me_keepalive_sent(), 0); assert_eq!(stats.get_me_route_drop_queue_full(), 0); + assert_eq!(stats.get_me_d2c_batches_total(), 0); + assert_eq!(stats.get_me_d2c_flush_reason_batch_bytes_total(), 0); + assert_eq!(stats.get_me_d2c_write_mode_coalesced_total(), 0); + assert_eq!(stats.get_me_d2c_quota_reject_pre_write_total(), 0); + assert_eq!(stats.get_me_d2c_frame_buf_shrink_total(), 0); + assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); } #[test] - fn test_teardown_counters_and_duration() { - let stats = Stats::new(); - stats.increment_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal, - ); - stats.increment_me_writer_teardown_success_total(MeWriterTeardownMode::Normal); - stats.observe_me_writer_teardown_duration( - MeWriterTeardownMode::Normal, - Duration::from_millis(3), - ); - stats.increment_me_writer_cleanup_side_effect_failures_total( - MeWriterCleanupSideEffectStep::CloseSignalChannelFull, - ); - - assert_eq!( - stats.get_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal - ), - 1 - ); - assert_eq!( - stats.get_me_writer_teardown_success_total(MeWriterTeardownMode::Normal), - 1 - ); - assert_eq!( - stats.get_me_writer_teardown_duration_count(MeWriterTeardownMode::Normal), - 1 - ); - assert!( - stats.get_me_writer_teardown_duration_sum_seconds(MeWriterTeardownMode::Normal) > 0.0 - ); - assert_eq!( - stats.get_me_writer_cleanup_side_effect_failures_total( - MeWriterCleanupSideEffectStep::CloseSignalChannelFull - ), - 1 - ); - } - - #[test] - fn test_teardown_counters_respect_me_silent() { + fn test_telemetry_policy_me_normal_blocks_d2c_debug_metrics() { let stats = Stats::new(); stats.apply_telemetry_policy(TelemetryPolicy { core_enabled: true, user_enabled: true, - me_level: MeTelemetryLevel::Silent, + me_level: MeTelemetryLevel::Normal, }); - stats.increment_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal, - ); - stats.increment_me_writer_teardown_timeout_total(); - stats.observe_me_writer_teardown_duration( - MeWriterTeardownMode::Normal, - Duration::from_millis(1), - ); - assert_eq!( - stats.get_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal - ), - 0 - ); - assert_eq!(stats.get_me_writer_teardown_timeout_total(), 0); - assert_eq!( - stats.get_me_writer_teardown_duration_count(MeWriterTeardownMode::Normal), - 0 - ); + + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(2); + stats.add_me_d2c_batch_bytes_total(2048); + stats.increment_me_d2c_flush_reason(MeD2cFlushReason::QueueDrain); + stats.observe_me_d2c_batch_frames(2); + stats.observe_me_d2c_batch_bytes(2048); + stats.observe_me_d2c_flush_duration_us(100); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); + + assert_eq!(stats.get_me_d2c_batches_total(), 1); + assert_eq!(stats.get_me_d2c_batch_frames_total(), 2); + assert_eq!(stats.get_me_d2c_batch_bytes_total(), 2048); + assert_eq!(stats.get_me_d2c_flush_reason_queue_drain_total(), 1); + assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); } - + + #[test] + fn test_telemetry_policy_me_debug_enables_d2c_debug_metrics() { + let stats = Stats::new(); + stats.apply_telemetry_policy(TelemetryPolicy { + core_enabled: true, + user_enabled: true, + me_level: MeTelemetryLevel::Debug, + }); + + stats.observe_me_d2c_batch_frames(7); + stats.observe_me_d2c_batch_bytes(70_000); + stats.observe_me_d2c_flush_duration_us(1400); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); + + assert_eq!(stats.get_me_d2c_batch_frames_bucket_5_8(), 1); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_64k_128k(), 1); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_1001_5000(), 1); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 1); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 1); + } + #[test] fn test_replay_checker_basic() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); assert!(!checker.check_handshake(b"test1")); // first time, inserts - assert!(checker.check_handshake(b"test1")); // duplicate + assert!(checker.check_handshake(b"test1")); // duplicate assert!(!checker.check_handshake(b"test2")); // new key inserts } - + #[test] fn test_replay_checker_duplicate_add() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); @@ -2173,7 +2397,7 @@ mod tests { checker.add_handshake(b"dup"); assert!(checker.check_handshake(b"dup")); } - + #[test] fn test_replay_checker_expiration() { let checker = ReplayChecker::new(100, Duration::from_millis(50)); @@ -2182,7 +2406,7 @@ mod tests { std::thread::sleep(Duration::from_millis(100)); assert!(!checker.check_handshake(b"expire")); } - + #[test] fn test_replay_checker_stats() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); @@ -2195,12 +2419,12 @@ mod tests { assert_eq!(stats.total_checks, 4); assert_eq!(stats.total_hits, 1); } - + #[test] fn test_replay_checker_many_keys() { let checker = ReplayChecker::new(10_000, Duration::from_secs(60)); for i in 0..500u32 { - checker.add_only(&i.to_le_bytes()); + checker.add_handshake(&i.to_le_bytes()); } for i in 0..500u32 { assert!(checker.check_handshake(&i.to_le_bytes())); @@ -2208,3 +2432,11 @@ mod tests { assert_eq!(checker.stats().total_entries, 500); } } + +#[cfg(test)] +#[path = "tests/connection_lease_security_tests.rs"] +mod connection_lease_security_tests; + +#[cfg(test)] +#[path = "tests/replay_checker_security_tests.rs"] +mod replay_checker_security_tests; diff --git a/src/stats/tests/connection_lease_security_tests.rs b/src/stats/tests/connection_lease_security_tests.rs new file mode 100644 index 0000000..1d15773 --- /dev/null +++ b/src/stats/tests/connection_lease_security_tests.rs @@ -0,0 +1,269 @@ +use super::*; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Barrier; + +#[test] +fn direct_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_direct(), 0); + + { + let _lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + } + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn middle_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_me(), 0); + + { + let _lease = stats.acquire_me_connection_lease(); + assert_eq!(stats.get_current_connections_me(), 1); + } + + assert_eq!(stats.get_current_connections_me(), 0); +} + +#[test] +fn connection_lease_disarm_prevents_double_release() { + let stats = Arc::new(Stats::new()); + + let mut lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + + stats.decrement_current_connections_direct(); + assert_eq!(stats.get_current_connections_direct(), 0); + + lease.disarm(); + drop(lease); + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn direct_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_direct_connection_lease(); + panic!("intentional panic to verify lease drop path"); + })); + + assert!( + panic_result.is_err(), + "panic must propagate from test closure" + ); + assert_eq!( + stats.get_current_connections_direct(), + 0, + "panic unwind must release direct route gauge" + ); +} + +#[test] +fn middle_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_me_connection_lease(); + panic!("intentional panic to verify middle lease drop path"); + })); + + assert!( + panic_result.is_err(), + "panic must propagate from test closure" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "panic unwind must release middle route gauge" + ); +} + +#[tokio::test] +async fn concurrent_mixed_route_lease_churn_balances_to_zero() { + const TASKS: usize = 48; + const ITERATIONS_PER_TASK: usize = 256; + + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(TASKS)); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + let barrier_for_task = barrier.clone(); + workers.push(tokio::spawn(async move { + barrier_for_task.wait().await; + for iter in 0..ITERATIONS_PER_TASK { + if (task_idx + iter) % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::task::yield_now().await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::task::yield_now().await; + } + } + })); + } + + for worker in workers { + worker.await.expect("lease churn worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after concurrent lease churn" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after concurrent lease churn" + ); +} + +#[tokio::test] +async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() { + const TASKS: usize = 64; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + workers.push(tokio::spawn(async move { + if task_idx % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } + })); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + let total = stats.get_current_connections_direct() + stats.get_current_connections_me(); + if total == TASKS as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all storm tasks must acquire route leases before abort"); + + for worker in &workers { + worker.abort(); + } + for worker in workers { + let joined = worker.await; + assert!(joined.is_err(), "aborted worker must return join error"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 0 + && stats.get_current_connections_me() == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all route gauges must drain to zero after abort storm"); +} + +#[test] +fn saturating_route_decrements_do_not_underflow_under_race() { + const THREADS: usize = 16; + const DECREMENTS_PER_THREAD: usize = 4096; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(THREADS); + + for _ in 0..THREADS { + let stats_for_thread = stats.clone(); + workers.push(std::thread::spawn(move || { + for _ in 0..DECREMENTS_PER_THREAD { + stats_for_thread.decrement_current_connections_direct(); + stats_for_thread.decrement_current_connections_me(); + } + })); + } + + for worker in workers { + worker.join().expect("decrement race worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route decrement races must never underflow" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route decrement races must never underflow" + ); +} + +#[tokio::test] +async fn direct_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_direct(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "aborted task must release direct route gauge" + ); +} + +#[tokio::test] +async fn middle_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_me(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "aborted task must release middle route gauge" + ); +} diff --git a/src/stats/tests/replay_checker_security_tests.rs b/src/stats/tests/replay_checker_security_tests.rs new file mode 100644 index 0000000..8e73204 --- /dev/null +++ b/src/stats/tests/replay_checker_security_tests.rs @@ -0,0 +1,80 @@ +use super::*; +use std::time::Duration; + +#[test] +fn replay_checker_keeps_tls_and_handshake_domains_isolated_for_same_key() { + let checker = ReplayChecker::new(128, Duration::from_millis(20)); + let key = b"same-key-domain-separation"; + + assert!( + !checker.check_and_add_handshake(key), + "first handshake use should be fresh" + ); + assert!( + !checker.check_and_add_tls_digest(key), + "same bytes in TLS domain should still be fresh" + ); + + assert!( + checker.check_and_add_handshake(key), + "second handshake use should be replay-hit" + ); + assert!( + checker.check_and_add_tls_digest(key), + "second TLS use should be replay-hit independently" + ); +} + +#[test] +fn replay_checker_tls_window_is_clamped_beyond_small_handshake_window() { + let checker = ReplayChecker::new(128, Duration::from_millis(20)); + let handshake_key = b"short-window-handshake"; + let tls_key = b"short-window-tls"; + + assert!(!checker.check_and_add_handshake(handshake_key)); + assert!(!checker.check_and_add_tls_digest(tls_key)); + + std::thread::sleep(Duration::from_millis(80)); + + assert!( + !checker.check_and_add_handshake(handshake_key), + "handshake key should expire under short configured window" + ); + assert!( + checker.check_and_add_tls_digest(tls_key), + "TLS key should still be replay-hit because TLS window is clamped to a secure minimum" + ); +} + +#[test] +fn replay_checker_compat_add_paths_do_not_cross_pollute_domains() { + let checker = ReplayChecker::new(128, Duration::from_secs(1)); + let key = b"compat-domain-separation"; + + checker.add_handshake(key); + assert!( + checker.check_and_add_handshake(key), + "handshake add helper must populate handshake domain" + ); + assert!( + !checker.check_and_add_tls_digest(key), + "handshake add helper must not pollute TLS domain" + ); + + checker.add_tls_digest(key); + assert!( + checker.check_and_add_tls_digest(key), + "TLS add helper must populate TLS domain" + ); +} + +#[test] +fn replay_checker_stats_reflect_dual_shard_domains() { + let checker = ReplayChecker::new(128, Duration::from_secs(1)); + let stats = checker.stats(); + + assert_eq!( + stats.num_shards, 128, + "stats should expose both shard domains (handshake + TLS)" + ); +} diff --git a/src/stream/buffer_pool.rs b/src/stream/buffer_pool.rs index dac0fb5..6cdac60 100644 --- a/src/stream/buffer_pool.rs +++ b/src/stream/buffer_pool.rs @@ -8,8 +8,8 @@ use bytes::BytesMut; use crossbeam_queue::ArrayQueue; use std::ops::{Deref, DerefMut}; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; // ============= Configuration ============= @@ -42,7 +42,7 @@ impl BufferPool { pub fn new() -> Self { Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS) } - + /// Create a buffer pool with custom configuration pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self { Self { @@ -54,7 +54,7 @@ impl BufferPool { hits: AtomicUsize::new(0), } } - + /// Get a buffer from the pool, or create a new one if empty pub fn get(self: &Arc) -> PooledBuffer { match self.buffers.pop() { @@ -76,7 +76,7 @@ impl BufferPool { } } } - + /// Try to get a buffer, returns None if pool is empty pub fn try_get(self: &Arc) -> Option { self.buffers.pop().map(|mut buffer| { @@ -88,12 +88,12 @@ impl BufferPool { } }) } - + /// Return a buffer to the pool fn return_buffer(&self, mut buffer: BytesMut) { // Clear the buffer but keep capacity buffer.clear(); - + // Only return if we haven't exceeded max and buffer is right size if buffer.capacity() >= self.buffer_size { // Try to push to pool, if full just drop @@ -103,7 +103,7 @@ impl BufferPool { // Actually we don't decrement here because the buffer might have been // grown beyond our size - we just let it go } - + /// Get pool statistics pub fn stats(&self) -> PoolStats { PoolStats { @@ -115,17 +115,21 @@ impl BufferPool { misses: self.misses.load(Ordering::Relaxed), } } - + /// Get buffer size pub fn buffer_size(&self) -> usize { self.buffer_size } - + /// Preallocate buffers to fill the pool pub fn preallocate(&self, count: usize) { let to_alloc = count.min(self.max_buffers); for _ in 0..to_alloc { - if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() { + if self + .buffers + .push(BytesMut::with_capacity(self.buffer_size)) + .is_err() + { break; } self.allocated.fetch_add(1, Ordering::Relaxed); @@ -183,22 +187,22 @@ impl PooledBuffer { pub fn take(mut self) -> BytesMut { self.buffer.take().unwrap() } - + /// Get the capacity of the buffer pub fn capacity(&self) -> usize { self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0) } - + /// Check if buffer is empty pub fn is_empty(&self) -> bool { self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true) } - + /// Get the length of data in buffer pub fn len(&self) -> usize { self.buffer.as_ref().map(|b| b.len()).unwrap_or(0) } - + /// Clear the buffer pub fn clear(&mut self) { if let Some(ref mut b) = self.buffer { @@ -209,7 +213,7 @@ impl PooledBuffer { impl Deref for PooledBuffer { type Target = BytesMut; - + fn deref(&self) -> &Self::Target { self.buffer.as_ref().expect("buffer taken") } @@ -259,7 +263,7 @@ impl<'a> ScopedBuffer<'a> { impl<'a> Deref for ScopedBuffer<'a> { type Target = BytesMut; - + fn deref(&self) -> &Self::Target { self.buffer.deref() } @@ -280,108 +284,108 @@ impl<'a> Drop for ScopedBuffer<'a> { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_pool_basic() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // Get a buffer let mut buf1 = pool.get(); buf1.extend_from_slice(b"hello"); assert_eq!(&buf1[..], b"hello"); - + // Drop returns to pool drop(buf1); - + let stats = pool.stats(); assert_eq!(stats.pooled, 1); assert_eq!(stats.hits, 0); assert_eq!(stats.misses, 1); - + // Get again - should reuse let buf2 = pool.get(); assert!(buf2.is_empty()); // Buffer was cleared - + let stats = pool.stats(); assert_eq!(stats.pooled, 0); assert_eq!(stats.hits, 1); } - + #[test] fn test_pool_multiple_buffers() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // Get multiple buffers let buf1 = pool.get(); let buf2 = pool.get(); let buf3 = pool.get(); - + let stats = pool.stats(); assert_eq!(stats.allocated, 3); assert_eq!(stats.pooled, 0); - + // Return all drop(buf1); drop(buf2); drop(buf3); - + let stats = pool.stats(); assert_eq!(stats.pooled, 3); } - + #[test] fn test_pool_overflow() { let pool = Arc::new(BufferPool::with_config(1024, 2)); - + // Get 3 buffers (more than max) let buf1 = pool.get(); let buf2 = pool.get(); let buf3 = pool.get(); - + // Return all - only 2 should be pooled drop(buf1); drop(buf2); drop(buf3); - + let stats = pool.stats(); assert_eq!(stats.pooled, 2); } - + #[test] fn test_pool_take() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + let mut buf = pool.get(); buf.extend_from_slice(b"data"); - + // Take ownership, buffer should not return to pool let taken = buf.take(); assert_eq!(&taken[..], b"data"); - + let stats = pool.stats(); assert_eq!(stats.pooled, 0); } - + #[test] fn test_pool_preallocate() { let pool = Arc::new(BufferPool::with_config(1024, 10)); pool.preallocate(5); - + let stats = pool.stats(); assert_eq!(stats.pooled, 5); assert_eq!(stats.allocated, 5); } - + #[test] fn test_pool_try_get() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // Pool is empty, try_get returns None assert!(pool.try_get().is_none()); - + // Add a buffer to pool pool.preallocate(1); - + // Now try_get should succeed once while the buffer is held let buf = pool.try_get(); assert!(buf.is_some()); @@ -391,50 +395,50 @@ mod tests { drop(buf); assert!(pool.try_get().is_some()); } - + #[test] fn test_hit_rate() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // First get is a miss let buf1 = pool.get(); drop(buf1); - + // Second get is a hit let buf2 = pool.get(); drop(buf2); - + // Third get is a hit let _buf3 = pool.get(); - + let stats = pool.stats(); assert_eq!(stats.hits, 2); assert_eq!(stats.misses, 1); assert!((stats.hit_rate() - 66.67).abs() < 1.0); } - + #[test] fn test_scoped_buffer() { let pool = Arc::new(BufferPool::with_config(1024, 10)); let mut buf = pool.get(); - + { let mut scoped = ScopedBuffer::new(&mut buf); scoped.extend_from_slice(b"scoped data"); assert_eq!(&scoped[..], b"scoped data"); } - + // After scoped is dropped, buffer is cleared assert!(buf.is_empty()); } - + #[test] fn test_concurrent_access() { use std::thread; - + let pool = Arc::new(BufferPool::with_config(1024, 100)); let mut handles = vec![]; - + for _ in 0..10 { let pool_clone = Arc::clone(&pool); handles.push(thread::spawn(move || { @@ -445,11 +449,11 @@ mod tests { } })); } - + for handle in handles { handle.join().unwrap(); } - + let stats = pool.stats(); // All buffers should be returned assert!(stats.pooled > 0); diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs index 744b186..d962321 100644 --- a/src/stream/crypto_stream.rs +++ b/src/stream/crypto_stream.rs @@ -37,7 +37,7 @@ //! //! Backpressure //! - pending ciphertext buffer is bounded (configurable per connection) -//! - pending is full and upstream is pending +//! - pending is full and upstream is pending //! -> poll_write returns Poll::Pending //! -> do not accept any plaintext //! @@ -59,8 +59,8 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace}; -use crate::crypto::AesCtr; use super::state::{StreamState, YieldBuffer}; +use crate::crypto::AesCtr; // ============= Constants ============= @@ -152,9 +152,9 @@ impl CryptoReader { fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoReaderState::Poisoned { error } => error.take().unwrap_or_else(|| { - io::Error::other("stream previously poisoned") - }), + CryptoReaderState::Poisoned { error } => error + .take() + .unwrap_or_else(|| io::Error::other("stream previously poisoned")), _ => io::Error::other("stream not poisoned"), } } @@ -221,7 +221,11 @@ impl AsyncRead for CryptoReader { let filled = buf.filled_mut(); this.decryptor.apply(&mut filled[before..after]); - trace!(bytes_read, state = this.state_name(), "CryptoReader decrypted chunk"); + trace!( + bytes_read, + state = this.state_name(), + "CryptoReader decrypted chunk" + ); return Poll::Ready(Ok(())); } @@ -503,9 +507,9 @@ impl CryptoWriter { fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoWriterState::Poisoned { error } => error.take().unwrap_or_else(|| { - io::Error::other("stream previously poisoned") - }), + CryptoWriterState::Poisoned { error } => error + .take() + .unwrap_or_else(|| io::Error::other("stream previously poisoned")), _ => io::Error::other("stream not poisoned"), } } @@ -525,7 +529,11 @@ impl CryptoWriter { } /// Select how many plaintext bytes can be accepted in buffering path - fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize, max_pending: usize) -> usize { + fn select_to_accept_for_buffering( + state: &CryptoWriterState, + buf_len: usize, + max_pending: usize, + ) -> usize { if buf_len == 0 { return 0; } @@ -602,11 +610,7 @@ impl CryptoWriter { } impl AsyncWrite for CryptoWriter { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); // Poisoned? @@ -629,8 +633,11 @@ impl AsyncWrite for CryptoWriter { Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => { // Upstream blocked. Apply ideal backpressure - let to_accept = - Self::select_to_accept_for_buffering(&this.state, buf.len(), this.max_pending_write); + let to_accept = Self::select_to_accept_for_buffering( + &this.state, + buf.len(), + this.max_pending_write, + ); if to_accept == 0 { trace!( diff --git a/src/stream/frame.rs b/src/stream/frame.rs index 5c93ea7..08baf4c 100644 --- a/src/stream/frame.rs +++ b/src/stream/frame.rs @@ -9,8 +9,8 @@ use bytes::{Bytes, BytesMut}; use std::io::Result; use std::sync::Arc; -use crate::protocol::constants::ProtoTag; use crate::crypto::SecureRandom; +use crate::protocol::constants::ProtoTag; // ============= Frame Types ============= @@ -31,27 +31,27 @@ impl Frame { meta: FrameMeta::default(), } } - + /// Create a new frame with data and metadata pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self { Self { data, meta } } - + /// Create an empty frame pub fn empty() -> Self { Self::new(Bytes::new()) } - + /// Check if frame is empty pub fn is_empty(&self) -> bool { self.data.is_empty() } - + /// Get frame length pub fn len(&self) -> usize { self.data.len() } - + /// Create a QuickAck request frame pub fn quickack(data: Bytes) -> Self { Self { @@ -62,7 +62,7 @@ impl Frame { }, } } - + /// Create a simple ACK frame pub fn simple_ack(data: Bytes) -> Self { Self { @@ -91,25 +91,25 @@ impl FrameMeta { pub fn new() -> Self { Self::default() } - + /// Create with quickack flag pub fn with_quickack(mut self) -> Self { self.quickack = true; self } - + /// Create with simple_ack flag pub fn with_simple_ack(mut self) -> Self { self.simple_ack = true; self } - + /// Create with padding length pub fn with_padding(mut self, len: u8) -> Self { self.padding_len = len; self } - + /// Check if any special flags are set pub fn has_flags(&self) -> bool { self.quickack || self.simple_ack @@ -122,12 +122,12 @@ impl FrameMeta { pub trait FrameCodec: Send + Sync { /// Get the protocol tag for this codec fn proto_tag(&self) -> ProtoTag; - + /// Encode a frame into the destination buffer /// /// Returns the number of bytes written. fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result; - + /// Try to decode a frame from the source buffer /// /// Returns: @@ -137,10 +137,10 @@ pub trait FrameCodec: Send + Sync { /// /// On success, the consumed bytes are removed from `src`. fn decode(&self, src: &mut BytesMut) -> Result>; - + /// Get the minimum bytes needed to determine frame length fn min_header_size(&self) -> usize; - + /// Get the maximum allowed frame size fn max_frame_size(&self) -> usize { // Default: 16MB @@ -162,30 +162,28 @@ pub fn create_codec(proto_tag: ProtoTag, rng: Arc) -> Box Self { self.max_frame_size = size; self } - + /// Get protocol tag pub fn proto_tag(&self) -> ProtoTag { self.proto_tag @@ -56,7 +56,7 @@ impl FrameCodec { impl Decoder for FrameCodec { type Item = Frame; type Error = io::Error; - + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match self.proto_tag { ProtoTag::Abridged => decode_abridged(src, self.max_frame_size), @@ -68,7 +68,7 @@ impl Decoder for FrameCodec { impl Encoder for FrameCodec { type Error = io::Error; - + fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { match self.proto_tag { ProtoTag::Abridged => encode_abridged(&frame, dst), @@ -84,18 +84,18 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result= 0x80 { meta.quickack = true; } - + let header_len; - + if len_words == 0x7f { // Extended length (3 more bytes needed) if src.len() < 4 { @@ -106,46 +106,49 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result max_size { return Err(Error::new( ErrorKind::InvalidData, - format!("frame too large: {} bytes (max {})", byte_len, max_size) + format!("frame too large: {} bytes (max {})", byte_len, max_size), )); } - + let total_len = header_len + byte_len; - + if src.len() < total_len { // Reserve space for the rest of the frame src.reserve(total_len - src.len()); return Ok(None); } - + // Extract data let _ = src.split_to(header_len); let data = src.split_to(byte_len).freeze(); - + Ok(Some(Frame::with_meta(data, meta))) } fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { let data = &frame.data; - + // Validate alignment if !data.len().is_multiple_of(4) { return Err(Error::new( ErrorKind::InvalidInput, - format!("abridged frame must be 4-byte aligned, got {} bytes", data.len()) + format!( + "abridged frame must be 4-byte aligned, got {} bytes", + data.len() + ), )); } - + // Simple ACK: send reversed data without header if frame.meta.simple_ack { dst.reserve(data.len()); @@ -154,9 +157,9 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { } return Ok(()); } - + let len_words = data.len() / 4; - + if len_words < 0x7f { // Short header dst.reserve(1 + data.len()); @@ -178,10 +181,10 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { } else { return Err(Error::new( ErrorKind::InvalidInput, - format!("frame too large: {} bytes", data.len()) + format!("frame too large: {} bytes", data.len()), )); } - + dst.extend_from_slice(data); Ok(()) } @@ -192,58 +195,58 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result