Compare commits

..

65 Commits

Author SHA1 Message Date
Alexey 840713a359 Merge pull request #847 from AndreyOsipuk/feat/client-mss-relay
feat(server): client_mss_bulk — fragment only the handshake, restore MSS for bulk data (cuts pps)
2026-06-20 22:10:04 +03:00
Andrey Osipuk 50b67a93d6 feat(server): client_mss_bulk — raise MSS after handshake to cut pps
client_mss (e.g. "tspu", MSS=92) fragments the whole connection to evade
DPI on the ServerHello, but it also fragments bulk payload, multiplying
outgoing packets-per-second ~10x. On hosts whose abuse detection counts
pps (not bandwidth) this trips packet-flood limits.

Add an optional [server].client_mss_bulk: keep the low client_mss for the
handshake (ServerHello stays fragmented => DPI bypass intact), then raise
the client socket MSS to client_mss_bulk once the connection enters the
post-handshake (bulk transfer) phase, so bulk data uses normal-size
segments and pps drops back to normal. Same preset/int grammar as
client_mss. Opt-in: when unset, the handshake MSS is kept for the whole
connection (unchanged behavior).

Linux-only (setsockopt TCP_MAXSEG via raw fd, mirroring TCP_USER_TIMEOUT);
no-op on other unix. Documented in CONFIG_PARAMS.{en,ru}.
2026-06-19 11:11:01 +03:00
Alexey 72800e4aa7 Harden masking fallback and frame readers after flow sync
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-17 21:48:57 +03:00
Alexey 49742d38a7 Merge pull request #843 from amirotin/fix/config-api-section-corruption
Fix config API corrupting nested sub-tables on save
2026-06-15 20:55:56 +03:00
Mirotin Artem 869d8517a0 Rustfmt 2026-06-15 10:40:45 +03:00
Mirotin Artem e82ce634d6 Use tokio::fs for I/O in config API tests
The save and patch paths under test are async, so the tests now use tokio::fs instead of blocking std::fs. The config_store tests also switch to tempfile::tempdir() for panic-safe cleanup instead of manual remove_dir_all.
2026-06-15 10:05:09 +03:00
Mirotin Artem f1f46fac42 Fix config API corrupting nested sub-tables on save
render_top_level_section serialized a section in isolation, so nested sub-tables ([general.links], [general.modes]) were emitted as bare [links]/[modes] top-level headers and duplicated on load. Serialize the section inside a wrapper keyed by its name to keep dotted headers.

find_toml_table_bounds only spanned the first contiguous block, leaving scattered sub-tables behind as duplicates on repeated saves. Replace it with find_all_table_blocks and drop every block belonging to the section during upsert.

show_link is a legacy top-level scalar/array, not a [table]; the upsert machinery appended a bare key at EOF (landing inside the previous table) and duplicated it on repeat. Remove it from EDITABLE_SECTIONS; the editable general.links.show sub-table covers the case.

Add tests for dotted sub-tables, idempotent saves, non-contiguous layouts, show_link rejection, and integer/float/string coercion of public_port.
2026-06-15 09:49:47 +03:00
Alexey 37d0184a0b Implement shared MTProto framing and ME address role separation
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-15 08:50:08 +03:00
Alexey d81d7dba62 Rustfmt 2026-06-14 19:59:06 +03:00
Alexey 04b8d8365c Account for full-word paddings in roundtrip tests 2026-06-14 19:38:54 +03:00
Alexey 2e26bfb86e Updated secure padding expectations for VersionD
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-14 16:33:41 +03:00
Alexey d414c73c9b Hardened KDF-Tuple + NAT Probing + Paddings
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-14 16:15:41 +03:00
Alexey d1a97fe10f Update README.md 2026-06-14 12:03:55 +03:00
Alexey b153782597 More efficient Relay Mode
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-13 23:22:50 +03:00
Alexey 9dc67727b0 Merge pull request #840 from telemt/flow
Restore single-record TLS-F primary application flight + Fix SYN limiter lifecycle and default burst
2026-06-12 15:23:23 +03:00
Alexey 2d02fbe548 Bump 2026-06-12 15:06:14 +03:00
Alexey 2675779915 Fix SYN limiter lifecycle and default burst
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-12 14:40:26 +03:00
Alexey c4954f745f Restore single-record TLS-F primary application flight
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-12 12:44:22 +03:00
Alexey f33abfb09e Merge pull request #838 from telemt/flow
SYN limiter for Netfilter control + Syntactic key shares for TLS-F
2026-06-12 10:08:25 +03:00
Alexey 9904da737a Rustfmt 2026-06-12 01:28:41 +03:00
Alexey 9a3ff726b2 Use token-bucket SYN limiter backends
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-12 01:27:03 +03:00
Alexey 942882f9de SYN Limiter interval and hitcount in Config
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-12 00:29:23 +03:00
Alexey eeff16c3fd Rustfmt 2026-06-12 00:01:01 +03:00
Alexey c86dc2f65e Docs for SYN Limiter
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 23:59:47 +03:00
Alexey 1cbde70a14 Add per-listener SYN limiter for Netfilter control
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 23:58:48 +03:00
Alexey 26cd4734de Update tls.rs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 23:29:10 +03:00
Alexey 52a1b66ad7 Syntactic key shares for TLS-F
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 23:13:21 +03:00
Alexey 9ff48c2028 Merge pull request #836 from telemt/flow
API + TLS-F Advanced tuning
2026-06-11 21:08:11 +03:00
Alexey b43c683615 Rustfmt 2026-06-11 19:59:48 +03:00
Alexey e41470fb4c Update fetcher.rs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 19:52:23 +03:00
Alexey 09dc0cb76c Update handshake_security_tests.rs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 19:44:39 +03:00
Alexey c36eb81808 Fix for TLS-F, ALPN и SNI/ALPN helpers
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 19:17:06 +03:00
Alexey 0f8aca56d9 Fix fallback test record iterator lifetime
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 17:56:21 +03:00
Alexey 4e66933a35 Fix TLS masking test ClientHello fixtures and tail write ordering 2026-06-11 17:51:05 +03:00
Alexey 7cf00db242 Update client_masking_budget_security_tests.rs 2026-06-11 17:32:26 +03:00
Alexey 8bc1ac06d6 Update client_masking_budget_security_tests.rs 2026-06-11 17:31:23 +03:00
Alexey 59cfcf05d3 Update client_masking_blackhat_campaign_tests.rs 2026-06-11 17:23:35 +03:00
Alexey fcbedf66ea Update client_masking_blackhat_campaign_tests.rs 2026-06-11 17:21:54 +03:00
Alexey f5c402d9fc Update metrics.rs 2026-06-11 16:43:24 +03:00
Alexey 118d53239a Merge pull request #835 from telemt/flow-ey
TLS Fixes escalating
2026-06-11 16:38:10 +03:00
Alexey 607f5442ad Merge pull request #834 from telemt/flow-11ec
TLS Fixes
2026-06-11 16:37:15 +03:00
Alexey 1edd63bfb1 Rustfmt + Bump 2026-06-11 16:36:33 +03:00
Alexey a808dc2815 Fix TLS fetch test constants scope
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 16:34:58 +03:00
Alexey 6dc9f8c27a Replay-safe TLS-F ServerHello profile consistency
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 16:11:41 +03:00
Alexey 409b0ef5ee Expose TLS Fetcher Profile Quality for ServerHello fidelity
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 14:53:21 +03:00
Alexey 3d0560d583 Select ServerHello key share from TLS Fetcher Profile
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 14:43:03 +03:00
Alexey 62af515504 Generate Valid X25519MLKEM768 ServerHello key shares
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 14:14:09 +03:00
Alexey eba55e755d Preserve TLS-F Origin Record Choreography
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 13:51:58 +03:00
Alexey c4b58ad374 Hardened TLS-F ServerHello selection
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 13:07:40 +03:00
Alexey db7ff8737c Add dynamic SNI mask target mode
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-11 10:36:37 +03:00
Alexey cd2bb9c8cd Alles muss man selber machen
Co-Authored-By: Mikhail I. Izmestev <355023+izmmisha@users.noreply.github.com>
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
Co-Authored-By: Dietmar Schreiber <376736+dginorg@users.noreply.github.com>
2026-06-11 10:13:17 +03:00
Alexey 8d3f8a8215 Merge pull request #828 from amirotin/feat/config-edit-api
Add config-edit HTTP API: PATCH/GET /v1/config
2026-06-10 10:30:52 +03:00
Mirotin Artem ff7a12d5f8 fix(api): GET /v1/config returns only editable sections; tolerate commented TOML headers; doc fixes 2026-06-09 12:13:32 +03:00
Mirotin Artem 27ee634f4a docs(api): document PATCH/GET /v1/config 2026-06-09 12:03:35 +03:00
Mirotin Artem d7e16f5b26 feat(api): config-edit endpoints PATCH/GET /v1/config 2026-06-09 12:03:28 +03:00
Mirotin Artem e39aaeb5c5 feat(config): classify_config_changes (hot vs restart) via overlay_hot_fields 2026-06-09 12:03:10 +03:00
Mirotin Artem 1628a7d822 feat(api): generic config section writer + array-table bounds 2026-06-09 12:03:01 +03:00
Alexey e9c62b6d8d Merge pull request #827 from Rightarion/fix-rate-limits-document-bits-per-second
Document rate limits as bits per second
2026-06-08 20:04:10 +03:00
Alexey 36cf3b035c Merge pull request #825 from groozchique/main
[docs] change fingerprint for xray double hop instruction
2026-06-08 20:01:20 +03:00
Samat Gilmanov 8491f5183c Document rate limits as bits per second 2026-06-08 12:39:32 -04:00
Nick Parfyonov 357852cc59 [docs] change fingerprint for xray double hop 2026-06-08 11:14:15 +03:00
Alexey 504cafb129 Merge pull request #824 from telemt/flow
MSS Tuning
2026-06-06 12:25:33 +03:00
Alexey 1096e38854 Docs for MSS Tuning
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-06 12:24:27 +03:00
Alexey 9bbdf796d8 Rustfmt 2026-06-06 12:17:19 +03:00
Alexey 27a5f5a4ec MSS Tuning with config
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-06-06 12:11:05 +03:00
79 changed files with 6680 additions and 992 deletions
Generated
+166 -16
View File
@@ -8,7 +8,7 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"crypto-common 0.1.7",
"generic-array",
]
@@ -249,6 +249,15 @@ dependencies = [
"generic-array",
]
[[package]]
name = "block-buffer"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be"
dependencies = [
"hybrid-array",
]
[[package]]
name = "block-padding"
version = "0.3.3"
@@ -397,7 +406,7 @@ version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"crypto-common 0.1.7",
"inout",
"zeroize",
]
@@ -436,6 +445,12 @@ dependencies = [
"cc",
]
[[package]]
name = "cmov"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c9ea0ac24bc397ab3c98583a3c9ba74fa56b09a4449bbe172b9b1ddb016027a"
[[package]]
name = "combine"
version = "4.6.7"
@@ -452,6 +467,12 @@ version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
[[package]]
name = "const-oid"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c"
[[package]]
name = "constant_time_eq"
version = "0.4.2"
@@ -611,6 +632,16 @@ dependencies = [
"typenum",
]
[[package]]
name = "crypto-common"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce6e4c961d6cd6c9a86db418387425e8bdeaf05b3c8bc1411e6dca4c252f1453"
dependencies = [
"hybrid-array",
"rand_core 0.10.1",
]
[[package]]
name = "ctr"
version = "0.9.2"
@@ -620,6 +651,15 @@ dependencies = [
"cipher",
]
[[package]]
name = "ctutils"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e"
dependencies = [
"cmov",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
@@ -672,7 +712,17 @@ version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
dependencies = [
"const-oid",
"const-oid 0.9.6",
"zeroize",
]
[[package]]
name = "der"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b"
dependencies = [
"const-oid 0.10.2",
"zeroize",
]
@@ -705,11 +755,21 @@ version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"block-buffer 0.10.4",
"crypto-common 0.1.7",
"subtle",
]
[[package]]
name = "digest"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2"
dependencies = [
"block-buffer 0.12.0",
"crypto-common 0.2.2",
]
[[package]]
name = "displaydoc"
version = "0.2.6"
@@ -753,7 +813,7 @@ version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53"
dependencies = [
"pkcs8",
"pkcs8 0.10.2",
"signature",
]
@@ -1135,7 +1195,7 @@ version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
"digest 0.10.7",
]
[[package]]
@@ -1183,6 +1243,17 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hybrid-array"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9155a582abd142abc056962c29e3ce5ff2ad5469f4246b537ed42c5deba857da"
dependencies = [
"ctutils",
"typenum",
"zeroize",
]
[[package]]
name = "hyper"
version = "1.10.0"
@@ -1532,6 +1603,26 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "keccak"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e24a010dd405bd7ed803e5253182815b41bf2e6a80cc3bfc066658e03a198aa"
dependencies = [
"cfg-if",
"cpufeatures 0.3.0",
]
[[package]]
name = "kem"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01737161ba802849cfd486b5bd209d38ba4943494c249a8126005170c7621edd"
dependencies = [
"crypto-common 0.2.2",
"rand_core 0.10.1",
]
[[package]]
name = "kqueue"
version = "1.1.1"
@@ -1634,7 +1725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf"
dependencies = [
"cfg-if",
"digest",
"digest 0.10.7",
]
[[package]]
@@ -1670,6 +1761,33 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "ml-kem"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e15f3e5b957493873e396a66914e83e616b6afe335cdef7efe5c6e1216aba66"
dependencies = [
"hybrid-array",
"kem",
"module-lattice",
"pkcs8 0.11.0",
"rand_core 0.10.1",
"sha3",
"zeroize",
]
[[package]]
name = "module-lattice"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c61b87c9683ab7cb1c6871d261ad5479b6b10ceb52c4352aaca3b5d35a8febe"
dependencies = [
"ctutils",
"hybrid-array",
"num-traits",
"zeroize",
]
[[package]]
name = "moka"
version = "0.12.15"
@@ -1888,8 +2006,18 @@ version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7"
dependencies = [
"der",
"spki",
"der 0.7.10",
"spki 0.7.3",
]
[[package]]
name = "pkcs8"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "451913da69c775a56034ea8d9003d27ee8948e12443eae7c038ba100a4f21cb7"
dependencies = [
"der 0.8.0",
"spki 0.8.0",
]
[[package]]
@@ -2280,7 +2408,7 @@ dependencies = [
"aead",
"ed25519",
"generic-array",
"pkcs8",
"pkcs8 0.10.2",
"ring",
]
@@ -2567,7 +2695,7 @@ checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures 0.2.17",
"digest",
"digest 0.10.7",
]
[[package]]
@@ -2578,7 +2706,17 @@ checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures 0.2.17",
"digest",
"digest 0.10.7",
]
[[package]]
name = "sha3"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be176f1a57ce4e3d31c1a166222d9768de5954f811601fb7ca06fc8203905ce1"
dependencies = [
"digest 0.11.3",
"keccak",
]
[[package]]
@@ -2724,7 +2862,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d"
dependencies = [
"base64ct",
"der",
"der 0.7.10",
]
[[package]]
name = "spki"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d9efca8738c78ee9484207732f728b1ef517bbb1833d6fc0879ca898a522f6f"
dependencies = [
"base64ct",
"der 0.8.0",
]
[[package]]
@@ -2790,7 +2938,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.4.14"
version = "3.4.18"
dependencies = [
"aes",
"anyhow",
@@ -2816,6 +2964,7 @@ dependencies = [
"libc",
"lru",
"md-5",
"ml-kem",
"nix",
"notify",
"num-bigint",
@@ -2834,6 +2983,7 @@ dependencies = [
"socket2",
"static_assertions",
"subtle",
"tempfile",
"thiserror",
"tokio",
"tokio-rustls",
@@ -3258,7 +3408,7 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"crypto-common 0.1.7",
"subtle",
]
+3 -1
View File
@@ -1,6 +1,6 @@
[package]
name = "telemt"
version = "3.4.14"
version = "3.4.18"
edition = "2024"
[features]
@@ -27,6 +27,7 @@ crc32c = "0.6"
zeroize = { version = "1.8", features = ["derive"] }
subtle = "2.6"
static_assertions = "1.1"
ml-kem = { version = "0.3.2", default-features = false, features = ["alloc", "zeroize"] }
# Network
socket2 = { version = "0.6", features = ["all"] }
@@ -90,6 +91,7 @@ tokio-test = "0.4"
criterion = "0.8"
proptest = "1.4"
futures = "0.3"
tempfile = "3.27.0"
[[bench]]
name = "crypto_bench"
+1 -1
View File
@@ -8,7 +8,7 @@
>
> From June 5th, 2026: we are already analyzing the causes of a new wave of "malfunctions"
>
> Telegram Clients TLS ClientHello has been banned by JA3 Fingerprint: we are already looking for ways to solve this problem
> Telegram Clients TLS ClientHello has been banned by JA4/JA4+ Fingerprint: we are already looking for ways to solve this problem
>
> You can try build your client with our Telegram Devlibrary - [tdlib-obf](https://github.com/telemt/tdlib-obf)
+143 -6
View File
@@ -106,6 +106,8 @@ Notes:
| `GET` | `/v1/runtime/tls-fingerprints` | optional `limit=1..1000` | `200` | `RuntimeEdgeTlsFingerprintsData` |
| `GET` | `/v1/stats/users/active-ips` | none | `200` | `UserActiveIps[]` |
| `GET` | `/v1/stats/users` | none | `200` | `UserInfo[]` |
| `GET` | `/v1/config` | none | `200` | `ConfigData` |
| `PATCH` | `/v1/config` | sparse JSON object | `200` | `PatchConfigResponse` |
| `GET` | `/v1/users` | none | `200` | `UserInfo[]` |
| `POST` | `/v1/users` | `CreateUserRequest` | `201` or `202` | `CreateUserResponse` |
| `GET` | `/v1/users/{username}` | none | `200` | `UserInfo` |
@@ -143,6 +145,8 @@ Notes:
| `GET /v1/runtime/events/recent` | Returns recent API/runtime event records with optional `limit` query. |
| `GET /v1/stats/users/active-ips` | Returns users that currently have non-empty active source-IP lists. |
| `GET /v1/stats/users` | Alias of `GET /v1/users`; returns disk-first user views with runtime lag flag. |
| `GET /v1/config` | Returns the current editable config sections as JSON (no `access.*`) plus the revision. |
| `PATCH /v1/config` | Applies a sparse patch to editable config sections; validates, writes, and reports restart impact. |
| `GET /v1/users` | Returns disk-first user views sorted by username. |
| `POST /v1/users` | Creates a user and returns the effective user view plus secret. |
| `GET /v1/users/{username}` | Returns one disk-first user view or `404` when absent. |
@@ -158,6 +162,8 @@ Notes:
| HTTP | `error.code` | Trigger |
| --- | --- | --- |
| `400` | `bad_request` | Invalid JSON, validation failures, malformed request body. |
| `400` | `access_not_editable` | `PATCH /v1/config` body contains an `access` key (managed via users API). |
| `400` | `section_not_editable` | `PATCH /v1/config` body contains `server`, `network`, or an unknown top-level key. |
| `401` | `unauthorized` | Missing/invalid `Authorization` when `auth_header` is configured. |
| `403` | `forbidden` | Source IP is not allowed by whitelist. |
| `403` | `read_only` | Mutating endpoint called while `read_only=true`. |
@@ -177,6 +183,7 @@ Notes:
| Path matching | Exact match on `req.uri().path()`. Query string does not affect route matching. |
| Trailing slash | Trimmed for route matching when path length is greater than 1. Example: `/v1/users/` matches `/v1/users`. |
| Username route with extra slash | `/v1/users/{username}/...` is not treated as user route and returns `404`. |
| `DELETE /v1/config` (or any method not in `GET`, `PATCH`) | `405 method_not_allowed` with `Allow: GET, PATCH`. |
| `PUT /v1/users/{username}` | `405 method_not_allowed`. |
| `POST /v1/users/{username}` | `404 not_found`. |
| `POST /v1/users/{username}/rotate-secret/` | Trailing slash is trimmed and the route matches `rotate-secret`. |
@@ -212,8 +219,8 @@ Notes:
| `max_tcp_conns` | `usize` | no | Per-user concurrent TCP limit. |
| `expiration_rfc3339` | `string` | no | RFC3339 expiration timestamp. |
| `data_quota_bytes` | `u64` | no | Per-user traffic quota. |
| `rate_limit_up_bps` | `u64` | no | Per-user upload rate limit in bytes per second. |
| `rate_limit_down_bps` | `u64` | no | Per-user download rate limit in bytes per second. |
| `rate_limit_up_bps` | `u64` | no | Per-user upload rate limit in bits per second. |
| `rate_limit_down_bps` | `u64` | no | Per-user download rate limit in bits per second. |
| `max_unique_ips` | `usize` | no | Per-user unique source IP limit. |
| `enabled` | `bool` | no | User enable flag. Missing means enabled. `false` persists a disabled override. |
@@ -225,8 +232,8 @@ Notes:
| `max_tcp_conns` | `usize|null` | no | Per-user concurrent TCP limit; `null` removes the per-user override. |
| `expiration_rfc3339` | `string|null` | no | RFC3339 expiration timestamp; `null` removes the expiration. |
| `data_quota_bytes` | `u64|null` | no | Per-user traffic quota; `null` removes the per-user quota. |
| `rate_limit_up_bps` | `u64|null` | no | Per-user upload rate limit in bytes per second; `null` removes the upload direction limit. |
| `rate_limit_down_bps` | `u64|null` | no | Per-user download rate limit in bytes per second; `null` removes the download direction limit. |
| `rate_limit_up_bps` | `u64|null` | no | Per-user upload rate limit in bits per second; `null` removes the upload direction limit. |
| `rate_limit_down_bps` | `u64|null` | no | Per-user download rate limit in bits per second; `null` removes the download direction limit. |
| `max_unique_ips` | `usize|null` | no | Per-user unique source IP limit; `null` removes the per-user override. |
| `enabled` | `bool|null` | no | `false` disables the user. `true` or `null` removes the disabled override, so the user is enabled. |
@@ -245,6 +252,20 @@ alice = ["203.0.113.0/24", "2001:db8:abcd::/48"]
bob = ["198.51.100.42/32"]
```
### `PatchConfigRequest`
A sparse JSON object containing only the top-level config sections to modify. Each key must be one of the editable sections (`general`, `timeouts`, `censorship`, `upstreams`, `show_link`, `dc_overrides`). Tables within a section are deep-merged field-by-field into the existing config; arrays and scalar values replace the existing value wholesale. Untouched sections and file comments are preserved.
**Rejected keys:**
- `access``400 access_not_editable` (users/secrets are managed via `POST/PATCH /v1/users`).
- `server`, `network`, or any unknown top-level key → `400 section_not_editable`.
- An object with no editable keys → `400 bad_request` (empty patch).
Example — patch only the SNI domain:
```json
{"censorship": {"tls_domain": "front.example.com"}}
```
### `RotateSecretRequest`
| Field | Type | Required | Description |
| --- | --- | --- | --- |
@@ -254,6 +275,31 @@ An empty request body is accepted and generates a new secret automatically.
## Response Data Contracts
### `ConfigData`
Returned by `GET /v1/config` as the envelope `data`. The fields are exactly the editable TOML sections. The current revision is returned in the envelope `revision` field (same value as `config_hash` in `SystemInfoData`), **not** inside `data`.
| Field | Type | Description |
| --- | --- | --- |
| `general` | `object?` | `[general]` section, if present in config. |
| `timeouts` | `object?` | `[timeouts]` section, if present. |
| `censorship` | `object?` | `[censorship]` section, if present. |
| `upstreams` | `object?` | `[upstreams]` section, if present. |
| `show_link` | `object?` | `[show_link]` section, if present. |
| `dc_overrides` | `object?` | `[dc_overrides]` section, if present. |
Sections absent from the config file are absent from the response (not `null`). Only the editable sections above are returned; `access` (users/secrets), `server` (carries the API `auth_header` and per-node identity), and `network` (per-node addresses) are always excluded.
### `PatchConfigResponse`
Returned by `PATCH /v1/config` on success (`200`).
| Field | Type | Description |
| --- | --- | --- |
| `revision` | `string` | SHA-256 hex of the config file after the patch was written. |
| `restart_required` | `bool` | `true` when one or more changed fields require a process restart to take effect. Hot-reloadable fields (e.g. `general.log_level`) are applied automatically by the config file watcher; restart-required fields (e.g. any `censorship.*`, `timeouts.*`, `upstreams`, or `general.modes` change) are written to disk but only take effect after the Telemt process is restarted. The caller is responsible for triggering a restart when this flag is `true`. |
| `changed` | `string[]` | Top-level section names that differed between the old and new config (e.g. `["censorship"]`). |
### `HealthData`
| Field | Type | Description |
| --- | --- | --- |
@@ -1217,8 +1263,8 @@ JA3 follows the Salesforce ClientHello field order. JA4 follows the FoxIO TLS-cl
| `max_tcp_conns` | `usize?` | Optional max concurrent TCP limit. |
| `expiration_rfc3339` | `string?` | Optional expiration timestamp. |
| `data_quota_bytes` | `u64?` | Optional data quota. |
| `rate_limit_up_bps` | `u64?` | Optional upload rate limit in bytes per second. |
| `rate_limit_down_bps` | `u64?` | Optional download rate limit in bytes per second. |
| `rate_limit_up_bps` | `u64?` | Optional upload rate limit in bits per second. |
| `rate_limit_down_bps` | `u64?` | Optional download rate limit in bits per second. |
| `max_unique_ips` | `usize?` | Optional unique IP limit. |
| `current_connections` | `u64` | Current live connections. |
| `active_unique_ips` | `usize` | Current active unique source IPs. |
@@ -1279,10 +1325,101 @@ Link generation uses active config and enabled modes:
| `used_bytes` | `u64` | Current used bytes after reset; always `0` on success. |
| `last_reset_epoch_secs` | `u64` | Unix timestamp of the reset operation. |
## Config Endpoints
### `GET /v1/config`
Returns the current editable config sections as TOML-shaped JSON, plus the current revision. The `access` section (users and secrets) is always stripped and never appears in the response.
**Auth:** requires `Authorization` header when `auth_header` is configured (same as all other endpoints).
**Success `200` response body** (`data` field of the standard envelope):
```json
{
"revision": "<sha256-hex>",
"censorship": {"tls_domain": "front.example.com"},
"general": {"log_level": "normal"}
}
```
Top-level sections absent from the config file are absent from the response. Only `GET` and `PATCH` are accepted; any other method returns `405 Method Not Allowed` with `Allow: GET, PATCH`.
---
### `PATCH /v1/config`
Applies a sparse patch to the editable config sections. The merged config is fully validated before writing; if validation fails the file is not modified.
**Auth:** requires `Authorization` header when `auth_header` is configured.
**Headers:**
| Header | Required | Description |
| --- | --- | --- |
| `Authorization` | when configured | Same token as all other endpoints. |
| `Content-Type: application/json` | recommended | Not enforced, but body must be valid JSON. |
| `If-Match: <revision>` | no | Optimistic concurrency. `<revision>` is the `revision` value from `GET /v1/config` or `config_hash` from `GET /v1/system/info`. If supplied and it does not match the current on-disk revision, returns `409 revision_conflict`. If omitted, the patch applies unconditionally. |
**Editable sections:** `general`, `timeouts`, `censorship`, `upstreams`, `show_link`, `dc_overrides`.
**Rejected keys and their error codes:**
| Key | HTTP | `error.code` |
| --- | --- | --- |
| `access` | `400` | `access_not_editable` |
| `server`, `network`, or any unknown key | `400` | `section_not_editable` |
| Object with no editable key | `400` | `bad_request` |
**Merge semantics:** tables are deep-merged field-by-field; arrays and scalar values replace the existing value wholesale. File comments and untouched sections are preserved.
**Validation:** the merged config is deserialized into the full `ProxyConfig` type and validated before writing. Failures return `400` with a descriptive message; the file is not modified.
**Read-only mode:** returns `403 read_only` when the API runs with `read_only = true`.
**Success `200` response body** (`data` field of the standard envelope):
```json
{
"revision": "<new-sha256-hex>",
"restart_required": true,
"changed": ["censorship"]
}
```
- `revision` — SHA-256 hex of the config file after the write.
- `restart_required``true` when the change affects a field that Telemt cannot hot-reload (e.g. `censorship.*`, `timeouts.*`, `upstreams`, `general.modes`). Hot-reloadable fields (e.g. `general.log_level`) are applied automatically by the config file watcher. Restart-required fields are written to disk but only take effect after the Telemt process is restarted; the caller is responsible for triggering the restart.
- `changed` — list of top-level section names that differed.
**Status codes:**
| HTTP | `error.code` | Condition |
| --- | --- | --- |
| `200` | — | Patch applied successfully. |
| `400` | `bad_request` | Invalid JSON, empty patch, or config validation/deserialization failure. |
| `400` | `access_not_editable` | Patch contains an `access` key. |
| `400` | `section_not_editable` | Patch contains `server`, `network`, or an unknown top-level key. |
| `401` | `unauthorized` | Missing or invalid `Authorization` header. |
| `403` | `read_only` | API is in read-only mode. |
| `405` | `method_not_allowed` | Method other than `GET` or `PATCH` used on `/v1/config`. |
| `409` | `revision_conflict` | `If-Match` header supplied but does not match current revision. |
| `500` | `internal_error` | I/O or serialization failure. |
**curl example:**
```bash
# get current revision
curl -s -H "Authorization: <token>" http://127.0.0.1:<api>/v1/system/info | jq -r .config_hash
# patch the SNI domain with optimistic concurrency
curl -s -X PATCH -H "Authorization: <token>" -H "If-Match: <revision>" \
-H "Content-Type: application/json" \
-d '{"censorship":{"tls_domain":"front.example.com"}}' \
http://127.0.0.1:<api>/v1/config
```
## Mutation Semantics
| Endpoint | Notes |
| --- | --- |
| `PATCH /v1/config` | Deep-merges the patch into editable config sections (tables merged per-field; arrays/scalars replaced wholesale). Validates the merged result before writing. Writes only the touched sections via atomic `tmp + rename`. Returns the new revision and which sections changed. |
| `POST /v1/users` | Creates user, validates config, then atomically updates only affected `access.*` TOML tables (`access.users` always, plus optional per-user tables present in request). |
| `PATCH /v1/users/{username}` | Partial update of provided fields only. Missing fields remain unchanged; explicit `null` removes optional per-user entries. The write path updates only affected `access.*` TOML tables. |
| `POST /v1/users/{username}/rotate-secret` | Replaces the user's secret with a provided valid 32-hex value or a generated value, then returns the effective secret in `CreateUserResponse`. |
+91 -1
View File
@@ -1805,6 +1805,8 @@ This document lists all configuration keys accepted by `config.toml`.
| [`listen_unix_sock`](#listen_unix_sock) | `String` | — | `` |
| [`listen_unix_sock_perm`](#listen_unix_sock_perm) | `String` | — | `` |
| [`listen_tcp`](#listen_tcp) | `bool` | — (auto) | `` |
| [`client_mss`](#client_mss) | `String` | `""` | `` |
| [`client_mss_bulk`](#client_mss_bulk) | `String` | `""` | `` |
| [`proxy_protocol`](#proxy_protocol) | `bool` | `false` | `` |
| [`proxy_protocol_header_timeout_ms`](#proxy_protocol_header_timeout_ms) | `u64` | `500` | `` |
| [`proxy_protocol_trusted_cidrs`](#proxy_protocol_trusted_cidrs) | `IpNetwork[]` | `[]` | `` |
@@ -1887,6 +1889,26 @@ This document lists all configuration keys accepted by `config.toml`.
listen_unix_sock = "/run/telemt.sock"
listen_tcp = true
```
## client_mss
- **Constraints / validation**: `String`. Empty or omitted means do not change kernel MSS. Presets: `"extreme-low"` = `88`, `"tspu"` = `92`, `"2in8"` = `256`. Custom decimal strings must be within `88..=4096`.
- **Description**: Client-facing TCP MSS applied to TCP listener sockets before `listen(2)`, so Linux can announce it in SYN/ACK. This affects only proxy client TCP listeners, not API, metrics, Unix sockets, Telegram upstreams, ME sockets, or mask backend connections. Changes require listener restart/rebind.
- **Performance note**: Low MSS increases packet count predictably. Approximate segment multiplier is `ceil(1460 / client_mss)`.
- **Example**:
```toml
[server]
client_mss = "tspu"
```
## client_mss_bulk
- **Constraints / validation**: `String`. Same grammar as [`client_mss`](#client_mss) (empty/omitted, presets `"extreme-low"`/`"tspu"`/`"2in8"`, or a decimal in `88..=4096`).
- **Description**: Optional bulk-phase MSS. When set, the low `client_mss` is applied only while the TLS handshake (including the DPI-inspected ServerHello) is sent; once the connection transitions to relaying, the client socket MSS is raised to `client_mss_bulk` for the bulk data phase. This keeps the anti-DPI handshake fragmentation but restores normal-size packets for payload, cutting outgoing packets-per-second by roughly the `client_mss` segment multiplier (e.g. ~10x with `"tspu"`). Useful on hosts whose abuse detection counts packets-per-second rather than bandwidth. When empty/omitted, the handshake MSS is kept for the whole connection (previous behavior). Linux only; a no-op elsewhere.
- **Example**:
```toml
[server]
client_mss = "tspu"
client_mss_bulk = "1400"
```
## proxy_protocol
- **Constraints / validation**: `bool`.
- **Description**: Enables HAProxy PROXY protocol parsing on incoming connections (PROXY v1/v2). When enabled, client source address is taken from the PROXY header.
@@ -2207,6 +2229,11 @@ Note: This section also accepts the legacy alias `[server.admin_api]` (same sche
| --- | ---- | ------- | ---------- |
| [`ip`](#ip) | `IpAddr` | — | `` |
| [`port`](#port-serverlisteners) | `u16` | `server.port` | `` |
| [`client_mss`](#client_mss-serverlisteners) | `String` | `[server].client_mss` | `` |
| [`synlimit`](#synlimit-serverlisteners) | `false`, `"iptables"`, or `"nftables"` | `false` | `` |
| [`synlimit_seconds`](#synlimit_seconds-serverlisteners) | `u32` | `1` | `` |
| [`synlimit_hitcount`](#synlimit_hitcount-serverlisteners) | `u32` | `1` | `` |
| [`synlimit_burst`](#synlimit_burst-serverlisteners) | `u32` | `2` | `` |
| [`announce`](#announce) | `String` | — | `` |
| [`announce_ip`](#announce_ip) | `IpAddr` | — | `` |
| [`proxy_protocol`](#proxy_protocol) | `bool` | — | `` |
@@ -2231,6 +2258,69 @@ Note: This section also accepts the legacy alias `[server.admin_api]` (same sche
ip = "0.0.0.0"
port = 443
```
## client_mss (server.listeners)
- **Constraints / validation**: `String` (optional). Same values as `[server].client_mss`.
- **Description**: Per-listener MSS override. When omitted, inherits `[server].client_mss`; when set to an empty string, disables MSS shaping for this listener even if the global value is set. Changes require listener restart/rebind.
- **Example**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
client_mss = "256"
```
## synlimit (server.listeners)
- **Constraints / validation**: `false`, `"iptables"`, or `"nftables"`. Omitted or `false` disables SYN limiting for this listener.
- **Description**: Installs per-listener Linux netfilter SYN limiter rules for the listener port. `"iptables"` uses `iptables`/`ip6tables` filter rules with the `hashlimit` match as a per-source token bucket. `"nftables"` uses per-source `meter` rules with `limit rate over` and auto-detects whether the host already uses `inet`, `ip`, or `ip6` table families before creating Telemt-owned tables. The token-bucket rate is `synlimit_hitcount / synlimit_seconds`; `synlimit_burst` controls the burst size. Rules are reconciled at runtime and removed during graceful Telemt shutdown; `SIGKILL` cannot be cleaned up by the process. Requires CAP_NET_ADMIN. `synlimit*` changes hot-reload for existing listener endpoints; changing listener `ip` or `port` still requires restart/rebind.
- **Example**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
[[server.listeners]]
ip = "::"
port = 443
synlimit = "nftables"
```
## synlimit_seconds (server.listeners)
- **Constraints / validation**: `u32`, must be `> 0`. Default is `1`.
- **Description**: Token-bucket interval for both SYN limiter backends. The rate is `synlimit_hitcount / synlimit_seconds` and is rendered to native netfilter rate units (`second`, `minute`, `hour`, or `day`).
- **Example**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
synlimit_seconds = 1
```
## synlimit_hitcount (server.listeners)
- **Constraints / validation**: `u32`, must be `> 0`. Default is `1`.
- **Description**: Token-bucket rate amount for both SYN limiter backends. Together with `synlimit_seconds`, it defines the allowed source-IP SYN rate before excess SYN packets are dropped.
- **Example**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
synlimit_hitcount = 1
```
## synlimit_burst (server.listeners)
- **Constraints / validation**: `u32`, must be `> 0`. Default is `2`.
- **Description**: Token-bucket burst size for both SYN limiter backends. Higher values allow short connection bursts from the same source IP before the steady-state `synlimit_hitcount / synlimit_seconds` rate is enforced.
- **Example**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
synlimit_burst = 2
```
## announce
- **Constraints / validation**: `String` (optional). Must not be empty when set.
- **Description**: Public IP/domain announced in proxy links for this listener. Takes precedence over `announce_ip`.
@@ -3104,7 +3194,7 @@ If your backend or network is very bandwidth-constrained, reduce cap first. If p
## user_rate_limits
- **Constraints / validation**: Table `username -> { up_bps, down_bps }`. At least one direction must be non-zero.
- **Description**: Per-user bandwidth caps in bytes/sec for upload (`up_bps`) and download (`down_bps`).
- **Description**: Per-user bandwidth caps in bits/sec for upload (`up_bps`) and download (`down_bps`).
- **Example**:
```toml
+91 -1
View File
@@ -1807,6 +1807,8 @@
| [`listen_unix_sock`](#listen_unix_sock) | `String` | — | `` |
| [`listen_unix_sock_perm`](#listen_unix_sock_perm) | `String` | — | `` |
| [`listen_tcp`](#listen_tcp) | `bool` | — (auto) | `` |
| [`client_mss`](#client_mss) | `String` | `""` | `` |
| [`client_mss_bulk`](#client_mss_bulk) | `String` | `""` | `` |
| [`proxy_protocol`](#proxy_protocol) | `bool` | `false` | `` |
| [`proxy_protocol_header_timeout_ms`](#proxy_protocol_header_timeout_ms) | `u64` | `500` | `` |
| [`proxy_protocol_trusted_cidrs`](#proxy_protocol_trusted_cidrs) | `IpNetwork[]` | `[]` | `` |
@@ -1889,6 +1891,26 @@
listen_unix_sock = "/run/telemt.sock"
listen_tcp = true
```
## client_mss
- **Ограничения / валидация**: `String`. Пустое значение или отсутствие параметра означает, что Telemt не изменяет MSS, выбранный ядром. Поддерживаемые presets: `"extreme-low"` = `88`, `"tspu"` = `92`, `"2in8"` = `256`. Пользовательское десятичное значение должно быть строкой в диапазоне `88..=4096`.
- **Описание**: MSS для входящих TCP-соединений клиентов. Значение применяется к TCP listener-сокетам до `listen(2)`, чтобы Linux мог объявить его в SYN/ACK. Параметр влияет только на proxy client TCP listeners и не применяется к API, metrics, Unix sockets, Telegram upstreams, ME sockets или mask backend connections. Изменение требует restart/rebind listener’ов.
- **Performance note**: Низкий MSS предсказуемо увеличивает количество TCP-сегментов. Приблизительный multiplier: `ceil(1460 / client_mss)`.
- **Пример**:
```toml
[server]
client_mss = "tspu"
```
## client_mss_bulk
- **Ограничения / валидация**: `String`. Грамматика та же, что у [`client_mss`](#client_mss) (пусто/не задано, пресеты `"extreme-low"`/`"tspu"`/`"2in8"` либо десятичное число в диапазоне `88..=4096`).
- **Описание**: Необязательный MSS для bulk-фазы. Если задан, низкий `client_mss` применяется только на время TLS-handshake (включая инспектируемый DPI ServerHello); как только соединение переходит в фазу relay, MSS клиентского сокета поднимается до `client_mss_bulk` для передачи полезной нагрузки. Так сохраняется anti-DPI фрагментация handshake, но для данных возвращаются пакеты нормального размера — это снижает исходящий packets-per-second примерно во столько раз, каков segment multiplier у `client_mss` (например, ~10x для `"tspu"`). Полезно на хостингах, где abuse-детекция считает packets-per-second, а не полосу. Если пусто/не задано — MSS handshake сохраняется на всё соединение (прежнее поведение). Только Linux; на прочих платформах — no-op.
- **Пример**:
```toml
[server]
client_mss = "tspu"
client_mss_bulk = "1400"
```
## proxy_protocol
- **Ограничения / валидация**: `bool`.
- **Описание**: Включает поддержку разбора PROXY protocol от HAProxy (v1/v2) на входящих соединениях. При включении исходный IP клиента берётся из PROXY-заголовка.
@@ -2213,6 +2235,11 @@
| --- | ---- | ------- | ---------- |
| [`ip`](#ip) | `IpAddr` | — | `` |
| [`port`](#port-serverlisteners) | `u16` | `server.port` | `` |
| [`client_mss`](#client_mss-serverlisteners) | `String` | `[server].client_mss` | `` |
| [`synlimit`](#synlimit-serverlisteners) | `false`, `"iptables"` или `"nftables"` | `false` | `` |
| [`synlimit_seconds`](#synlimit_seconds-serverlisteners) | `u32` | `1` | `` |
| [`synlimit_hitcount`](#synlimit_hitcount-serverlisteners) | `u32` | `1` | `` |
| [`synlimit_burst`](#synlimit_burst-serverlisteners) | `u32` | `2` | `` |
| [`announce`](#announce) | `String` | — | `` |
| [`announce_ip`](#announce_ip) | `IpAddr` | — | `` |
| [`proxy_protocol`](#proxy_protocol) | `bool` | — | `` |
@@ -2237,6 +2264,69 @@
ip = "0.0.0.0"
port = 443
```
## client_mss (server.listeners)
- **Ограничения / валидация**: `String` (необязательный параметр). Допустимые значения совпадают с `[server].client_mss`.
- **Описание**: Per-listener override для MSS. Если параметр не задан, listener наследует `[server].client_mss`; если задана пустая строка, MSS shaping отключается только для этого listener’а, даже когда глобальный параметр задан. Изменение требует restart/rebind listener’а.
- **Пример**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
client_mss = "256"
```
## synlimit (server.listeners)
- **Ограничения / валидация**: `false`, `"iptables"` или `"nftables"`. Если параметр не задан или задан как `false`, SYN limiter для этого listener’а выключен.
- **Описание**: Устанавливает per-listener Linux netfilter SYN limiter rules для порта listener’а. `"iptables"` использует `iptables`/`ip6tables` filter rules с `hashlimit` match как per-source token bucket. `"nftables"` использует per-source `meter` rules с `limit rate over` и автоматически определяет, какие table families уже используются на хосте (`inet`, `ip`, `ip6`), перед созданием Telemt-owned tables. Token-bucket rate равен `synlimit_hitcount / synlimit_seconds`; `synlimit_burst` управляет burst size. Rules reconciled at runtime и удаляются при graceful shutdown Telemt; `SIGKILL` процессом не очищается. Требует CAP_NET_ADMIN. Изменения `synlimit*` hot-reload’ятся для существующих listener endpoints; изменение listener `ip` или `port` по-прежнему требует restart/rebind.
- **Пример**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
[[server.listeners]]
ip = "::"
port = 443
synlimit = "nftables"
```
## synlimit_seconds (server.listeners)
- **Ограничения / валидация**: `u32`, должно быть `> 0`. Значение по умолчанию: `1`.
- **Описание**: Token-bucket interval для обоих SYN limiter backends. Rate равен `synlimit_hitcount / synlimit_seconds` и рендерится в native netfilter rate units (`second`, `minute`, `hour` или `day`).
- **Пример**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
synlimit_seconds = 1
```
## synlimit_hitcount (server.listeners)
- **Ограничения / валидация**: `u32`, должно быть `> 0`. Значение по умолчанию: `1`.
- **Описание**: Token-bucket rate amount для обоих SYN limiter backends. Вместе с `synlimit_seconds` задает разрешенный source-IP SYN rate до того, как excess SYN packets начнут drop’аться.
- **Пример**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
synlimit_hitcount = 1
```
## synlimit_burst (server.listeners)
- **Ограничения / валидация**: `u32`, должно быть `> 0`. Значение по умолчанию: `2`.
- **Описание**: Token-bucket burst size для обоих SYN limiter backends. Более высокие значения разрешают short connection bursts с одного source IP перед применением steady-state rate `synlimit_hitcount / synlimit_seconds`.
- **Пример**:
```toml
[[server.listeners]]
ip = "0.0.0.0"
port = 443
synlimit = "iptables"
synlimit_burst = 2
```
## announce
- **Ограничения / валидация**: `String` (необязательный параметр). Не должен быть пустым, если задан.
- **Описание**: Публичный IP-адрес или домен, объявляемый в proxy-ссылках для данного listener’а. Имеет приоритет над `announce_ip`.
@@ -3100,7 +3190,7 @@
## user_rate_limits
- **Ограничения / валидация**: Таблица `username -> { up_bps, down_bps }`. Должно быть ненулевое значение хотя бы в одном направлении.
- **Описание**: Персональные лимиты скорости по пользователям в байтах/сек для отправки (`up_bps`) и получения (`down_bps`).
- **Описание**: Персональные лимиты скорости по пользователям в битах/сек для отправки (`up_bps`) и получения (`down_bps`).
- **Example**:
```toml
+1 -1
View File
@@ -206,7 +206,7 @@ File content:
"publicKey": "<SERVER_B_PUBLIC_KEY>",
"shortId": "<SHORT_ID>",
"spiderX": "/",
"fingerprint": "chrome"
"fingerprint": "firefox"
},
"xhttpSettings": {
"path": "/<YOUR_RANDOM_PATH>"
+1 -1
View File
@@ -206,7 +206,7 @@ nano /usr/local/etc/xray/config.json
"publicKey": "<SERVER_B_PUBLIC_KEY>",
"shortId": "<SHORT_ID>",
"spiderX": "/",
"fingerprint": "chrome"
"fingerprint": "firefox"
},
"xhttpSettings": {
"path": "/<YOUR_RANDOM_PATH>"
+411
View File
@@ -0,0 +1,411 @@
//! Config-editing API: read managed sections and apply sparse field patches.
//! `access.*` is intentionally not editable here (owned by the users API).
use serde_json::Value as Json;
use toml::Value as Toml;
use super::ApiShared;
use super::config_store::{
EDITABLE_SECTIONS, compute_revision, current_revision, save_sections_to_disk,
};
use super::model::ApiFailure;
use crate::config::ProxyConfig;
use crate::config::hot_reload::classify_config_changes;
use serde::Serialize;
use std::path::Path;
#[derive(Debug, Serialize)]
pub(super) struct PatchConfigResponse {
pub revision: String,
pub restart_required: bool,
pub changed: Vec<String>,
}
/// Shared-state wrapper around [`apply_patch_to_path`]: serializes config
/// mutations behind `mutation_lock`, then records a runtime event. The route
/// handler calls this; the core logic stays decoupled for unit tests.
pub(super) async fn patch_config(
patch_json: Json,
expected_revision: Option<String>,
shared: &ApiShared,
) -> Result<PatchConfigResponse, ApiFailure> {
let _guard = shared.mutation_lock.lock().await;
let resp = apply_patch_to_path(&shared.config_path, &patch_json, expected_revision).await?;
drop(_guard);
shared
.runtime_events
.record("api.config.patch.ok", format!("changed={:?}", resp.changed));
Ok(resp)
}
/// Core patch logic, decoupled from hyper/shared-state so it is unit-testable
/// against a temp file. The route handler holds `mutation_lock` while calling this.
pub(super) async fn apply_patch_to_path(
config_path: &Path,
patch_json: &Json,
expected_revision: Option<String>,
) -> Result<PatchConfigResponse, ApiFailure> {
// 1. optimistic concurrency
let current = current_revision(config_path).await?;
if expected_revision.is_some_and(|expected| expected != current) {
return Err(ApiFailure::new(
hyper::StatusCode::CONFLICT,
"revision_conflict",
"Config revision mismatch",
));
}
// 2. convert + reject access / unknown sections
let patch_toml = json_to_toml(patch_json)
.map_err(|e| ApiFailure::bad_request(format!("invalid patch: {}", e)))?;
let patch_table = patch_toml
.as_table()
.ok_or_else(|| ApiFailure::bad_request("patch must be a JSON object"))?;
if patch_table.contains_key("access") {
return Err(ApiFailure::new(
hyper::StatusCode::BAD_REQUEST,
"access_not_editable",
"access.* is managed via the users API, not editable here",
));
}
for key in patch_table.keys() {
if !EDITABLE_SECTIONS.contains(&key.as_str()) {
return Err(ApiFailure::new(
hyper::StatusCode::BAD_REQUEST,
"section_not_editable",
format!("section not editable: {}", key),
));
}
}
let touched: Vec<&str> = patch_table
.keys()
.map(|k| k.as_str())
.filter(|k| EDITABLE_SECTIONS.contains(k))
.collect();
if touched.is_empty() {
return Err(ApiFailure::bad_request("empty patch: no editable sections"));
}
// 3. Parse old + merged from the SAME deserialize path so the classifier
// sees only the delta this patch introduces. `ProxyConfig::load` applies
// include-expansion / legacy-compat / normalization that a bare
// `try_into` does not; mixing the two paths would make unrelated fields
// compare unequal and spuriously force `restart_required`.
let original = tokio::fs::read_to_string(config_path)
.await
.map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?;
let original_toml: Toml = toml::from_str(&original)
.map_err(|e| ApiFailure::internal(format!("failed to parse config: {}", e)))?;
let old_cfg: ProxyConfig = original_toml
.clone()
.try_into()
.map_err(|e| ApiFailure::internal(format!("config does not deserialize: {}", e)))?;
let mut merged = original_toml;
deep_merge(&mut merged, &patch_toml);
let new_cfg: ProxyConfig = merged
.clone()
.try_into()
.map_err(|e| ApiFailure::bad_request(format!("config does not deserialize: {}", e)))?;
new_cfg
.validate()
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
// 4. classify changes (Telemt's own hot/restart rule)
let class = classify_config_changes(&old_cfg, &new_cfg);
// 5. write only the touched top-level sections
let revision = save_sections_to_disk(config_path, &new_cfg, &touched).await?;
Ok(PatchConfigResponse {
revision,
restart_required: class.restart_required,
changed: class.changed,
})
}
/// Return only the editable config sections + current revision.
pub(super) async fn read_managed_config(config_path: &Path) -> Result<(Toml, String), ApiFailure> {
let original = tokio::fs::read_to_string(config_path)
.await
.map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?;
let parsed: Toml = toml::from_str(&original)
.map_err(|e| ApiFailure::internal(format!("failed to parse config: {}", e)))?;
let parsed_table = parsed
.as_table()
.cloned()
.unwrap_or_else(toml::value::Table::new);
// Whitelist: return ONLY the editable sections. A blacklist (just removing
// `access`) would leak `server` (carries the API `auth_header` + per-node
// identity) and `network` (per-node addresses). Mirror the PATCH contract.
let mut table = toml::value::Table::new();
for section in EDITABLE_SECTIONS {
if let Some(value) = parsed_table.get(*section) {
table.insert((*section).to_string(), value.clone());
}
}
let revision = compute_revision(&original);
Ok((Toml::Table(table), revision))
}
/// Convert a serde_json value to a toml value. `null` is dropped from objects
/// (a patch never sets a key to TOML-null). Numbers become integers when exact,
/// otherwise floats.
fn json_to_toml(j: &Json) -> Result<Toml, String> {
Ok(match j {
Json::Null => return Err("null is not representable in TOML".into()),
Json::Bool(b) => Toml::Boolean(*b),
Json::Number(n) => {
if let Some(i) = n.as_i64() {
Toml::Integer(i)
} else if let Some(f) = n.as_f64() {
Toml::Float(f)
} else {
return Err(format!("unrepresentable number: {}", n));
}
}
Json::String(s) => Toml::String(s.clone()),
Json::Array(items) => {
let mut out = Vec::with_capacity(items.len());
for item in items {
out.push(json_to_toml(item)?);
}
Toml::Array(out)
}
Json::Object(map) => {
let mut table = toml::value::Table::new();
for (k, v) in map {
if v.is_null() {
continue; // skip nulls instead of erroring at object level
}
table.insert(k.clone(), json_to_toml(v)?);
}
Toml::Table(table)
}
})
}
/// Recursively overlay `patch` onto `base`. Tables merge key-by-key; every
/// other value type (scalars, arrays) replaces wholesale.
fn deep_merge(base: &mut Toml, patch: &Toml) {
match (base, patch) {
(Toml::Table(b), Toml::Table(p)) => {
for (k, pv) in p {
match b.get_mut(k) {
Some(bv) => deep_merge(bv, pv),
None => {
b.insert(k.clone(), pv.clone());
}
}
}
}
(b, p) => *b = p.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn json_object_converts_to_toml_table() {
let j: Json = serde_json::json!({"censorship": {"tls_domain": "a.com"}, "default_dc": 2});
let t = json_to_toml(&j).expect("convertible");
let table = t.as_table().unwrap();
assert_eq!(table["censorship"]["tls_domain"].as_str(), Some("a.com"));
assert_eq!(table["default_dc"].as_integer(), Some(2));
}
#[test]
fn deep_merge_overlays_tables_and_replaces_scalars() {
let mut base: Toml =
toml::from_str("[censorship]\ntls_domain = \"old\"\nfake_cert_len = 100\n").unwrap();
let patch: Toml = toml::from_str("[censorship]\ntls_domain = \"new\"\n").unwrap();
deep_merge(&mut base, &patch);
let cens = base["censorship"].as_table().unwrap();
assert_eq!(cens["tls_domain"].as_str(), Some("new")); // overlaid
assert_eq!(cens["fake_cert_len"].as_integer(), Some(100)); // preserved
}
use std::path::PathBuf;
fn temp_config(body: &str) -> (PathBuf, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("config.toml");
std::fs::write(&path, body).unwrap();
(path, dir)
}
#[tokio::test]
async fn patch_rejects_access_section() {
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
let patch: Json = serde_json::json!({"access": {"users": {"x": "y"}}});
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
assert_eq!(err.code, "access_not_editable");
}
#[tokio::test]
async fn patch_revision_conflict() {
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
let patch: Json = serde_json::json!({"censorship": {"tls_domain": "b"}});
let err = apply_patch_to_path(&path, &patch, Some("deadbeef".into()))
.await
.unwrap_err();
assert_eq!(err.code, "revision_conflict");
}
#[tokio::test]
async fn patch_sni_reports_restart_required() {
let (path, _d) =
temp_config("[censorship]\ntls_domain = \"a.com\"\n[server]\nport = 443\n");
let patch: Json = serde_json::json!({"censorship": {"tls_domain": "b.com"}});
let resp = apply_patch_to_path(&path, &patch, None).await.unwrap();
assert!(resp.restart_required);
assert!(resp.changed.iter().any(|c| c == "censorship"));
let written = std::fs::read_to_string(&path).unwrap();
assert!(written.contains("tls_domain = \"b.com\""));
assert_eq!(
resp.revision,
crate::api::config_store::compute_revision(&written)
);
}
#[tokio::test]
async fn read_managed_config_strips_access() {
let (path, _d) = temp_config(
"[censorship]\ntls_domain = \"a.com\"\n[access.users]\nbob = \"deadbeef\"\n",
);
let (value, revision) = read_managed_config(&path).await.unwrap();
let table = value.as_table().unwrap();
assert!(table.contains_key("censorship"));
assert!(!table.contains_key("access")); // secrets never leave the box here
assert_eq!(revision, current_revision(&path).await.unwrap());
}
#[tokio::test]
async fn read_managed_config_returns_only_editable_sections() {
// server carries the API auth_header + per-node identity; network carries
// per-node addresses. Neither must be exposed by GET /v1/config.
let (path, _d) = temp_config(concat!(
"[censorship]\ntls_domain = \"a\"\n",
"[server]\nport = 443\n[server.api]\nauth_header = \"SECRET\"\n",
"[network]\nipv4 = \"1.2.3.4\"\n",
"[access.users]\nbob = \"deadbeef\"\n",
));
let (value, _rev) = read_managed_config(&path).await.unwrap();
let table = value.as_table().unwrap();
assert!(table.contains_key("censorship"));
assert!(!table.contains_key("server")); // no API auth_header / identity leak
assert!(!table.contains_key("network")); // no per-node identity leak
assert!(!table.contains_key("access")); // no users/secrets
}
#[tokio::test]
async fn patch_rejects_server_section() {
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
let patch: Json = serde_json::json!({"server": {"port": 1}});
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
assert_eq!(err.code, "section_not_editable");
}
#[tokio::test]
async fn patch_rejects_show_link_section() {
// show_link is a legacy top-level scalar/array (not a [table]); it cannot
// be upserted safely and is superseded by the editable general.links.show.
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
let patch: Json = serde_json::json!({"show_link": "*"});
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
assert_eq!(err.code, "section_not_editable");
}
#[tokio::test]
async fn patch_general_links_show_is_editable() {
// The supported replacement path: edit show via the general.links sub-table.
let (path, _d) = temp_config(
"[general]\nprefer_ipv6 = false\n[general.links]\nshow = \"*\"\n\
[censorship]\ntls_domain = \"a\"\n",
);
let patch: Json = serde_json::json!({"general": {"links": {"show": ["alice"]}}});
let resp = apply_patch_to_path(&path, &patch, None).await.unwrap();
assert!(resp.changed.iter().any(|c| c == "general"));
let written = tokio::fs::read_to_string(&path).await.unwrap();
let parsed: toml::Value = toml::from_str(&written).unwrap();
assert_eq!(
parsed["general"]["links"]["show"][0].as_str(),
Some("alice"),
"{written}"
);
// No leaked top-level [links]/[modes] and no duplicate sub-tables.
assert_eq!(written.matches("[general.links]").count(), 1, "{written}");
}
#[tokio::test]
async fn patch_links_public_port_written_as_integer_not_float_or_string() {
// A JSON integer must land on disk as a bare TOML integer (443), never
// 443.0 nor "443". The write re-renders from the typed config, so the
// u16 field dictates the output format regardless of JSON quirks.
let (path, _d) = temp_config("[general]\nprefer_ipv6 = false\n");
let patch: Json = serde_json::json!({"general": {"links": {"public_port": 443}}});
apply_patch_to_path(&path, &patch, None).await.unwrap();
let written = tokio::fs::read_to_string(&path).await.unwrap();
assert!(written.contains("public_port = 443"), "{written}");
assert!(
!written.contains("443.0"),
"must not be a float:\n{written}"
);
assert!(
!written.contains("\"443\""),
"must not be a string:\n{written}"
);
let parsed: toml::Value = toml::from_str(&written).unwrap();
assert_eq!(
parsed["general"]["links"]["public_port"].as_integer(),
Some(443),
"{written}"
);
}
#[tokio::test]
async fn patch_links_public_port_rejects_float() {
// 443.0 cannot deserialize into u16 -> rejected, not silently coerced.
let (path, _d) = temp_config("[general]\nprefer_ipv6 = false\n");
let patch: Json = serde_json::json!({"general": {"links": {"public_port": 443.0}}});
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
assert_eq!(err.status, hyper::StatusCode::BAD_REQUEST, "{:?}", err);
}
#[tokio::test]
async fn patch_links_public_port_rejects_string() {
// "443" is a string, not a u16 -> rejected.
let (path, _d) = temp_config("[general]\nprefer_ipv6 = false\n");
let patch: Json = serde_json::json!({"general": {"links": {"public_port": "443"}}});
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
assert_eq!(err.status, hyper::StatusCode::BAD_REQUEST, "{:?}", err);
}
#[tokio::test]
async fn patch_empty_is_rejected() {
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
let patch: Json = serde_json::json!({});
assert!(apply_patch_to_path(&path, &patch, None).await.is_err());
}
#[tokio::test]
async fn patch_log_level_is_hot() {
// general.log_level is hot-reloadable -> a patch changing only it must
// report restart_required = false (exercises the full apply path, not
// just the classifier). Default LogLevel is Normal; patch to "debug".
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
let patch: Json = serde_json::json!({"general": {"log_level": "debug"}});
let resp = apply_patch_to_path(&path, &patch, None).await.unwrap();
assert!(!resp.restart_required);
assert!(resp.changed.iter().any(|c| c == "general"));
}
}
+326 -10
View File
@@ -97,6 +97,90 @@ pub(super) async fn save_config_to_disk(
Ok(compute_revision(&serialized))
}
/// Top-level config tables that may be edited via the config API.
///
/// Intentionally excluded (defense-in-depth, enforces the spec's per-node
/// identity invariant at the Telemt layer too):
///
/// - `access` : owned by the users API.
/// - `server` : carries per-node identity (`port`, `api`/`api_bind`, listeners).
/// - `network` : carries per-node identity (`ipv4`/`ipv6`).
/// - `show_link` : legacy top-level scalar/array (not a `[table]`), superseded
/// by the editable `general.links.show` sub-table. The
/// section-upsert machinery here only handles `[table]` /
/// `[[array-of-tables]]` blocks; a bare top-level key cannot be
/// located or replaced safely, so it is edited via `general`.
///
/// A future field-level allowlist can re-admit specific safe fields
/// (e.g. `network.dns_overrides`) without opening the whole section.
pub(super) const EDITABLE_SECTIONS: &[&str] = &[
"general",
"timeouts",
"censorship",
"upstreams",
"dc_overrides",
];
/// Re-render the given top-level tables from `cfg` and upsert each into the
/// on-disk file, preserving every untouched section (and its comments).
pub(super) async fn save_sections_to_disk(
config_path: &Path,
cfg: &ProxyConfig,
sections: &[&str],
) -> Result<String, ApiFailure> {
let mut content = tokio::fs::read_to_string(config_path)
.await
.map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?;
for section in sections {
let rendered = render_top_level_section(cfg, section)?;
content = upsert_toml_table(&content, section, &rendered);
}
write_atomic(config_path.to_path_buf(), content.clone()).await?;
Ok(compute_revision(&content))
}
/// Render one top-level table as `[section]\n...\n` (or `[[upstreams]]` array
/// of tables) from the typed `cfg`. Serializes via the `toml` crate so the
/// output matches the canonical format Telemt parses.
fn render_top_level_section(cfg: &ProxyConfig, section: &str) -> Result<String, ApiFailure> {
let value = toml::Value::try_from(cfg)
.map_err(|e| ApiFailure::internal(format!("failed to serialize config: {}", e)))?;
let table = value
.get(section)
.ok_or_else(|| ApiFailure::internal(format!("unknown section: {}", section)))?;
// upstreams is an array-of-tables -> render as [[upstreams]] blocks.
if let toml::Value::Array(items) = table {
let mut out = String::new();
for item in items {
out.push_str(&format!("[[{}]]\n", section));
out.push_str(&toml::to_string(item).map_err(|e| {
ApiFailure::internal(format!("failed to serialize {}: {}", section, e))
})?);
if !out.ends_with('\n') {
out.push('\n');
}
}
return Ok(out);
}
// Serialize the table *inside a wrapper keyed by `section`* so the `toml`
// crate emits correctly dotted headers for nested sub-tables, e.g.
// `[general]` + `[general.modes]` + `[general.links]`. Serializing the
// inner table alone would render bare `[modes]`/`[links]` headers, which
// would leak as duplicate top-level tables and break config load.
let mut wrapper = toml::value::Table::new();
wrapper.insert(section.to_string(), table.clone());
let mut out = toml::to_string(&toml::Value::Table(wrapper))
.map_err(|e| ApiFailure::internal(format!("failed to serialize {}: {}", section, e)))?;
if !out.ends_with('\n') {
out.push('\n');
}
Ok(out)
}
pub(super) async fn save_access_sections_to_disk(
config_path: &Path,
cfg: &ProxyConfig,
@@ -253,11 +337,22 @@ fn serialize_toml_key(key: &str) -> Result<String, ApiFailure> {
}
fn upsert_toml_table(source: &str, table_name: &str, replacement: &str) -> String {
if let Some((start, end)) = find_toml_table_bounds(source, table_name) {
let blocks = find_all_table_blocks(source, table_name);
if let Some(&(first_start, first_end)) = blocks.first() {
// Replace the first block in place and delete any further blocks that
// also belong to this table. Telemt writes a section's sub-tables
// contiguously, but a hand-edited config may scatter them; dropping the
// extras here prevents the duplicate-table corruption that would
// otherwise break config load.
let mut out = String::with_capacity(source.len() + replacement.len());
out.push_str(&source[..start]);
out.push_str(&source[..first_start]);
out.push_str(replacement);
out.push_str(&source[end..]);
let mut cursor = first_end;
for &(start, end) in &blocks[1..] {
out.push_str(&source[cursor..start]);
cursor = end;
}
out.push_str(&source[cursor..]);
return out;
}
@@ -272,24 +367,62 @@ fn upsert_toml_table(source: &str, table_name: &str, replacement: &str) -> Strin
out
}
/// Whether a (comment-stripped, trimmed) TOML header line belongs to
/// `table_name`: the table itself (`[X]` / `[[X]]`) or any of its nested
/// sub-tables (`[X.…]` / `[[X.…]]`). The trailing dot guards against sibling
/// prefixes — `access.users` must not match `access.user_enabled`.
fn header_belongs_to(header: &str, table_name: &str) -> bool {
let body = match header.strip_prefix("[[").and_then(|h| h.strip_suffix("]]")) {
Some(body) => body,
None => match header.strip_prefix('[').and_then(|h| h.strip_suffix(']')) {
Some(body) => body,
None => return false,
},
};
let body = body.trim();
body == table_name
|| body
.strip_prefix(table_name)
.is_some_and(|rest| rest.starts_with('.'))
}
/// Locate the first contiguous byte range covering `table_name` and the nested
/// sub-tables immediately following it. Used for existence checks; see
/// [`find_all_table_blocks`] for the full set of (possibly scattered) blocks.
fn find_toml_table_bounds(source: &str, table_name: &str) -> Option<(usize, usize)> {
let target = format!("[{}]", table_name);
find_all_table_blocks(source, table_name).into_iter().next()
}
/// Locate every byte range that belongs to `table_name`: the table header and
/// its nested sub-tables. Returns one range per contiguous run, so a config
/// where a section's sub-tables are scattered (e.g. hand-edited) yields several
/// ranges — letting the caller collapse them into a single rendered block.
fn find_all_table_blocks(source: &str, table_name: &str) -> Vec<(usize, usize)> {
let mut blocks = Vec::new();
let mut offset = 0usize;
let mut start = None;
let mut start: Option<usize> = None;
for line in source.split_inclusive('\n') {
let trimmed = line.trim();
// Drop any inline comment so a hand-edited header like
// `[censorship] # note` still matches. Section names never contain `#`.
let header = line.trim().split('#').next().unwrap_or("").trim();
let is_header = header.starts_with('[');
if let Some(start_offset) = start {
if trimmed.starts_with('[') {
return Some((start_offset, offset));
if is_header && !header_belongs_to(header, table_name) {
blocks.push((start_offset, offset));
start = None;
}
} else if trimmed == target {
}
if start.is_none() && header_belongs_to(header, table_name) {
start = Some(offset);
}
offset = offset.saturating_add(line.len());
}
start.map(|start_offset| (start_offset, source.len()))
if let Some(start_offset) = start {
blocks.push((start_offset, source.len()));
}
blocks
}
async fn write_atomic(path: PathBuf, contents: String) -> Result<(), ApiFailure> {
@@ -336,6 +469,189 @@ fn write_atomic_sync(path: &Path, contents: &str) -> std::io::Result<()> {
mod tests {
use super::*;
#[tokio::test]
async fn save_sections_preserves_other_tables_and_comments() {
let dir = std::env::temp_dir().join(format!("cfgtest-{}", rand::random::<u64>()));
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("config.toml");
std::fs::write(
&path,
"# top comment\n[censorship]\ntls_domain = \"old.example\"\n\n[server]\nport = 443\n",
)
.unwrap();
let mut cfg = ProxyConfig::default();
cfg.censorship.tls_domain = "new.example".to_string();
cfg.server.port = 443;
let rev = save_sections_to_disk(&path, &cfg, &["censorship"])
.await
.unwrap();
let written = std::fs::read_to_string(&path).unwrap();
assert!(written.contains("tls_domain = \"new.example\""));
assert!(written.contains("# top comment")); // untouched comment kept
assert!(written.contains("[server]\nport = 443")); // untouched table kept
assert_eq!(rev, compute_revision(&written));
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn find_bounds_matches_array_of_tables() {
let src =
"[server]\nport = 1\n\n[[upstreams]]\nkind = \"a\"\n\n[[upstreams]]\nkind = \"b\"\n";
let bounds = find_toml_table_bounds(src, "upstreams");
assert!(bounds.is_some(), "should locate [[upstreams]] block start");
let (start, end) = bounds.unwrap();
let slice = &src[start..end];
assert!(slice.starts_with("[[upstreams]]"));
assert!(slice.contains("kind = \"b\"")); // spans through the last upstream block
}
#[test]
fn find_bounds_matches_header_with_inline_comment() {
let src = "[censorship] # notes\ntls_domain = \"a\"\n\n[server]\nport = 1\n";
let bounds = find_toml_table_bounds(src, "censorship");
assert!(bounds.is_some(), "commented header must still match");
let (start, end) = bounds.unwrap();
let slice = &src[start..end];
assert!(slice.starts_with("[censorship] # notes"));
assert!(slice.contains("tls_domain"));
assert!(!slice.contains("[server]")); // terminates at the next header
}
#[tokio::test]
async fn save_general_section_keeps_subtables_dotted_without_duplicates() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("config.toml");
tokio::fs::write(
&path,
"[general]\nprefer_ipv6 = false\n\n[general.modes]\ntls = true\n\n\
[general.links]\npublic_host = \"old.example\"\n\n[server]\nport = 443\n",
)
.await
.unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.prefer_ipv6 = true;
save_sections_to_disk(&path, &cfg, &["general"])
.await
.unwrap();
let written = tokio::fs::read_to_string(&path).await.unwrap();
// No bare top-level [modes] / [links] headers leaked.
for line in written.lines() {
let header = line.trim();
assert_ne!(header, "[modes]", "leaked top-level [modes]:\n{written}");
assert_ne!(header, "[links]", "leaked top-level [links]:\n{written}");
}
// Sub-tables kept their dotted prefix exactly once each.
assert_eq!(
written.matches("[general.modes]").count(),
1,
"[general.modes] must appear exactly once:\n{written}"
);
assert_eq!(
written.matches("[general.links]").count(),
1,
"[general.links] must appear exactly once:\n{written}"
);
// Result parses (duplicate tables would error here).
toml::from_str::<toml::Value>(&written)
.unwrap_or_else(|e| panic!("written config must parse: {e}\n{written}"));
assert!(written.contains("[server]\nport = 443")); // untouched table kept
}
#[tokio::test]
async fn save_general_section_is_idempotent_across_repeated_saves() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("config.toml");
tokio::fs::write(
&path,
"[general]\nprefer_ipv6 = false\n\n[general.modes]\ntls = true\n\n\
[general.links]\npublic_host = \"old.example\"\n",
)
.await
.unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.prefer_ipv6 = true;
save_sections_to_disk(&path, &cfg, &["general"])
.await
.unwrap();
save_sections_to_disk(&path, &cfg, &["general"])
.await
.unwrap();
let written = tokio::fs::read_to_string(&path).await.unwrap();
assert_eq!(written.matches("[general.modes]").count(), 1, "{written}");
assert_eq!(written.matches("[general.links]").count(), 1, "{written}");
assert_eq!(written.matches("[general]").count(), 1, "{written}");
toml::from_str::<toml::Value>(&written)
.unwrap_or_else(|e| panic!("written config must parse: {e}\n{written}"));
}
#[test]
fn find_bounds_spans_dotted_subtables() {
let src = "[general]\nprefer_ipv6 = false\n\n[general.modes]\ntls = true\n\n\
[general.links]\npublic_host = \"a\"\n\n[server]\nport = 1\n";
let bounds = find_toml_table_bounds(src, "general");
assert!(bounds.is_some(), "should locate [general] block");
let (start, end) = bounds.unwrap();
let slice = &src[start..end];
assert!(slice.starts_with("[general]"));
assert!(slice.contains("[general.modes]")); // spans nested sub-tables
assert!(slice.contains("[general.links]"));
assert!(!slice.contains("[server]")); // terminates at the next unrelated header
}
#[test]
fn find_bounds_does_not_overrun_sibling_prefix() {
// access.users must not swallow access.user_enabled (dot guards the prefix).
let src = "[access.users]\nalice = \"x\"\n\n[access.user_enabled]\nalice = true\n";
let bounds = find_toml_table_bounds(src, "access.users").unwrap();
let slice = &src[bounds.0..bounds.1];
assert!(slice.starts_with("[access.users]"));
assert!(!slice.contains("[access.user_enabled]"));
}
#[tokio::test]
async fn save_general_handles_non_contiguous_subtables() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("config.toml");
// Hand-edited layout: [general.modes] sits AFTER an unrelated [server].
tokio::fs::write(
&path,
"[general]\nprefer_ipv6 = false\n\n[server]\nport = 443\n\n\
[general.modes]\ntls = true\n",
)
.await
.unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.prefer_ipv6 = true;
save_sections_to_disk(&path, &cfg, &["general"])
.await
.unwrap();
let written = tokio::fs::read_to_string(&path).await.unwrap();
assert_eq!(
written.matches("[general.modes]").count(),
1,
"non-contiguous [general.modes] must not duplicate:\n{written}"
);
toml::from_str::<toml::Value>(&written)
.unwrap_or_else(|e| panic!("written config must parse: {e}\n{written}"));
assert!(written.contains("[server]")); // unrelated section preserved
}
#[test]
fn render_user_rate_limits_section() {
let mut cfg = ProxyConfig::default();
+34
View File
@@ -28,6 +28,7 @@ use crate::stats::Stats;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool;
mod config_edit;
mod config_store;
mod events;
mod http_utils;
@@ -84,6 +85,7 @@ const ALLOW_GET: &str = "GET";
const ALLOW_POST: &str = "POST";
const ALLOW_GET_POST: &str = "GET, POST";
const ALLOW_GET_PATCH_DELETE: &str = "GET, PATCH, DELETE";
const ALLOW_GET_PATCH: &str = "GET, PATCH";
pub(super) struct ApiRuntimeState {
pub(super) process_started_at_epoch_secs: u64,
@@ -174,6 +176,7 @@ fn allowed_methods_for_path(path: &str) -> Option<&'static str> {
| "/v1/stats/users/quota"
| "/v1/stats/users" => Some(ALLOW_GET),
"/v1/users" => Some(ALLOW_GET_POST),
"/v1/config" => Some(ALLOW_GET_PATCH),
_ if user_action_route_matches(path, "/reset-quota") => Some(ALLOW_POST),
_ if user_action_route_matches(path, "/rotate-secret") => Some(ALLOW_POST),
_ if user_action_route_matches(path, "/enable") => Some(ALLOW_POST),
@@ -643,6 +646,37 @@ async fn handle(
};
Ok(success_response(status, data, revision))
}
("GET", "/v1/config") => {
let (value, revision) =
config_edit::read_managed_config(&shared.config_path).await?;
Ok(success_response(StatusCode::OK, value, revision))
}
("PATCH", "/v1/config") => {
if api_cfg.read_only {
return Ok(error_response(
request_id,
ApiFailure::new(
StatusCode::FORBIDDEN,
"read_only",
"API runs in read-only mode",
),
));
}
let expected_revision = parse_if_match(req.headers());
let body = read_json::<serde_json::Value>(req.into_body(), body_limit).await?;
match config_edit::patch_config(body, expected_revision, &shared).await {
Ok(resp) => {
let revision = resp.revision.clone();
Ok(success_response(StatusCode::OK, resp, revision))
}
Err(error) => {
shared
.runtime_events
.record("api.config.patch.failed", error.code);
Err(error)
}
}
}
_ => {
if method == Method::POST
&& let Some(base_user) = normalized_path
+15
View File
@@ -54,6 +54,9 @@ const DEFAULT_CONNTRACK_CONTROL_ENABLED: bool = true;
const DEFAULT_CONNTRACK_PRESSURE_HIGH_WATERMARK_PCT: u8 = 85;
const DEFAULT_CONNTRACK_PRESSURE_LOW_WATERMARK_PCT: u8 = 70;
const DEFAULT_CONNTRACK_DELETE_BUDGET_PER_SEC: u64 = 4096;
const DEFAULT_SYNLIMIT_SECONDS: u32 = 1;
const DEFAULT_SYNLIMIT_HITCOUNT: u32 = 1;
const DEFAULT_SYNLIMIT_BURST: u32 = 2;
const DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS: u32 = 2;
const DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD: u32 = 5;
const DEFAULT_UPSTREAM_CONNECT_BUDGET_MS: u64 = 3000;
@@ -243,6 +246,18 @@ pub(crate) fn default_conntrack_delete_budget_per_sec() -> u64 {
DEFAULT_CONNTRACK_DELETE_BUDGET_PER_SEC
}
pub(crate) fn default_synlimit_seconds() -> u32 {
DEFAULT_SYNLIMIT_SECONDS
}
pub(crate) fn default_synlimit_hitcount() -> u32 {
DEFAULT_SYNLIMIT_HITCOUNT
}
pub(crate) fn default_synlimit_burst() -> u32 {
DEFAULT_SYNLIMIT_BURST
}
pub(crate) fn default_prefer_4() -> u8 {
4
}
+138 -2
View File
@@ -16,10 +16,12 @@
//! | `general` | `telemetry` / `me_*_policy` | Applied immediately |
//! | `network` | `dns_overrides` | Applied immediately |
//! | `access` | All user/quota fields | Effective immediately |
//! | `server.listeners` | `synlimit*` for existing endpoints | Netfilter rules reconciled immediately |
//!
//! Fields that require re-binding sockets (`server.listeners`, legacy
//! `server.port`, `censorship.*`, `network.*`, `use_middle_proxy`) are **not**
//! applied; a warning is emitted.
//! applied, except for SYN limiter fields on unchanged listener endpoints; a
//! warning is emitted.
//! Non-hot changes are never mixed into the runtime config snapshot.
use std::collections::BTreeSet;
@@ -34,7 +36,8 @@ use tracing::{error, info, warn};
use super::load::{LoadedConfig, ProxyConfig};
use crate::config::{
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode,
ListenerConfig, LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
MeWriterPickMode, SynLimitMode,
};
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
@@ -131,6 +134,17 @@ pub struct HotFields {
pub user_max_unique_ips_global_each: usize,
pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode,
pub user_max_unique_ips_window_secs: u64,
pub listener_synlimit: Vec<ListenerSynLimitHotFields>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ListenerSynLimitHotFields {
pub ip: IpAddr,
pub port: Option<u16>,
pub synlimit: SynLimitMode,
pub synlimit_seconds: u32,
pub synlimit_hitcount: u32,
pub synlimit_burst: u32,
}
impl HotFields {
@@ -260,6 +274,25 @@ impl HotFields {
user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each,
user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode,
user_max_unique_ips_window_secs: cfg.access.user_max_unique_ips_window_secs,
listener_synlimit: cfg
.server
.listeners
.iter()
.map(ListenerSynLimitHotFields::from_listener)
.collect(),
}
}
}
impl ListenerSynLimitHotFields {
fn from_listener(listener: &ListenerConfig) -> Self {
Self {
ip: listener.ip,
port: listener.port,
synlimit: listener.synlimit,
synlimit_seconds: listener.synlimit_seconds,
synlimit_hitcount: listener.synlimit_hitcount,
synlimit_burst: listener.synlimit_burst,
}
}
}
@@ -312,6 +345,7 @@ fn listeners_equal(
lhs.iter().zip(rhs.iter()).all(|(a, b)| {
a.ip == b.ip
&& a.port == b.port
&& a.client_mss == b.client_mss
&& a.announce == b.announce
&& a.announce_ip == b.announce_ip
&& a.proxy_protocol == b.proxy_protocol
@@ -565,6 +599,7 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
cfg.access.user_max_unique_ips_global_each = new.access.user_max_unique_ips_global_each;
cfg.access.user_max_unique_ips_mode = new.access.user_max_unique_ips_mode;
cfg.access.user_max_unique_ips_window_secs = new.access.user_max_unique_ips_window_secs;
overlay_listener_synlimit_fields(&mut cfg.server.listeners, &new.server.listeners);
if cfg.rebuild_runtime_user_auth().is_err() {
cfg.runtime_user_auth = None;
@@ -573,6 +608,21 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
cfg
}
fn overlay_listener_synlimit_fields(old: &mut [ListenerConfig], new: &[ListenerConfig]) {
if old.len() != new.len() {
return;
}
for (old_listener, new_listener) in old.iter_mut().zip(new.iter()) {
if old_listener.ip != new_listener.ip || old_listener.port != new_listener.port {
continue;
}
old_listener.synlimit = new_listener.synlimit;
old_listener.synlimit_seconds = new_listener.synlimit_seconds;
old_listener.synlimit_hitcount = new_listener.synlimit_hitcount;
old_listener.synlimit_burst = new_listener.synlimit_burst;
}
}
/// Warn if any non-hot fields changed (require restart).
fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: bool) {
let mut warned = false;
@@ -608,6 +658,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.server.listen_addr_ipv4 != new.server.listen_addr_ipv4
|| old.server.listen_addr_ipv6 != new.server.listen_addr_ipv6
|| old.server.listen_tcp != new.server.listen_tcp
|| old.server.client_mss != new.server.client_mss
|| old.server.listen_unix_sock != new.server.listen_unix_sock
|| old.server.listen_unix_sock_perm != new.server.listen_unix_sock_perm
{
@@ -618,6 +669,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.tls_domains != new.censorship.tls_domains
|| old.censorship.tls_fetch_scope != new.censorship.tls_fetch_scope
|| old.censorship.mask != new.censorship.mask
|| old.censorship.mask_dynamic != new.censorship.mask_dynamic
|| old.censorship.mask_host != new.censorship.mask_host
|| old.censorship.mask_port != new.censorship.mask_port
|| old.censorship.exclusive_mask != new.censorship.exclusive_mask
@@ -847,6 +899,13 @@ fn log_changes(
);
}
if old_hot.listener_synlimit != new_hot.listener_synlimit {
info!(
"config reload: server.listeners SYN limiter updated ({} listeners)",
new_hot.listener_synlimit.len()
);
}
if old_hot.desync_all_full != new_hot.desync_all_full {
info!(
"config reload: desync_all_full: {} → {}",
@@ -1487,6 +1546,48 @@ pub fn spawn_config_watcher(
(config_rx, log_rx)
}
// ── Change classification ─────────────────────────────────────────────────────
/// Which top-level config sections changed and whether any require a restart.
#[derive(Debug, Default, Clone, serde::Serialize)]
pub struct ChangeClassification {
pub changed: Vec<String>,
pub restart_required: bool,
}
/// Classify old->new using Telemt's OWN reload rule: overlay the hot fields and
/// see if anything non-hot remains different. This guarantees `restart_required`
/// matches actual runtime behavior and never drifts as new fields are added.
pub fn classify_config_changes(old: &ProxyConfig, new: &ProxyConfig) -> ChangeClassification {
let applied = overlay_hot_fields(old, new);
let restart_required = !config_equal(&applied, new);
ChangeClassification {
changed: changed_sections(old, new),
restart_required,
}
}
/// Top-level config sections whose canonical serialized form differs between
/// old and new. Uses the same serialize+canonicalize path as `config_equal`.
fn changed_sections(old: &ProxyConfig, new: &ProxyConfig) -> Vec<String> {
let mut lhs = serde_json::to_value(old).unwrap_or(serde_json::Value::Null);
let mut rhs = serde_json::to_value(new).unwrap_or(serde_json::Value::Null);
canonicalize_json(&mut lhs);
canonicalize_json(&mut rhs);
let mut out = Vec::new();
if let (Some(lo), Some(ro)) = (lhs.as_object(), rhs.as_object()) {
let mut keys: std::collections::BTreeSet<&String> = lo.keys().collect();
keys.extend(ro.keys());
for key in keys {
if lo.get(key) != ro.get(key) {
out.push(key.clone());
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
@@ -1659,6 +1760,41 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn classify_sni_change_requires_restart() {
// censorship.* is not in overlay_hot_fields -> restart.
let old = ProxyConfig::default();
let mut new = ProxyConfig::default();
new.censorship.tls_domain = "front.example".to_string();
let class = classify_config_changes(&old, &new);
assert!(class.restart_required);
assert!(class.changed.iter().any(|c| c == "censorship"));
}
#[test]
fn classify_dns_overrides_change_is_hot() {
// network.dns_overrides IS in overlay_hot_fields -> no restart.
let old = ProxyConfig::default();
let mut new = ProxyConfig::default();
new.network.dns_overrides.push("1.1.1.1".to_string());
let class = classify_config_changes(&old, &new);
assert!(!class.restart_required);
assert!(class.changed.iter().any(|c| c == "network"));
}
#[test]
fn classify_timeouts_change_requires_restart() {
// timeouts.* is NOT in overlay_hot_fields -> restart.
let old = ProxyConfig::default();
let mut new = ProxyConfig::default();
new.timeouts.client_handshake = old.timeouts.client_handshake + 1;
let class = classify_config_changes(&old, &new);
assert!(class.restart_required);
}
#[test]
fn reload_recovers_after_parse_error_on_next_attempt() {
let initial_tag = "cccccccccccccccccccccccccccccccc";
+202 -5
View File
@@ -299,6 +299,8 @@ const SERVER_CONFIG_KEYS: &[&str] = &[
"listen_unix_sock",
"listen_unix_sock_perm",
"listen_tcp",
"client_mss",
"client_mss_bulk",
"proxy_protocol",
"proxy_protocol_header_timeout_ms",
"proxy_protocol_trusted_cidrs",
@@ -344,6 +346,11 @@ const CONNTRACK_CONTROL_CONFIG_KEYS: &[&str] = &[
const LISTENER_CONFIG_KEYS: &[&str] = &[
"ip",
"port",
"client_mss",
"synlimit",
"synlimit_seconds",
"synlimit_hitcount",
"synlimit_burst",
"announce",
"announce_ip",
"proxy_protocol",
@@ -370,6 +377,7 @@ const CENSORSHIP_CONFIG_KEYS: &[&str] = &[
"tls_fetch_scope",
"tls_fetch",
"mask",
"mask_dynamic",
"mask_host",
"mask_port",
"exclusive_mask",
@@ -1933,6 +1941,42 @@ impl ProxyConfig {
));
}
if config.server.listen_backlog == 0 || config.server.listen_backlog > i32::MAX as u32 {
return Err(ProxyError::Config(format!(
"server.listen_backlog must be within [1, {}]",
i32::MAX
)));
}
config
.server
.client_mss_value()
.map_err(|error| ProxyError::Config(format!("server.client_mss {error}")))?;
for (idx, listener) in config.server.listeners.iter().enumerate() {
if listener.client_mss.is_some() {
listener
.effective_client_mss(&config.server)
.map_err(|error| {
ProxyError::Config(format!("server.listeners[{idx}].client_mss {error}"))
})?;
}
if listener.synlimit_seconds == 0 {
return Err(ProxyError::Config(format!(
"server.listeners[{idx}].synlimit_seconds must be > 0"
)));
}
if listener.synlimit_hitcount == 0 {
return Err(ProxyError::Config(format!(
"server.listeners[{idx}].synlimit_hitcount must be > 0"
)));
}
if listener.synlimit_burst == 0 {
return Err(ProxyError::Config(format!(
"server.listeners[{idx}].synlimit_burst must be > 0"
)));
}
}
if config.server.accept_permit_timeout_ms > 60_000 {
return Err(ProxyError::Config(
"server.accept_permit_timeout_ms must be within [0, 60000]".to_string(),
@@ -2031,11 +2075,6 @@ impl ProxyConfig {
*mask_host = normalize_mask_host_to_ascii(mask_host, "censorship.mask_host")?;
}
// Default mask_host to tls_domain if not set and no unix socket configured.
if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() {
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
}
for (domain, target) in &config.censorship.exclusive_mask {
if !is_valid_tls_domain_name(domain) {
return Err(ProxyError::Config(format!(
@@ -2173,6 +2212,11 @@ impl ProxyConfig {
config.server.listeners.push(ListenerConfig {
ip: ipv4,
port: Some(config.server.port),
client_mss: None,
synlimit: SynLimitMode::default(),
synlimit_seconds: default_synlimit_seconds(),
synlimit_hitcount: default_synlimit_hitcount(),
synlimit_burst: default_synlimit_burst(),
announce: None,
announce_ip: None,
proxy_protocol: None,
@@ -2185,6 +2229,11 @@ impl ProxyConfig {
config.server.listeners.push(ListenerConfig {
ip: ipv6,
port: Some(config.server.port),
client_mss: None,
synlimit: SynLimitMode::default(),
synlimit_seconds: default_synlimit_seconds(),
synlimit_hitcount: default_synlimit_hitcount(),
synlimit_burst: default_synlimit_burst(),
announce: None,
announce_ip: None,
proxy_protocol: None,
@@ -2460,6 +2509,7 @@ mod tests {
assert_eq!(cfg.general.update_every, default_update_every());
assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4());
assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt());
assert_eq!(cfg.server.client_mss_value(), Ok(None));
assert_eq!(
cfg.server.proxy_protocol_trusted_cidrs,
default_proxy_protocol_trusted_cidrs()
@@ -3787,6 +3837,153 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn client_mss_presets_and_listener_override_are_resolved() {
let toml = r#"
[server]
client_mss = "tspu"
[[server.listeners]]
ip = "127.0.0.1"
port = 1443
[[server.listeners]]
ip = "127.0.0.2"
port = 1444
client_mss = "2in8"
[[server.listeners]]
ip = "127.0.0.3"
port = 1445
client_mss = ""
[[server.listeners]]
ip = "127.0.0.4"
port = 1446
client_mss = "extreme-low"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_client_mss_valid_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(cfg.server.client_mss_value(), Ok(Some(92)));
assert_eq!(
cfg.server.listeners[0].effective_client_mss(&cfg.server),
Ok(Some(92))
);
assert_eq!(
cfg.server.listeners[1].effective_client_mss(&cfg.server),
Ok(Some(256))
);
assert_eq!(
cfg.server.listeners[2].effective_client_mss(&cfg.server),
Ok(None)
);
assert_eq!(
cfg.server.listeners[3].effective_client_mss(&cfg.server),
Ok(Some(88))
);
let _ = std::fs::remove_file(path);
}
#[test]
fn client_mss_custom_value_is_accepted() {
let toml = r#"
[server]
client_mss = "4096"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_client_mss_custom_valid_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(cfg.server.client_mss_value(), Ok(Some(4096)));
let _ = std::fs::remove_file(path);
}
#[test]
fn client_mss_out_of_range_is_rejected() {
for value in ["87", "4097"] {
let toml = format!(
r#"
[server]
client_mss = "{value}"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#
);
let dir = std::env::temp_dir();
let path = dir.join(format!("telemt_client_mss_out_of_range_{value}_test.toml"));
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("server.client_mss custom value must be within [88, 4096]"));
let _ = std::fs::remove_file(path);
}
}
#[test]
fn client_mss_unquoted_number_is_rejected() {
let toml = r#"
[server]
client_mss = 256
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_client_mss_unquoted_number_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("client_mss"));
let _ = std::fs::remove_file(path);
}
#[test]
fn listener_client_mss_invalid_preset_is_rejected() {
let toml = r#"
[[server.listeners]]
ip = "127.0.0.1"
port = 1443
client_mss = "tiny"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_listener_client_mss_invalid_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("server.listeners[0].client_mss"));
assert!(err.contains("must be \"\", extreme-low, tspu, 2in8"));
let _ = std::fs::remove_file(path);
}
#[test]
fn api_runtime_edge_cache_ttl_out_of_range_is_rejected() {
let toml = r#"
@@ -95,6 +95,44 @@ max_client_frame = 16777217
remove_temp_config(&path);
}
#[test]
fn load_rejects_listen_backlog_above_i32_upper_bound() {
let path = write_temp_config(
r#"
[server]
listen_backlog = 2147483648
"#,
);
let err = ProxyConfig::load(&path).expect_err("listen_backlog above socket cap must fail");
let msg = err.to_string();
assert!(
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
"error must explain listen_backlog hard cap, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_zero_listen_backlog() {
let path = write_temp_config(
r#"
[server]
listen_backlog = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("zero listen_backlog must fail");
let msg = err.to_string();
assert!(
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
"error must explain listen_backlog lower bound, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_memory_limits_at_hard_upper_bounds() {
let path = write_temp_config(
+172 -1
View File
@@ -429,7 +429,7 @@ pub struct GeneralConfig {
pub ad_tag: Option<String>,
/// Public IP override for middle-proxy NAT environments.
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
/// When set, this IP is used in ME key derivation and local address translation.
#[serde(default)]
pub middle_proxy_nat_ip: Option<IpAddr>,
@@ -1369,6 +1369,77 @@ impl ConntrackPressureProfile {
}
}
/// Per-listener SYN limiter mode.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SynLimitMode {
/// Disable SYN limiting for this listener.
#[default]
Off,
/// Use iptables/ip6tables filter rules with the hashlimit match.
Iptables,
/// Use nftables rules with per-source token-bucket meters.
Nftables,
}
impl Serialize for SynLimitMode {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Self::Off => serializer.serialize_bool(false),
Self::Iptables => serializer.serialize_str("iptables"),
Self::Nftables => serializer.serialize_str("nftables"),
}
}
}
impl<'de> Deserialize<'de> for SynLimitMode {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct SynLimitModeVisitor;
impl<'de> serde::de::Visitor<'de> for SynLimitModeVisitor {
type Value = SynLimitMode;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("false, iptables, or nftables")
}
fn visit_bool<E>(self, value: bool) -> std::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
if value {
Err(E::custom(
"synlimit=true is ambiguous; use \"iptables\" or \"nftables\"",
))
} else {
Ok(SynLimitMode::Off)
}
}
fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
match value.trim().to_ascii_lowercase().as_str() {
"false" | "off" | "disabled" | "none" => Ok(SynLimitMode::Off),
"iptables" => Ok(SynLimitMode::Iptables),
"nftables" => Ok(SynLimitMode::Nftables),
_ => Err(E::custom(
"synlimit must be false, \"iptables\", or \"nftables\"",
)),
}
}
}
deserializer.deserialize_any(SynLimitModeVisitor)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConntrackControlConfig {
/// Enables runtime conntrack-control worker for pressure mitigation.
@@ -1451,6 +1522,20 @@ pub struct ServerConfig {
#[serde(default)]
pub listen_tcp: Option<bool>,
/// Client-facing TCP MSS preset or custom value for all TCP listeners.
/// Empty string or omitted value keeps the kernel default.
#[serde(default)]
pub client_mss: Option<String>,
/// Client-facing TCP MSS to switch to AFTER the TLS handshake (ServerHello)
/// is sent. Lets `client_mss` fragment ONLY the handshake (the DPI-inspected
/// part) while the bulk transfer uses normal-size packets — avoids the ~10x
/// packets-per-second blowup that triggers anti-DDoS abuse blocks on
/// pps-policing hosts. Empty/omitted = keep the handshake MSS for the whole
/// connection (previous behavior). Same preset/int grammar as `client_mss`.
#[serde(default)]
pub client_mss_bulk: Option<String>,
/// Accept HAProxy PROXY protocol headers on incoming connections.
/// When enabled, real client IPs are extracted from PROXY v1/v2 headers.
#[serde(default)]
@@ -1517,6 +1602,8 @@ impl Default for ServerConfig {
listen_unix_sock: None,
listen_unix_sock_perm: None,
listen_tcp: None,
client_mss: None,
client_mss_bulk: None,
proxy_protocol: false,
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(),
@@ -1720,6 +1807,10 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_true")]
pub mask: bool,
/// Use the ClientHello SNI as the mask TCP target for configured TLS domains.
#[serde(default = "default_true")]
pub mask_dynamic: bool,
#[serde(default)]
pub mask_host: Option<String>,
@@ -1855,6 +1946,7 @@ impl Default for AntiCensorshipConfig {
tls_fetch_scope: default_tls_fetch_scope(),
tls_fetch: TlsFetchConfig::default(),
mask: default_true(),
mask_dynamic: default_true(),
mask_host: None,
mask_port: default_mask_port(),
exclusive_mask: HashMap::new(),
@@ -2087,6 +2179,22 @@ pub struct ListenerConfig {
/// Per-listener TCP port. If omitted, falls back to legacy `server.port`.
#[serde(default)]
pub port: Option<u16>,
/// Per-listener client-facing TCP MSS preset or custom value.
/// Empty string disables MSS shaping for this listener.
#[serde(default)]
pub client_mss: Option<String>,
/// Per-listener SYN limiter mode.
#[serde(default)]
pub synlimit: SynLimitMode,
/// Token-bucket rate interval for the per-listener SYN limiter.
#[serde(default = "default_synlimit_seconds")]
pub synlimit_seconds: u32,
/// Token-bucket rate amount for the per-listener SYN limiter.
#[serde(default = "default_synlimit_hitcount")]
pub synlimit_hitcount: u32,
/// Token-bucket burst size for the per-listener SYN limiter.
#[serde(default = "default_synlimit_burst")]
pub synlimit_burst: u32,
/// IP address or hostname to announce in proxy links.
/// Takes precedence over `announce_ip` if both are set.
#[serde(default)]
@@ -2104,6 +2212,69 @@ pub struct ListenerConfig {
pub reuse_allow: bool,
}
/// Client-facing TCP MSS preset for extreme-low fragmentation profiles.
pub const CLIENT_MSS_EXTREME_LOW: u16 = 88;
/// Client-facing TCP MSS preset matching TSPU-oriented deployments.
pub const CLIENT_MSS_TSPU: u16 = 92;
/// Client-facing TCP MSS preset for 2-in-8 segment shaping.
pub const CLIENT_MSS_2IN8: u16 = 256;
/// Minimum accepted custom client-facing TCP MSS value.
pub const CLIENT_MSS_MIN: u16 = CLIENT_MSS_EXTREME_LOW;
/// Maximum accepted custom client-facing TCP MSS value.
pub const CLIENT_MSS_MAX: u16 = 4096;
impl ServerConfig {
/// Resolves the global client-facing TCP MSS setting.
pub fn client_mss_value(&self) -> std::result::Result<Option<u16>, String> {
parse_client_mss(self.client_mss.as_deref())
}
/// Resolves the post-handshake (bulk transfer) client MSS, if configured.
pub fn client_mss_bulk_value(&self) -> std::result::Result<Option<u16>, String> {
parse_client_mss(self.client_mss_bulk.as_deref())
}
}
impl ListenerConfig {
/// Resolves the listener MSS override, falling back to the global server value.
pub fn effective_client_mss(
&self,
server: &ServerConfig,
) -> std::result::Result<Option<u16>, String> {
match self.client_mss.as_deref() {
Some(value) => parse_client_mss(Some(value)),
None => server.client_mss_value(),
}
}
}
fn parse_client_mss(raw: Option<&str>) -> std::result::Result<Option<u16>, String> {
let Some(raw) = raw else {
return Ok(None);
};
let value = raw.trim();
if value.is_empty() {
return Ok(None);
}
match value.to_ascii_lowercase().as_str() {
"extreme-low" => return Ok(Some(CLIENT_MSS_EXTREME_LOW)),
"tspu" => return Ok(Some(CLIENT_MSS_TSPU)),
"2in8" => return Ok(Some(CLIENT_MSS_2IN8)),
_ => {}
}
let parsed = value
.parse::<u16>()
.map_err(|_| "must be \"\", extreme-low, tspu, 2in8, or a decimal value".to_string())?;
if !(CLIENT_MSS_MIN..=CLIENT_MSS_MAX).contains(&parsed) {
return Err(format!(
"custom value must be within [{CLIENT_MSS_MIN}, {CLIENT_MSS_MAX}]"
));
}
Ok(Some(parsed))
}
// ============= ShowLink =============
/// Controls which users' proxy links are displayed at startup.
+24
View File
@@ -47,6 +47,10 @@ fn default_link_port(config: &ProxyConfig) -> u16 {
.unwrap_or(config.server.port)
}
fn mss_segment_multiplier(client_mss: u16) -> u16 {
1460u16.div_ceil(client_mss)
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn bind_listeners(
config: &Arc<ProxyConfig>,
@@ -90,10 +94,22 @@ pub(crate) async fn bind_listeners(
warn!(%addr, "Skipping IPv6 listener: IPv6 disabled by [network]");
continue;
}
let client_mss = match listener_conf.effective_client_mss(&config.server) {
Ok(value) => value,
Err(error) => {
warn!(
%addr,
error = %error,
"Invalid listener client MSS after config validation; using kernel default"
);
None
}
};
let options = ListenOptions {
reuse_port: listener_conf.reuse_allow,
ipv6_only: listener_conf.ip.is_ipv6(),
backlog: config.server.listen_backlog,
client_mss,
..Default::default()
};
@@ -101,6 +117,14 @@ pub(crate) async fn bind_listeners(
Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr);
if let Some(client_mss) = client_mss {
info!(
%addr,
client_mss,
segment_multiplier = mss_segment_multiplier(client_mss),
"Client-facing TCP MSS configured"
);
}
let listener_proxy_protocol = listener_conf
.proxy_protocol
.unwrap_or(config.server.proxy_protocol);
+2
View File
@@ -208,6 +208,8 @@ pub(crate) async fn initialize_me_pool(
me_nat_probe,
None,
config.network.stun_servers.clone(),
config.network.stun_tcp_fallback,
config.network.http_ip_detect_urls.clone(),
config.general.stun_nat_probe_concurrency,
probe.detected_ipv6,
config.timeouts.me_one_retry,
+4
View File
@@ -45,6 +45,7 @@ use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool;
use crate::synlimit_control;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool;
use helpers::{
@@ -909,6 +910,9 @@ async fn run_telemt_core(
// On Unix, caller supplies privilege drop after bind (may require root for port < 1024).
drop_after_bind();
synlimit_control::reconcile_synlimit_rules(&config).await;
synlimit_control::spawn_synlimit_controller(config_rx.clone());
runtime_tasks::apply_runtime_log_filter(
has_rust_log,
&effective_log_level,
+5
View File
@@ -19,6 +19,7 @@ use tokio::signal::unix::{SignalKind, signal};
use tracing::{info, warn};
use crate::stats::Stats;
use crate::synlimit_control;
use crate::transport::middle_proxy::MePool;
use super::helpers::{format_uptime, unit_label};
@@ -102,6 +103,10 @@ async fn perform_shutdown(
let uptime_secs = process_started_at.elapsed().as_secs();
info!("Uptime: {}", format_uptime(uptime_secs));
if let Err(error) = synlimit_control::clear_synlimit_rules_all_backends().await {
warn!(error = %error, "Failed to clear SYN limiter rules during shutdown");
}
// Graceful ME pool shutdown
if let Some(pool) = &me_pool {
match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all())
+1
View File
@@ -30,6 +30,7 @@ mod service;
mod startup;
mod stats;
mod stream;
mod synlimit_control;
mod tls_front;
mod transport;
mod util;
+66 -1
View File
@@ -381,11 +381,32 @@ async fn render_tls_front_profile_health(
"# HELP telemt_tls_front_profile_info TLS front profile source and feature flags per configured domain"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_info gauge");
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_quality_info TLS front profile quality and key-share group per configured domain"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_quality_info gauge");
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_age_seconds Age of cached TLS front profile data per configured domain"
);
let _ = writeln!(out, "# TYPE telemt_tls_front_profile_age_seconds gauge");
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_server_hello_bytes TLS front cached ServerHello record body bytes per configured domain"
);
let _ = writeln!(
out,
"# TYPE telemt_tls_front_profile_server_hello_bytes gauge"
);
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_server_hello_extensions TLS front cached visible ServerHello extension count per configured domain"
);
let _ = writeln!(
out,
"# TYPE telemt_tls_front_profile_server_hello_extensions gauge"
);
let _ = writeln!(
out,
"# HELP telemt_tls_front_profile_app_data_records TLS front cached app-data record count per configured domain"
@@ -420,11 +441,26 @@ async fn render_tls_front_profile_health(
"telemt_tls_front_profile_info{{domain=\"{}\",source=\"{}\",is_default=\"{}\",has_cert_info=\"{}\",has_cert_payload=\"{}\"}} 1",
domain, item.source, item.is_default, item.has_cert_info, item.has_cert_payload
);
let _ = writeln!(
out,
"telemt_tls_front_profile_quality_info{{domain=\"{}\",quality=\"{}\",key_share_group=\"{}\"}} 1",
domain, item.quality, item.key_share_group
);
let _ = writeln!(
out,
"telemt_tls_front_profile_age_seconds{{domain=\"{}\"}} {}",
domain, item.age_seconds
);
let _ = writeln!(
out,
"telemt_tls_front_profile_server_hello_bytes{{domain=\"{}\"}} {}",
domain, item.server_hello_record_len
);
let _ = writeln!(
out,
"telemt_tls_front_profile_server_hello_extensions{{domain=\"{}\"}} {}",
domain, item.server_hello_extensions
);
let _ = writeln!(
out,
"telemt_tls_front_profile_app_data_records{{domain=\"{}\"}} {}",
@@ -3901,7 +3937,20 @@ mod tests {
session_id: Vec::new(),
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: Vec::new(),
extensions: {
let mut key_share = vec![0x00, 0x1d, 0x00, 0x20];
key_share.resize(36, 0x42);
vec![
crate::tls_front::types::TlsExtension {
ext_type: 0x002b,
data: vec![0x03, 0x04],
},
crate::tls_front::types::TlsExtension {
ext_type: 0x0033,
data: key_share,
},
]
},
},
cert_info: None,
cert_payload: Some(TlsCertPayload {
@@ -3915,6 +3964,7 @@ mod tests {
app_data_record_sizes: vec![1024, 512],
ticket_record_sizes: vec![69],
source: TlsProfileSource::Merged,
..TlsBehaviorProfile::default()
},
fetched_at: SystemTime::now(),
domain: "primary.example".to_string(),
@@ -3933,6 +3983,18 @@ mod tests {
assert!(
output.contains("telemt_tls_front_profile_info{domain=\"fallback.example\",source=\"default\",is_default=\"true\",has_cert_info=\"false\",has_cert_payload=\"false\"} 1")
);
assert!(
output.contains("telemt_tls_front_profile_quality_info{domain=\"primary.example\",quality=\"raw_strict\",key_share_group=\"x25519\"} 1")
);
assert!(
output.contains("telemt_tls_front_profile_quality_info{domain=\"fallback.example\",quality=\"fallback\",key_share_group=\"none\"} 1")
);
assert!(output.contains(
"telemt_tls_front_profile_server_hello_bytes{domain=\"primary.example\"} 90"
));
assert!(output.contains(
"telemt_tls_front_profile_server_hello_extensions{domain=\"primary.example\"} 2"
));
assert!(
output.contains(
"telemt_tls_front_profile_app_data_records{domain=\"primary.example\"} 2"
@@ -4045,7 +4107,10 @@ mod tests {
);
assert!(output.contains("# TYPE telemt_tls_front_profile_domains gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_info gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_quality_info gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_age_seconds gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_server_hello_bytes gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_server_hello_extensions gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_app_data_records gauge"));
assert!(output.contains("# TYPE telemt_tls_front_profile_ticket_records gauge"));
assert!(
+33 -7
View File
@@ -12,7 +12,7 @@ use tracing::{debug, info, warn};
use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType};
use crate::error::Result;
use crate::network::stun::{
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind,
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind_and_tcp_fallback,
};
use crate::transport::UpstreamManager;
@@ -58,6 +58,7 @@ impl NetworkDecision {
}
const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5);
const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12);
pub async fn run_probe(
config: &NetworkConfig,
@@ -81,8 +82,14 @@ pub async fn run_probe(
warn!("STUN probe is enabled but network.stun_servers is empty");
DualStunResult::default()
} else {
probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None)
.await
probe_stun_servers_parallel(
&servers,
stun_nat_probe_concurrency.max(1),
None,
None,
config.stun_tcp_fallback,
)
.await
}
} else if nat_probe {
info!("STUN probe is disabled by network.stun_use=false");
@@ -163,6 +170,7 @@ pub async fn run_probe(
stun_nat_probe_concurrency.max(1),
bind_v4,
bind_v6,
config.stun_tcp_fallback,
)
.await;
if let Some(reflected) = direct_stun_res.v4.map(|r| r.reflected_addr) {
@@ -234,7 +242,7 @@ pub async fn run_probe(
Ok(probe)
}
async fn detect_public_ipv4_http(urls: &[String]) -> Option<Ipv4Addr> {
pub(crate) async fn detect_public_ipv4_http(urls: &[String]) -> Option<Ipv4Addr> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(3))
.build()
@@ -277,6 +285,7 @@ async fn probe_stun_servers_parallel(
concurrency: usize,
bind_v4: Option<IpAddr>,
bind_v6: Option<IpAddr>,
tcp_fallback: bool,
) -> DualStunResult {
let mut join_set = JoinSet::new();
let mut next_idx = 0usize;
@@ -288,9 +297,26 @@ async fn probe_stun_servers_parallel(
let stun_addr = servers[next_idx].clone();
next_idx += 1;
join_set.spawn(async move {
let res = timeout(STUN_BATCH_TIMEOUT, async {
let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?;
let v6 = stun_probe_family_with_bind(&stun_addr, IpFamily::V6, bind_v6).await?;
let batch_timeout = if tcp_fallback {
STUN_BATCH_TCP_FALLBACK_TIMEOUT
} else {
STUN_BATCH_TIMEOUT
};
let res = timeout(batch_timeout, async {
let v4 = stun_probe_family_with_bind_and_tcp_fallback(
&stun_addr,
IpFamily::V4,
bind_v4,
tcp_fallback,
)
.await?;
let v6 = stun_probe_family_with_bind_and_tcp_fallback(
&stun_addr,
IpFamily::V6,
bind_v6,
tcp_fallback,
)
.await?;
Ok::<DualStunResult, crate::error::ProxyError>(DualStunResult { v4, v6 })
})
.await;
+241 -41
View File
@@ -4,7 +4,8 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::OnceLock;
use tokio::net::{UdpSocket, lookup_host};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpSocket, UdpSocket, lookup_host};
use tokio::time::{Duration, sleep, timeout};
use crate::crypto::SecureRandom;
@@ -36,9 +37,16 @@ pub struct DualStunResult {
}
pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
stun_probe_dual_with_tcp_fallback(stun_addr, false).await
}
pub async fn stun_probe_dual_with_tcp_fallback(
stun_addr: &str,
tcp_fallback: bool,
) -> Result<DualStunResult> {
let (v4, v6) = tokio::join!(
stun_probe_family(stun_addr, IpFamily::V4),
stun_probe_family(stun_addr, IpFamily::V6),
stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V4, tcp_fallback),
stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V6, tcp_fallback),
);
Ok(DualStunResult { v4: v4?, v6: v6? })
@@ -48,13 +56,44 @@ pub async fn stun_probe_family(
stun_addr: &str,
family: IpFamily,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind(stun_addr, family, None).await
stun_probe_family_with_tcp_fallback(stun_addr, family, false).await
}
pub async fn stun_probe_family_with_tcp_fallback(
stun_addr: &str,
family: IpFamily,
tcp_fallback: bool,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, None, tcp_fallback).await
}
pub async fn stun_probe_family_with_bind(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, bind_ip, false).await
}
pub async fn stun_probe_family_with_bind_and_tcp_fallback(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
tcp_fallback: bool,
) -> Result<Option<StunProbeResult>> {
let udp_attempts = if tcp_fallback { 1 } else { 3 };
let udp_result = stun_probe_family_udp(stun_addr, family, bind_ip, udp_attempts).await?;
if udp_result.is_some() || !tcp_fallback {
return Ok(udp_result);
}
stun_probe_family_tcp(stun_addr, family, bind_ip).await
}
async fn stun_probe_family_udp(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
max_attempts: u8,
) -> Result<Option<StunProbeResult>> {
let bind_addr = match (family, bind_ip) {
(IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0),
@@ -94,12 +133,7 @@ pub async fn stun_probe_family_with_bind(
return Ok(None);
}
let mut req = [0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
stun_rng().fill(&mut req[8..20]); // transaction ID
let req = build_binding_request();
let mut buf = [0u8; 256];
let mut attempt = 0;
let mut backoff = Duration::from_secs(1);
@@ -115,7 +149,7 @@ pub async fn stun_probe_family_with_bind(
Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN recv failed: {e}"))),
Err(_) => {
attempt += 1;
if attempt >= 3 {
if attempt >= max_attempts {
return Ok(None);
}
sleep(backoff).await;
@@ -128,19 +162,139 @@ pub async fn stun_probe_family_with_bind(
return Ok(None);
}
let magic = 0x2112A442u32.to_be_bytes();
let txid = &req[8..20];
let mut idx = 20;
while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4;
if idx + alen > n {
break;
}
if let Some(reflected_addr) = parse_reflected_addr(&buf[..n], txid) {
let local_addr = socket
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}));
}
}
match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
Ok(None)
}
async fn stun_probe_family_tcp(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
) -> Result<Option<StunProbeResult>> {
let target_addr = match resolve_stun_addr(stun_addr, family).await? {
Some(addr) => addr,
None => return Ok(None),
};
let socket = match family {
IpFamily::V4 => TcpSocket::new_v4(),
IpFamily::V6 => TcpSocket::new_v6(),
}
.map_err(|e| ProxyError::Proxy(format!("STUN TCP socket failed: {e}")))?;
match (family, bind_ip) {
(IpFamily::V4, Some(IpAddr::V4(ip))) => {
if socket.bind(SocketAddr::new(IpAddr::V4(ip), 0)).is_err() {
return Ok(None);
}
}
(IpFamily::V6, Some(IpAddr::V6(ip))) => {
if socket.bind(SocketAddr::new(IpAddr::V6(ip), 0)).is_err() {
return Ok(None);
}
}
(IpFamily::V4, Some(IpAddr::V6(_))) | (IpFamily::V6, Some(IpAddr::V4(_))) => {
return Ok(None);
}
(_, None) => {}
}
let connect_res = timeout(Duration::from_secs(3), socket.connect(target_addr)).await;
let mut stream = match connect_res {
Ok(Ok(stream)) => stream,
Ok(Err(e))
if family == IpFamily::V6
&& matches!(
e.kind(),
std::io::ErrorKind::NetworkUnreachable
| std::io::ErrorKind::HostUnreachable
| std::io::ErrorKind::Unsupported
| std::io::ErrorKind::NetworkDown
) =>
{
return Ok(None);
}
Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN TCP connect failed: {e}"))),
Err(_) => return Ok(None),
};
let req = build_binding_request();
timeout(Duration::from_secs(3), stream.write_all(&req))
.await
.map_err(|_| ProxyError::Proxy("STUN TCP send timeout".to_string()))?
.map_err(|e| ProxyError::Proxy(format!("STUN TCP send failed: {e}")))?;
let mut header = [0u8; 20];
timeout(Duration::from_secs(3), stream.read_exact(&mut header))
.await
.map_err(|_| ProxyError::Proxy("STUN TCP header timeout".to_string()))?
.map_err(|e| ProxyError::Proxy(format!("STUN TCP header read failed: {e}")))?;
let body_len = u16::from_be_bytes([header[2], header[3]]) as usize;
if body_len > 236 {
return Ok(None);
}
let mut buf = [0u8; 256];
buf[..20].copy_from_slice(&header);
if body_len > 0 {
timeout(
Duration::from_secs(3),
stream.read_exact(&mut buf[20..20 + body_len]),
)
.await
.map_err(|_| ProxyError::Proxy("STUN TCP body timeout".to_string()))?
.map_err(|e| ProxyError::Proxy(format!("STUN TCP body read failed: {e}")))?;
}
let txid = &req[8..20];
let Some(reflected_addr) = parse_reflected_addr(&buf[..20 + body_len], txid) else {
return Ok(None);
};
let local_addr = stream
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN TCP local_addr failed: {e}")))?;
Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}))
}
fn build_binding_request() -> [u8; 20] {
let mut req = [0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes());
req[2..4].copy_from_slice(&0u16.to_be_bytes());
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
stun_rng().fill(&mut req[8..20]);
req
}
fn parse_reflected_addr(buf: &[u8], txid: &[u8]) -> Option<SocketAddr> {
if buf.len() < 20 {
return None;
}
let magic = 0x2112A442u32.to_be_bytes();
let mut idx = 20;
while idx + 4 <= buf.len() {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().ok()?);
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().ok()?) as usize;
idx += 4;
if idx + alen > buf.len() {
break;
}
match atype {
0x0020 | 0x0001 => {
if alen < 8 {
break;
}
@@ -157,7 +311,6 @@ pub async fn stun_probe_family_with_bind(
let raw_ip = &buf[idx + 4..idx + 4 + len_check];
let mut port = u16::from_be_bytes(port_bytes);
let reflected_ip = if atype == 0x0020 {
port ^= ((magic[0] as u16) << 8) | magic[1] as u16;
match family_byte {
@@ -172,7 +325,9 @@ pub async fn stun_probe_family_with_bind(
}
0x02 => {
let mut ip = [0u8; 16];
let xor_key = [magic.as_slice(), txid].concat();
let mut xor_key = [0u8; 16];
xor_key[..4].copy_from_slice(&magic);
xor_key[4..].copy_from_slice(txid.get(..12)?);
for (i, b) in raw_ip.iter().enumerate().take(16) {
ip[i] = *b ^ xor_key[i];
}
@@ -185,34 +340,24 @@ pub async fn stun_probe_family_with_bind(
}
} else {
match family_byte {
0x01 => IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])),
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).unwrap())),
0x01 => {
IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3]))
}
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).ok()?)),
_ => {
idx += (alen + 3) & !3;
continue;
}
}
};
let reflected_addr = SocketAddr::new(reflected_ip, port);
let local_addr = socket
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}));
return Some(SocketAddr::new(reflected_ip, port));
}
_ => {}
}
idx += (alen + 3) & !3;
}
idx += (alen + 3) & !3;
}
Ok(None)
None
}
async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<SocketAddr>> {
@@ -245,3 +390,58 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<S
});
Ok(target)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_reflected_addr_reads_mapped_ipv4() {
let txid = [0u8; 12];
let mut response = [0u8; 32];
response[0..2].copy_from_slice(&0x0101u16.to_be_bytes());
response[2..4].copy_from_slice(&12u16.to_be_bytes());
response[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
response[20..22].copy_from_slice(&0x0001u16.to_be_bytes());
response[22..24].copy_from_slice(&8u16.to_be_bytes());
response[25] = 0x01;
response[26..28].copy_from_slice(&443u16.to_be_bytes());
response[28..32].copy_from_slice(&[203, 0, 113, 9]);
let reflected = parse_reflected_addr(&response, &txid).unwrap();
assert_eq!(
reflected,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443)
);
}
#[test]
fn parse_reflected_addr_reads_xor_mapped_ipv4() {
let txid = [0u8; 12];
let magic = 0x2112A442u32.to_be_bytes();
let port = 443u16;
let ip = [203u8, 0, 113, 9];
let xport = port ^ (((magic[0] as u16) << 8) | magic[1] as u16);
let xip = [
ip[0] ^ magic[0],
ip[1] ^ magic[1],
ip[2] ^ magic[2],
ip[3] ^ magic[3],
];
let mut response = [0u8; 32];
response[0..2].copy_from_slice(&0x0101u16.to_be_bytes());
response[2..4].copy_from_slice(&12u16.to_be_bytes());
response[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
response[20..22].copy_from_slice(&0x0020u16.to_be_bytes());
response[22..24].copy_from_slice(&8u16.to_be_bytes());
response[25] = 0x01;
response[26..28].copy_from_slice(&xport.to_be_bytes());
response[28..32].copy_from_slice(&xip);
let reflected = parse_reflected_addr(&response, &txid).unwrap();
assert_eq!(
reflected,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443)
);
}
}
+20 -15
View File
@@ -5,6 +5,9 @@
use std::net::{IpAddr, Ipv4Addr};
use crate::crypto::SecureRandom;
use crate::protocol::framing::{
secure_version_d_body_len_from_wire_len, secure_version_d_padding_len,
};
use std::sync::LazyLock;
// ============= Telegram Datacenters =============
@@ -236,22 +239,20 @@ pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
}
/// Compute Secure Intermediate payload length from wire length.
/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary.
/// Secure mode cannot distinguish full-word padding from payload, so only the
/// non-aligned tail bytes are stripped.
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
}
Some(wire_len - (wire_len % 4))
secure_version_d_body_len_from_wire_len(wire_len)
}
/// Generate padding length for Secure Intermediate protocol.
/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4.
/// Telegram Desktop uses a 4-bit random padding length for VersionD packets.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
debug_assert!(
is_valid_secure_payload_len(data_len),
"Secure payload must be 4-byte aligned, got {data_len}"
);
rng.range(3) + 1
secure_version_d_padding_len(rng)
}
// ============= Timeouts =============
@@ -424,21 +425,15 @@ mod tests {
}
#[test]
fn secure_padding_never_produces_aligned_total() {
fn secure_padding_matches_tdesktop_range() {
let rng = SecureRandom::new();
for data_len in (0..1000).step_by(4) {
for _ in 0..100 {
let padding = secure_padding_len(data_len, &rng);
assert!(
padding <= 3,
padding <= 15,
"padding out of range: data_len={data_len}, padding={padding}"
);
assert_ne!(
(data_len + padding) % 4,
0,
"invariant violated: data_len={data_len}, padding={padding}, total={}",
data_len + padding
);
}
}
}
@@ -454,6 +449,16 @@ mod tests {
}
}
#[test]
fn secure_wire_len_preserves_full_word_tail() {
let payload_len = 64;
for padding in [4usize, 8, 12] {
let wire_len = payload_len + padding;
let recovered = secure_payload_len_from_wire_len(wire_len);
assert_eq!(recovered, Some(wire_len));
}
}
#[test]
fn secure_wire_len_rejects_too_short_frames() {
assert_eq!(secure_payload_len_from_wire_len(0), None);
+92
View File
@@ -0,0 +1,92 @@
//! Shared MTProto transport framing helpers.
use crate::crypto::SecureRandom;
/// QuickACK marker bit used by Intermediate and Secure Intermediate headers.
pub(crate) const INTERMEDIATE_QUICKACK_FLAG: u32 = 0x8000_0000;
/// Payload length mask used by Intermediate and Secure Intermediate headers.
pub(crate) const INTERMEDIATE_WIRE_LEN_MASK: u32 = 0x7fff_ffff;
/// Maximum random tail length used by Telegram Desktop VersionD packets.
pub(crate) const SECURE_VERSION_D_PADDING_MAX: usize = 15;
/// Parsed Intermediate/Secure Intermediate length header.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct IntermediateHeader {
/// Payload length on the wire, excluding the four-byte header.
pub(crate) wire_len: usize,
/// Whether the QuickACK marker bit was set in the length header.
pub(crate) quickack: bool,
}
/// Parse an Intermediate/Secure Intermediate length header.
pub(crate) fn parse_intermediate_header(header: [u8; 4]) -> IntermediateHeader {
let raw = u32::from_le_bytes(header);
IntermediateHeader {
wire_len: (raw & INTERMEDIATE_WIRE_LEN_MASK) as usize,
quickack: (raw & INTERMEDIATE_QUICKACK_FLAG) != 0,
}
}
/// Encode an Intermediate/Secure Intermediate length header.
pub(crate) fn encode_intermediate_header(wire_len: usize, quickack: bool) -> Option<u32> {
if wire_len > INTERMEDIATE_WIRE_LEN_MASK as usize {
return None;
}
let mut raw = u32::try_from(wire_len).ok()?;
if quickack {
raw |= INTERMEDIATE_QUICKACK_FLAG;
}
Some(raw)
}
/// Recover the VersionD body length visible to MTProto from the encrypted wire length.
pub(crate) fn secure_version_d_body_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
}
Some(wire_len - (wire_len % 4))
}
/// Generate Telegram Desktop-compatible VersionD random tail length.
pub(crate) fn secure_version_d_padding_len(rng: &SecureRandom) -> usize {
rng.range(SECURE_VERSION_D_PADDING_MAX + 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn intermediate_header_roundtrip_preserves_quickack_zero_length() {
let encoded = encode_intermediate_header(0, true).unwrap();
assert_eq!(encoded, INTERMEDIATE_QUICKACK_FLAG);
let parsed = parse_intermediate_header(encoded.to_le_bytes());
assert_eq!(parsed.wire_len, 0);
assert!(parsed.quickack);
}
#[test]
fn intermediate_header_rejects_lengths_above_31_bits() {
assert_eq!(
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize, false),
Some(INTERMEDIATE_WIRE_LEN_MASK)
);
assert_eq!(
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize + 1, false),
None
);
}
#[test]
fn secure_version_d_body_len_strips_only_non_word_tail() {
assert_eq!(secure_version_d_body_len_from_wire_len(3), None);
assert_eq!(secure_version_d_body_len_from_wire_len(8), Some(8));
assert_eq!(secure_version_d_body_len_from_wire_len(11), Some(8));
assert_eq!(secure_version_d_body_len_from_wire_len(12), Some(12));
}
}
+1
View File
@@ -2,6 +2,7 @@
pub mod constants;
pub mod frame;
pub(crate) mod framing;
pub mod obfuscation;
pub mod tls;
pub mod tls_fingerprint;
+359 -12
View File
@@ -1239,6 +1239,18 @@ fn test_gen_fake_x25519_key() {
assert_ne!(key1, key2);
}
#[test]
fn test_gen_fake_x25519mlkem768_server_key_share_shape() {
let rng = crate::crypto::SecureRandom::new();
let key_share = gen_fake_x25519mlkem768_server_key_share(&rng);
assert_eq!(key_share.len(), X25519MLKEM768_SERVER_KEY_SHARE_LEN);
assert!(
key_share.iter().any(|byte| *byte != 0),
"hybrid ServerHello key_share must not collapse to all-zero bytes"
);
}
#[test]
fn test_fake_x25519_key_is_nonzero_and_varies() {
let rng = crate::crypto::SecureRandom::new();
@@ -1325,6 +1337,69 @@ fn server_hello_extension_types(record: &[u8]) -> Vec<u16> {
out
}
fn server_hello_key_share(record: &[u8]) -> Option<(u16, usize)> {
if record.len() < 9 || record[0] != TLS_RECORD_HANDSHAKE || record[5] != 0x02 {
return None;
}
let record_len = u16::from_be_bytes([record[3], record[4]]) as usize;
if record.len() < 5 + record_len {
return None;
}
let hs_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
let hs_start = 5;
let hs_end = hs_start + 4 + hs_len;
if hs_end > record.len() {
return None;
}
let mut pos = hs_start + 4 + 2 + 32;
if pos >= hs_end {
return None;
}
let sid_len = record[pos] as usize;
pos += 1 + sid_len;
if pos + 2 + 1 + 2 > hs_end {
return None;
}
pos += 2 + 1;
let ext_len = u16::from_be_bytes([record[pos], record[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len;
if ext_end > hs_end {
return None;
}
while pos + 4 <= ext_end {
let etype = u16::from_be_bytes([record[pos], record[pos + 1]]);
let elen = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize;
pos += 4;
if pos + elen > ext_end {
return None;
}
if etype == extension_type::KEY_SHARE {
if elen < 4 {
return None;
}
let group = u16::from_be_bytes([record[pos], record[pos + 1]]);
let key_exchange_len = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize;
if 4 + key_exchange_len != elen {
return None;
}
return Some((group, key_exchange_len));
}
pos += elen;
}
None
}
fn test_server_key_share(group: u16, len: usize) -> ServerHelloKeyShare {
ServerHelloKeyShare::new(group, vec![0x42; len])
}
#[test]
fn build_server_hello_never_places_alpn_in_server_hello_extensions() {
let secret = b"alpn_sh_forbidden";
@@ -1372,6 +1447,7 @@ fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() {
app_data_record_sizes: vec![1024],
ticket_record_sizes: Vec::new(),
source: TlsProfileSource::Default,
..TlsBehaviorProfile::default()
},
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
@@ -1386,6 +1462,10 @@ fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_SERVER_KEY_SHARE_LEN,
),
&rng,
Some(b"h2".to_vec()),
0,
@@ -1395,14 +1475,21 @@ fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() {
!exts.contains(&0x0010),
"ALPN extension must not appear in emulated ServerHello"
);
assert_eq!(
server_hello_key_share(&response),
Some((
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_SERVER_KEY_SHARE_LEN
))
);
}
#[test]
fn test_tls_extension_builder() {
let key = [0x42u8; 32];
let key = vec![0x42u8; X25519MLKEM768_SERVER_KEY_SHARE_LEN];
let mut builder = TlsExtensionBuilder::new();
builder.add_key_share(&key);
builder.add_key_share(TLS_NAMED_GROUP_X25519MLKEM768, &key);
builder.add_supported_versions(0x0304);
let result = builder.build();
@@ -1415,10 +1502,10 @@ fn test_tls_extension_builder() {
#[test]
fn test_server_hello_builder() {
let session_id = vec![0x01, 0x02, 0x03, 0x04];
let key = [0x55u8; 32];
let key = vec![0x55u8; X25519MLKEM768_SERVER_KEY_SHARE_LEN];
let builder = ServerHelloBuilder::new(session_id.clone())
.with_x25519_key(&key)
.with_key_share(TLS_NAMED_GROUP_X25519MLKEM768, &key)
.with_tls13_version();
let record = builder.build_record();
@@ -1452,6 +1539,41 @@ fn test_build_server_hello_structure() {
let app_start = ccs_start + ccs_len;
assert!(response.len() > app_start + 5);
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
assert_eq!(
server_hello_key_share(&response),
Some((
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_SERVER_KEY_SHARE_LEN
))
);
}
#[test]
fn test_build_server_hello_with_cipher_uses_selected_key_share_group() {
let secret = b"test secret";
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let key_share =
ServerHelloKeyShare::new(TLS_NAMED_GROUP_X25519, vec![0x55u8; X25519_KEY_SHARE_LEN]);
let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello_with_cipher(
secret,
&client_digest,
&session_id,
2048,
&rng,
[0x13, 0x01],
&key_share,
None,
0,
);
assert_eq!(
server_hello_key_share(&response),
Some((TLS_NAMED_GROUP_X25519, X25519_KEY_SHARE_LEN))
);
}
#[test]
@@ -1474,10 +1596,10 @@ fn test_build_server_hello_digest() {
#[test]
fn test_server_hello_extensions_length() {
let session_id = vec![0x01; 32];
let key = [0x55u8; 32];
let key = vec![0x55u8; X25519MLKEM768_SERVER_KEY_SHARE_LEN];
let builder = ServerHelloBuilder::new(session_id)
.with_x25519_key(&key)
.with_key_share(TLS_NAMED_GROUP_X25519MLKEM768, &key)
.with_tls13_version();
let record = builder.build_record();
@@ -1513,6 +1635,39 @@ fn build_client_hello_with_exts(exts: Vec<(u16, Vec<u8>)>, host: &str) -> Vec<u8
build_client_hello_with_ciphers_and_exts(&[[0x13, 0x01]], exts, host)
}
fn client_key_share_extension(entries: &[(u16, usize)]) -> Vec<u8> {
let mut shares = Vec::new();
for (group, key_exchange_len) in entries {
assert!(*key_exchange_len <= u16::MAX as usize);
shares.extend_from_slice(&group.to_be_bytes());
shares.extend_from_slice(&(*key_exchange_len as u16).to_be_bytes());
let start = shares.len();
shares.resize(start + *key_exchange_len, 0x42);
}
assert!(shares.len() <= u16::MAX as usize);
let mut extension = Vec::new();
extension.extend_from_slice(&(shares.len() as u16).to_be_bytes());
extension.extend_from_slice(&shares);
extension
}
fn client_key_share_extension_with_payloads(entries: &[(u16, &[u8])]) -> Vec<u8> {
let mut shares = Vec::new();
for (group, key_exchange) in entries {
assert!(key_exchange.len() <= u16::MAX as usize);
shares.extend_from_slice(&group.to_be_bytes());
shares.extend_from_slice(&(key_exchange.len() as u16).to_be_bytes());
shares.extend_from_slice(key_exchange);
}
assert!(shares.len() <= u16::MAX as usize);
let mut extension = Vec::new();
extension.extend_from_slice(&(shares.len() as u16).to_be_bytes());
extension.extend_from_slice(&shares);
extension
}
fn build_client_hello_with_ciphers_and_exts(
cipher_suites: &[[u8; 2]],
exts: Vec<(u16, Vec<u8>)>,
@@ -1674,7 +1829,7 @@ fn select_server_hello_cipher_suite_keeps_profile_cipher_when_offered() {
);
assert_eq!(
select_server_hello_cipher_suite(&ch, [0x13, 0x03]),
[0x13, 0x03]
Some([0x13, 0x03])
);
}
@@ -1687,30 +1842,222 @@ fn select_server_hello_cipher_suite_ignores_profile_tls12_cipher() {
);
assert_eq!(
select_server_hello_cipher_suite(&ch, [0xc0, 0x2f]),
[0x13, 0x03]
Some([0x13, 0x03])
);
}
#[test]
fn select_server_hello_cipher_suite_rejects_without_offered_tls13_suite() {
let ch = build_client_hello_with_ciphers_and_exts(&[[0xc0, 0x2f]], Vec::new(), "example.com");
assert_eq!(select_server_hello_cipher_suite(&ch, [0x13, 0x01]), None);
}
#[test]
fn select_server_hello_cipher_suite_falls_back_to_offered_tls13_suite() {
let ch = build_client_hello_with_ciphers_and_exts(&[[0x13, 0x03]], Vec::new(), "example.com");
assert_eq!(
select_server_hello_cipher_suite(&ch, [0x13, 0x01]),
[0x13, 0x03]
Some([0x13, 0x03])
);
}
#[test]
fn select_server_hello_cipher_suite_keeps_preferred_for_malformed_clienthello() {
fn select_server_hello_cipher_suite_rejects_malformed_clienthello() {
let mut ch =
build_client_hello_with_ciphers_and_exts(&[[0x13, 0x03]], Vec::new(), "example.com");
ch.truncate(12);
assert_eq!(select_server_hello_cipher_suite(&ch, [0x13, 0x01]), None);
}
#[test]
fn select_server_hello_key_share_group_prefers_hybrid_when_valid_share_is_offered() {
let key_share = client_key_share_extension(&[
(0x0a0a, 1),
(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
),
(TLS_NAMED_GROUP_X25519, X25519_KEY_SHARE_LEN),
]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
assert_eq!(
select_server_hello_cipher_suite(&ch, [0x13, 0x01]),
[0x13, 0x01]
select_server_hello_key_share_group(&ch),
Some(TLS_NAMED_GROUP_X25519MLKEM768)
);
}
#[test]
fn select_server_hello_key_share_group_prefers_profiled_x25519_when_valid_share_is_offered() {
let key_share = client_key_share_extension(&[
(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
),
(TLS_NAMED_GROUP_X25519, X25519_KEY_SHARE_LEN),
]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
assert_eq!(
select_server_hello_key_share_group_with_preference(&ch, Some(TLS_NAMED_GROUP_X25519)),
Some(TLS_NAMED_GROUP_X25519)
);
}
#[test]
fn build_x25519mlkem768_server_key_share_accepts_tdesktop_canonical_share() {
let key_share = client_key_share_extension(&[
(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
),
(TLS_NAMED_GROUP_X25519, X25519_KEY_SHARE_LEN),
]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
let rng = crate::crypto::SecureRandom::new();
let server_key_share = build_x25519mlkem768_server_key_share(&ch, &rng)
.expect("tdesktop-like canonical share must build a ServerHello share");
assert_eq!(server_key_share.len(), X25519MLKEM768_SERVER_KEY_SHARE_LEN);
assert!(
server_key_share[..MLKEM768_SERVER_CIPHERTEXT_LEN]
.iter()
.any(|byte| *byte != 0),
"ML-KEM ciphertext must not be all zero"
);
assert!(
server_key_share[MLKEM768_SERVER_CIPHERTEXT_LEN..]
.iter()
.any(|byte| *byte != 0),
"X25519 server share must not be all zero"
);
}
#[test]
fn build_x25519_server_key_share_accepts_tdesktop_fallback_share() {
let key_share = client_key_share_extension(&[
(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
),
(TLS_NAMED_GROUP_X25519, X25519_KEY_SHARE_LEN),
]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
let rng = crate::crypto::SecureRandom::new();
let server_key_share = build_x25519_server_key_share(&ch, &rng)
.expect("tdesktop-like X25519 share must build a ServerHello share");
assert_eq!(server_key_share.len(), X25519_KEY_SHARE_LEN);
assert!(
server_key_share.iter().any(|byte| *byte != 0),
"X25519 server share must not be all zero"
);
}
#[test]
fn build_server_hello_key_share_prefers_profiled_x25519() {
let key_share = client_key_share_extension(&[
(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
),
(TLS_NAMED_GROUP_X25519, X25519_KEY_SHARE_LEN),
]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
let rng = crate::crypto::SecureRandom::new();
let server_key_share = build_server_hello_key_share(&ch, Some(TLS_NAMED_GROUP_X25519), &rng)
.expect("profiled X25519 share must be selected when client offers it");
assert_eq!(server_key_share.group(), TLS_NAMED_GROUP_X25519);
assert_eq!(server_key_share.key_exchange().len(), X25519_KEY_SHARE_LEN);
}
#[test]
fn build_server_hello_key_share_falls_back_from_bad_profiled_x25519_to_hybrid() {
let key_share = client_key_share_extension(&[(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
)]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
let rng = crate::crypto::SecureRandom::new();
let server_key_share = build_server_hello_key_share(&ch, Some(TLS_NAMED_GROUP_X25519), &rng)
.expect("hybrid share must be selected when profiled X25519 is unavailable");
assert_eq!(server_key_share.group(), TLS_NAMED_GROUP_X25519MLKEM768);
assert_eq!(
server_key_share.key_exchange().len(),
X25519MLKEM768_SERVER_KEY_SHARE_LEN
);
}
#[test]
fn build_x25519mlkem768_server_key_share_rejects_noncanonical_mlkem_key() {
let mut key_exchange = vec![0x42; X25519MLKEM768_CLIENT_KEY_SHARE_LEN];
key_exchange[..3].copy_from_slice(&[0xff, 0xff, 0xff]);
let key_share = client_key_share_extension_with_payloads(&[(
TLS_NAMED_GROUP_X25519MLKEM768,
&key_exchange,
)]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
let rng = crate::crypto::SecureRandom::new();
assert!(build_x25519mlkem768_server_key_share(&ch, &rng).is_none());
}
#[test]
fn build_x25519mlkem768_server_key_share_rejects_all_zero_x25519_share() {
let mut key_exchange = vec![0x42; X25519MLKEM768_CLIENT_KEY_SHARE_LEN];
key_exchange[MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN..].fill(0);
let key_share = client_key_share_extension_with_payloads(&[(
TLS_NAMED_GROUP_X25519MLKEM768,
&key_exchange,
)]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
let rng = crate::crypto::SecureRandom::new();
assert!(build_x25519mlkem768_server_key_share(&ch, &rng).is_none());
}
#[test]
fn select_server_hello_key_share_group_accepts_x25519_when_hybrid_is_absent() {
let key_share = client_key_share_extension(&[(TLS_NAMED_GROUP_X25519, X25519_KEY_SHARE_LEN)]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
assert_eq!(
select_server_hello_key_share_group(&ch),
Some(TLS_NAMED_GROUP_X25519)
);
}
#[test]
fn select_server_hello_key_share_group_rejects_malformed_hybrid_len() {
let key_share = client_key_share_extension(&[(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN - 1,
)]);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
assert_eq!(select_server_hello_key_share_group(&ch), None);
}
#[test]
fn select_server_hello_key_share_group_rejects_malformed_key_share_tail() {
let mut key_share = client_key_share_extension(&[(
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
)]);
let shares_len = u16::from_be_bytes([key_share[0], key_share[1]]) + 1;
key_share[0..2].copy_from_slice(&shares_len.to_be_bytes());
key_share.push(0);
let ch = build_client_hello_with_exts(vec![(0x0033, key_share)], "example.com");
assert_eq!(select_server_hello_key_share_group(&ch), None);
}
#[test]
fn extract_sni_rejects_zero_length_host_name() {
let mut sni_ext = Vec::new();
+378 -29
View File
@@ -65,6 +65,7 @@ use super::constants::*;
use crate::crypto::{SecureRandom, sha256_hmac};
#[cfg(test)]
use crate::error::ProxyError;
use ml_kem::{B32, EncapsulationKey as MlKemEncapsulationKey, Key as MlKemKey, MlKem768};
use std::time::{SystemTime, UNIX_EPOCH};
use subtle::ConstantTimeEq;
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
@@ -109,9 +110,45 @@ mod cipher_suite {
pub const TLS_CHACHA20_POLY1305_SHA256: [u8; 2] = [0x13, 0x03];
}
/// TLS Named Curves
/// TLS named groups used in KeyShare extensions.
mod named_curve {
pub const X25519: u16 = 0x001d;
pub const X25519MLKEM768: u16 = 0x11ec;
}
/// TLS X25519 named group.
pub(crate) const TLS_NAMED_GROUP_X25519: u16 = named_curve::X25519;
/// TLS X25519MLKEM768 named group.
pub(crate) const TLS_NAMED_GROUP_X25519MLKEM768: u16 = named_curve::X25519MLKEM768;
const X25519_KEY_SHARE_LEN: usize = 32;
const X25519MLKEM768_CLIENT_KEY_SHARE_LEN: usize = 1216;
const X25519MLKEM768_SERVER_KEY_SHARE_LEN: usize = 1120;
const MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN: usize = 1184;
const MLKEM768_SERVER_CIPHERTEXT_LEN: usize = 1088;
/// ServerHello key_share selected for the authenticated ClientHello.
#[derive(Clone, Debug)]
pub(crate) struct ServerHelloKeyShare {
group: u16,
key_exchange: Vec<u8>,
}
impl ServerHelloKeyShare {
pub(crate) fn new(group: u16, key_exchange: Vec<u8>) -> Self {
Self {
group,
key_exchange,
}
}
pub(crate) fn group(&self) -> u16 {
self.group
}
pub(crate) fn key_exchange(&self) -> &[u8] {
&self.key_exchange
}
}
// ============= TLS Validation Result =============
@@ -144,26 +181,28 @@ impl TlsExtensionBuilder {
}
}
/// Add Key Share extension with X25519 key
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
/// Add KeyShare extension with the selected named group.
fn add_key_share(&mut self, group: u16, key_exchange: &[u8]) -> &mut Self {
let Ok(key_exchange_len) = u16::try_from(key_exchange.len()) else {
return self;
};
let Some(entry_len) = key_exchange.len().checked_add(4) else {
return self;
};
let Ok(entry_len) = u16::try_from(entry_len) else {
return self;
};
// Extension type: key_share (0x0033)
self.extensions
.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
// Extension data length
let entry_len: u16 = 2 + 2 + 32; // curve + length + key
// ServerHello key_share data is exactly one KeyShareEntry.
self.extensions.extend_from_slice(&entry_len.to_be_bytes());
// Named curve: x25519
self.extensions.extend_from_slice(&group.to_be_bytes());
self.extensions
.extend_from_slice(&named_curve::X25519.to_be_bytes());
// Key length
self.extensions.extend_from_slice(&(32u16).to_be_bytes());
// Key data
self.extensions.extend_from_slice(public_key);
.extend_from_slice(&key_exchange_len.to_be_bytes());
self.extensions.extend_from_slice(key_exchange);
self
}
@@ -232,8 +271,8 @@ impl ServerHelloBuilder {
}
}
fn with_x25519_key(mut self, key: &[u8; 32]) -> Self {
self.extensions.add_key_share(key);
fn with_key_share(mut self, group: u16, key_exchange: &[u8]) -> Self {
self.extensions.add_key_share(group, key_exchange);
self
}
@@ -508,9 +547,137 @@ fn validate_tls_handshake_at_time_with_boot_cap(
/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint,
/// yielding distribution-consistent public keys for anti-fingerprinting.
pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let mut scalar = [0u8; 32];
scalar.copy_from_slice(&rng.bytes(32));
x25519(scalar, X25519_BASEPOINT_BYTES)
let (_scalar, public_key) = gen_x25519_key_pair(rng);
public_key
}
fn gen_x25519_key_pair(rng: &SecureRandom) -> ([u8; 32], [u8; 32]) {
let mut scalar = [0u8; X25519_KEY_SHARE_LEN];
rng.fill(&mut scalar);
let public_key = x25519(scalar, X25519_BASEPOINT_BYTES);
(scalar, public_key)
}
/// Generate a fake X25519MLKEM768 ServerHello key_share payload.
pub(crate) fn gen_fake_x25519mlkem768_server_key_share(rng: &SecureRandom) -> Vec<u8> {
let mut key_share = vec![0u8; X25519MLKEM768_SERVER_KEY_SHARE_LEN];
// FakeTLS never derives TLS traffic secrets from this payload; only the
// externally visible named group and vector lengths are protocol-facing.
rng.fill(&mut key_share[..MLKEM768_SERVER_CIPHERTEXT_LEN]);
let x25519_key = gen_fake_x25519_key(rng);
key_share[MLKEM768_SERVER_CIPHERTEXT_LEN..].copy_from_slice(&x25519_key);
key_share
}
fn mlkem768_encapsulate_to_client(client_key: &[u8], rng: &SecureRandom) -> Option<Vec<u8>> {
let key_bytes = MlKemKey::<MlKemEncapsulationKey<MlKem768>>::try_from(client_key).ok()?;
let encapsulation_key = MlKemEncapsulationKey::<MlKem768>::new(&key_bytes).ok()?;
let mut randomness = [0u8; 32];
rng.fill(&mut randomness);
let randomness = B32::try_from(randomness.as_slice()).ok()?;
let (ciphertext, _shared_key) = encapsulation_key.encapsulate_deterministic(&randomness);
let ciphertext = ciphertext.as_slice().to_vec();
if ciphertext.len() == MLKEM768_SERVER_CIPHERTEXT_LEN {
Some(ciphertext)
} else {
None
}
}
/// Build a valid X25519MLKEM768 ServerHello key_share for the authenticated ClientHello.
pub(crate) fn build_x25519mlkem768_server_key_share(
handshake: &[u8],
rng: &SecureRandom,
) -> Option<Vec<u8>> {
let client_key_exchange = client_hello_key_share_group_entry(
handshake,
TLS_NAMED_GROUP_X25519MLKEM768,
X25519MLKEM768_CLIENT_KEY_SHARE_LEN,
)?;
let client_mlkem_key = client_key_exchange.get(..MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN)?;
let client_x25519_key = client_key_exchange.get(MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN..)?;
let mlkem_ciphertext = mlkem768_encapsulate_to_client(client_mlkem_key, rng)?;
let mut client_x25519 = [0u8; X25519_KEY_SHARE_LEN];
client_x25519.copy_from_slice(client_x25519_key);
let (server_x25519_scalar, server_x25519_key) = gen_x25519_key_pair(rng);
let x25519_shared = x25519(server_x25519_scalar, client_x25519);
if bool::from(x25519_shared.ct_eq(&[0u8; X25519_KEY_SHARE_LEN])) {
return None;
}
let mut key_share = Vec::with_capacity(X25519MLKEM768_SERVER_KEY_SHARE_LEN);
key_share.extend_from_slice(&mlkem_ciphertext);
key_share.extend_from_slice(&server_x25519_key);
Some(key_share)
}
/// Build a valid X25519 ServerHello key_share for the authenticated ClientHello.
pub(crate) fn build_x25519_server_key_share(
handshake: &[u8],
rng: &SecureRandom,
) -> Option<Vec<u8>> {
let client_key_exchange = client_hello_key_share_group_entry(
handshake,
TLS_NAMED_GROUP_X25519,
X25519_KEY_SHARE_LEN,
)?;
let mut client_x25519 = [0u8; X25519_KEY_SHARE_LEN];
client_x25519.copy_from_slice(client_key_exchange);
let (server_x25519_scalar, server_x25519_key) = gen_x25519_key_pair(rng);
let x25519_shared = x25519(server_x25519_scalar, client_x25519);
if bool::from(x25519_shared.ct_eq(&[0u8; X25519_KEY_SHARE_LEN])) {
return None;
}
Some(server_x25519_key.to_vec())
}
fn build_server_hello_key_share_for_group(
handshake: &[u8],
group: u16,
rng: &SecureRandom,
) -> Option<ServerHelloKeyShare> {
let expected_key_exchange_len = client_hello_key_share_group_len(group)?;
client_hello_key_share_group_entry(handshake, group, expected_key_exchange_len)?;
// FakeTLS clients validate ServerHello shape and digest, not TLS traffic
// secrets, so the response must mirror the offered group without binding to
// the camouflage key bytes embedded in ClientHello.
match group {
TLS_NAMED_GROUP_X25519MLKEM768 => Some(ServerHelloKeyShare::new(
group,
gen_fake_x25519mlkem768_server_key_share(rng),
)),
TLS_NAMED_GROUP_X25519 => Some(ServerHelloKeyShare::new(
group,
gen_fake_x25519_key(rng).to_vec(),
)),
_ => None,
}
}
fn server_hello_key_share_candidate_order(preferred_group: Option<u16>) -> [u16; 2] {
if preferred_group == Some(TLS_NAMED_GROUP_X25519) {
[TLS_NAMED_GROUP_X25519, TLS_NAMED_GROUP_X25519MLKEM768]
} else {
[TLS_NAMED_GROUP_X25519MLKEM768, TLS_NAMED_GROUP_X25519]
}
}
/// Build a ServerHello key_share using a profile-preferred group when possible.
pub(crate) fn build_server_hello_key_share(
handshake: &[u8],
preferred_group: Option<u16>,
rng: &SecureRandom,
) -> Option<ServerHelloKeyShare> {
for group in server_hello_key_share_candidate_order(preferred_group) {
if let Some(key_share) = build_server_hello_key_share_for_group(handshake, group, rng) {
return Some(key_share);
}
}
None
}
/// Build TLS ServerHello response
@@ -530,6 +697,10 @@ pub fn build_server_hello(
alpn: Option<Vec<u8>>,
new_session_tickets: u8,
) -> Vec<u8> {
let server_key_share = ServerHelloKeyShare::new(
TLS_NAMED_GROUP_X25519MLKEM768,
gen_fake_x25519mlkem768_server_key_share(rng),
);
build_server_hello_with_cipher(
secret,
client_digest,
@@ -537,6 +708,7 @@ pub fn build_server_hello(
fake_cert_len,
rng,
cipher_suite::TLS_AES_128_GCM_SHA256,
&server_key_share,
alpn,
new_session_tickets,
)
@@ -554,18 +726,18 @@ pub(crate) fn build_server_hello_with_cipher(
fake_cert_len: usize,
rng: &SecureRandom,
selected_cipher_suite: [u8; 2],
server_key_share: &ServerHelloKeyShare,
alpn: Option<Vec<u8>>,
new_session_tickets: u8,
) -> Vec<u8> {
const MIN_APP_DATA: usize = 64;
const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE;
let fake_cert_len = fake_cert_len.clamp(MIN_APP_DATA, MAX_APP_DATA);
let x25519_key = gen_fake_x25519_key(rng);
// Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
.with_cipher_suite(selected_cipher_suite)
.with_x25519_key(&x25519_key)
.with_key_share(server_key_share.group(), server_key_share.key_exchange())
.with_tls13_version()
.build_record();
@@ -1003,6 +1175,148 @@ fn client_hello_cipher_suites_range(handshake: &[u8]) -> Option<(usize, usize)>
Some((pos, cipher_end))
}
fn client_hello_extensions_range(handshake: &[u8]) -> Option<(usize, usize)> {
if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE {
return None;
}
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
let record_end = 5usize.checked_add(record_len)?;
if record_end > handshake.len() {
return None;
}
let mut pos = 5;
if handshake.get(pos) != Some(&0x01) {
return None;
}
pos += 1;
if pos + 3 > record_end {
return None;
}
let handshake_len = ((handshake[pos] as usize) << 16)
| ((handshake[pos + 1] as usize) << 8)
| handshake[pos + 2] as usize;
pos += 3;
let handshake_end = pos.checked_add(handshake_len)?;
if handshake_end > record_end {
return None;
}
if pos + 2 + 32 > handshake_end {
return None;
}
pos += 2 + 32;
let session_id_len = *handshake.get(pos)? as usize;
pos = pos.checked_add(1)?.checked_add(session_id_len)?;
if pos + 2 > handshake_end {
return None;
}
let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
if cipher_len == 0 || cipher_len % 2 != 0 {
return None;
}
pos += 2;
pos = pos.checked_add(cipher_len)?;
if pos + 1 > handshake_end {
return None;
}
let compression_len = *handshake.get(pos)? as usize;
pos = pos.checked_add(1)?.checked_add(compression_len)?;
if pos == handshake_end {
return Some((handshake_end, handshake_end));
}
if pos + 2 > handshake_end {
return None;
}
let extensions_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2;
let extensions_end = pos.checked_add(extensions_len)?;
if extensions_end > handshake_end {
return None;
}
Some((pos, extensions_end))
}
fn key_share_extension_group_entry<'a>(
data: &'a [u8],
group: u16,
expected_key_exchange_len: usize,
) -> Option<&'a [u8]> {
if data.len() < 2 {
return None;
}
let shares_len = u16::from_be_bytes([data[0], data[1]]) as usize;
if shares_len != data.len().saturating_sub(2) {
return None;
}
let mut pos = 2usize;
let shares_end = 2 + shares_len;
let mut found_group = None;
while pos + 4 <= shares_end {
let entry_group = u16::from_be_bytes([data[pos], data[pos + 1]]);
let key_exchange_len = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
let Some(key_exchange_end) = pos.checked_add(key_exchange_len) else {
return None;
};
if key_exchange_end > shares_end {
return None;
}
if entry_group == group {
if key_exchange_len != expected_key_exchange_len || found_group.is_some() {
return None;
}
found_group = Some(&data[pos..key_exchange_end]);
}
pos = key_exchange_end;
}
if pos == shares_end { found_group } else { None }
}
fn client_hello_key_share_group_entry<'a>(
handshake: &'a [u8],
group: u16,
expected_key_exchange_len: usize,
) -> Option<&'a [u8]> {
let Some((mut pos, extensions_end)) = client_hello_extensions_range(handshake) else {
return None;
};
while pos + 4 <= extensions_end {
let ext_type = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
let ext_len = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize;
pos += 4;
let Some(ext_end) = pos.checked_add(ext_len) else {
return None;
};
if ext_end > extensions_end {
return None;
}
if ext_type == extension_type::KEY_SHARE {
return key_share_extension_group_entry(
&handshake[pos..ext_end],
group,
expected_key_exchange_len,
);
}
pos = ext_end;
}
None
}
fn client_hello_offers_cipher_suite(
handshake: &[u8],
range: (usize, usize),
@@ -1027,20 +1341,23 @@ fn is_tls13_cipher_suite(suite: [u8; 2]) -> bool {
/// Select the ServerHello cipher suite from the already-received ClientHello.
///
/// This is intentionally a borrowed, zero-allocation scan. It runs only for an
/// authenticated success response and keeps malformed or unexpected ClientHello
/// shapes on the previous fallback behavior.
pub(crate) fn select_server_hello_cipher_suite(handshake: &[u8], preferred: [u8; 2]) -> [u8; 2] {
/// authenticated success response and fails closed for malformed or unsupported
/// ClientHello shapes that cannot produce a DPI-consistent ServerHello.
pub(crate) fn select_server_hello_cipher_suite(
handshake: &[u8],
preferred: [u8; 2],
) -> Option<[u8; 2]> {
let preferred = if is_tls13_cipher_suite(preferred) {
preferred
} else {
cipher_suite::TLS_AES_128_GCM_SHA256
};
let Some(range) = client_hello_cipher_suites_range(handshake) else {
return preferred;
return None;
};
if client_hello_offers_cipher_suite(handshake, range, preferred) {
return preferred;
return Some(preferred);
}
for fallback in [
@@ -1049,11 +1366,43 @@ pub(crate) fn select_server_hello_cipher_suite(handshake: &[u8], preferred: [u8;
cipher_suite::TLS_AES_256_GCM_SHA384,
] {
if client_hello_offers_cipher_suite(handshake, range, fallback) {
return fallback;
return Some(fallback);
}
}
preferred
None
}
fn client_hello_key_share_group_len(group: u16) -> Option<usize> {
match group {
TLS_NAMED_GROUP_X25519MLKEM768 => Some(X25519MLKEM768_CLIENT_KEY_SHARE_LEN),
TLS_NAMED_GROUP_X25519 => Some(X25519_KEY_SHARE_LEN),
_ => None,
}
}
/// Select the ServerHello key_share named group from the authenticated ClientHello.
///
/// Malformed key_share structures fail closed so authenticated but
/// DPI-inconsistent ClientHellos take the ordinary masking fallback path.
pub(crate) fn select_server_hello_key_share_group(handshake: &[u8]) -> Option<u16> {
select_server_hello_key_share_group_with_preference(handshake, None)
}
/// Select the ServerHello key_share named group with an origin-profile preference.
pub(crate) fn select_server_hello_key_share_group_with_preference(
handshake: &[u8],
preferred_group: Option<u16>,
) -> Option<u16> {
for group in server_hello_key_share_candidate_order(preferred_group) {
let expected_key_exchange_len = client_hello_key_share_group_len(group)?;
if client_hello_key_share_group_entry(handshake, group, expected_key_exchange_len).is_some()
{
return Some(group);
}
}
None
}
/// Check if bytes look like a TLS ClientHello
+32 -2
View File
@@ -113,7 +113,7 @@ use crate::proxy::handshake::{
};
#[cfg(test)]
use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake};
use crate::proxy::masking::handle_bad_client;
use crate::proxy::masking::handle_bad_client_with_shared;
use crate::proxy::middle_relay::handle_via_middle_proxy;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::proxy::shared_state::ProxySharedState;
@@ -310,6 +310,7 @@ fn masking_outcome<R, W>(
local_addr: SocketAddr,
config: Arc<ProxyConfig>,
beobachten: Arc<BeobachtenStore>,
shared: Arc<ProxySharedState>,
) -> HandshakeOutcome
where
R: AsyncRead + Unpin + Send + 'static,
@@ -325,7 +326,7 @@ where
)
.await;
handle_bad_client(
handle_bad_client_with_shared(
reader,
writer,
&initial_data,
@@ -333,6 +334,7 @@ where
local_addr,
&config,
&beobachten,
shared.as_ref(),
)
.await;
Ok(())
@@ -718,6 +720,7 @@ where
local_addr,
config.clone(),
beobachten.clone(),
shared.clone(),
));
}
@@ -739,6 +742,7 @@ where
local_addr,
config.clone(),
beobachten.clone(),
shared.clone(),
));
}
};
@@ -757,6 +761,7 @@ where
local_addr,
config.clone(),
beobachten.clone(),
shared.clone(),
));
}
@@ -787,6 +792,7 @@ where
local_addr,
config.clone(),
beobachten.clone(),
shared.clone(),
));
}
HandshakeResult::Error(e) => {
@@ -844,6 +850,7 @@ where
local_addr,
config.clone(),
beobachten.clone(),
shared.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
@@ -873,6 +880,7 @@ where
local_addr,
config.clone(),
beobachten.clone(),
shared.clone(),
));
}
@@ -898,6 +906,7 @@ where
local_addr,
config.clone(),
beobachten.clone(),
shared.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
@@ -1096,6 +1105,12 @@ impl RunningClientHandler {
#[cfg(unix)]
let raw_fd = self.raw_fd;
let rst_on_close = self.rst_on_close;
// MSS for the bulk data phase: once the handshake (incl. ServerHello) is
// sent, restore a normal MSS so only the handshake stays fragmented by the
// low listener `client_mss`. Cuts pps ~10x (anti-DDoS abuse on pps-policing
// hosts like FastVPS). None = keep handshake MSS for the whole connection.
#[cfg(unix)]
let bulk_mss: Option<u16> = self.config.server.client_mss_bulk_value().ok().flatten();
let outcome = match self.do_handshake().await? {
Some(outcome) => outcome,
@@ -1109,6 +1124,14 @@ impl RunningClientHandler {
if matches!(rst_on_close, crate::config::RstOnCloseMode::Errors) {
let _ = crate::transport::socket::clear_linger_fd(raw_fd);
}
// Handshake (ServerHello) done — raise MSS for bulk transfer.
#[cfg(unix)]
if let Some(mss) = bulk_mss {
if let Err(e) = crate::transport::socket::set_tcp_mss_fd(raw_fd, u32::from(mss))
{
debug!(error = %e, "Failed to raise bulk MSS; keeping handshake MSS");
}
}
fut.await
}
HandshakeOutcome::NeedsMasking(fut) => fut.await,
@@ -1329,6 +1352,7 @@ impl RunningClientHandler {
local_addr,
self.config.clone(),
self.beobachten.clone(),
self.shared.clone(),
));
}
@@ -1350,6 +1374,7 @@ impl RunningClientHandler {
local_addr,
self.config.clone(),
self.beobachten.clone(),
self.shared.clone(),
));
}
};
@@ -1369,6 +1394,7 @@ impl RunningClientHandler {
local_addr,
self.config.clone(),
self.beobachten.clone(),
self.shared.clone(),
));
}
@@ -1416,6 +1442,7 @@ impl RunningClientHandler {
local_addr,
config.clone(),
self.beobachten.clone(),
self.shared.clone(),
));
}
HandshakeResult::Error(e) => {
@@ -1483,6 +1510,7 @@ impl RunningClientHandler {
local_addr,
config.clone(),
self.beobachten.clone(),
self.shared.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
@@ -1530,6 +1558,7 @@ impl RunningClientHandler {
local_addr,
self.config.clone(),
self.beobachten.clone(),
self.shared.clone(),
));
}
@@ -1568,6 +1597,7 @@ impl RunningClientHandler {
local_addr,
config.clone(),
self.beobachten.clone(),
self.shared.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
+70 -137
View File
@@ -4,7 +4,6 @@
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use hmac::{Hmac, Mac};
#[cfg(test)]
use std::collections::HashSet;
use std::collections::hash_map::DefaultHasher;
@@ -33,8 +32,10 @@ use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::tls_front::{TlsFrontCache, emulator};
#[cfg(test)]
use rand::RngExt;
use sha2::Sha256;
use subtle::ConstantTimeEq;
mod tls_auth;
use self::tls_auth::{parse_tls_auth_material, validate_tls_secret_candidate};
const ACCESS_SECRET_BYTES: usize = 16;
const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5;
@@ -58,8 +59,6 @@ const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8;
const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64;
const RECENT_USER_RING_SCAN_LIMIT: usize = 32;
type HmacSha256 = Hmac<Sha256>;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
#[cfg(not(test))]
@@ -104,23 +103,6 @@ fn should_emit_unknown_sni_warn_in(shared: &ProxySharedState, now: Instant) -> b
true
}
#[derive(Clone, Copy)]
struct ParsedTlsAuthMaterial {
digest: [u8; tls::TLS_DIGEST_LEN],
session_id: [u8; 32],
session_id_len: usize,
now: i64,
ignore_time_skew: bool,
boot_time_cap_secs: u32,
}
#[derive(Clone, Copy)]
struct TlsCandidateValidation {
digest: [u8; tls::TLS_DIGEST_LEN],
session_id: [u8; 32],
session_id_len: usize,
}
struct MtprotoCandidateValidation {
proto_tag: ProtoTag,
dc_idx: i16,
@@ -251,104 +233,6 @@ fn budget_for_validation(total_users: usize, overload: bool, has_hint: bool) ->
total_users.min(cap.max(1))
}
fn parse_tls_auth_material(
handshake: &[u8],
ignore_time_skew: bool,
replay_window_secs: u64,
) -> Option<ParsedTlsAuthMaterial> {
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
return None;
}
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.try_into()
.ok()?;
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
if session_id_len > 32 {
return None;
}
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
return None;
}
let mut session_id = [0u8; 32];
session_id[..session_id_len]
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
let now = if !ignore_time_skew {
let d = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?;
i64::try_from(d.as_secs()).ok()?
} else {
0_i64
};
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
let boot_time_cap_secs = if ignore_time_skew {
0
} else {
tls::BOOT_TIME_MAX_SECS
.min(replay_window_u32)
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
};
Some(ParsedTlsAuthMaterial {
digest,
session_id,
session_id_len,
now,
ignore_time_skew,
boot_time_cap_secs,
})
}
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
mac.finalize().into_bytes().into()
}
fn validate_tls_secret_candidate(
parsed: &ParsedTlsAuthMaterial,
handshake: &[u8],
secret: &[u8],
) -> Option<TlsCandidateValidation> {
let computed = compute_tls_hmac_zeroed_digest(secret, handshake);
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
return None;
}
let timestamp = u32::from_le_bytes([
parsed.digest[28] ^ computed[28],
parsed.digest[29] ^ computed[29],
parsed.digest[30] ^ computed[30],
parsed.digest[31] ^ computed[31],
]);
if !parsed.ignore_time_skew {
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
if !is_boot_time {
let time_diff = parsed.now - i64::from(timestamp);
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
return None;
}
}
}
Some(TlsCandidateValidation {
digest: parsed.digest,
session_id: parsed.session_id,
session_id_len: parsed.session_id_len,
})
}
fn validate_mtproto_secret_candidate(
handshake: &[u8; HANDSHAKE_LEN],
dec_prekey: &[u8; PREKEY_LEN],
@@ -1473,14 +1357,60 @@ where
return HandshakeResult::BadClient { reader, writer };
}
let cached = if config.censorship.tls_emulation {
let cached_entry = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain =
matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str());
let cached_entry = cache.get(selected_domain).await;
let use_full_cert_payload = if config.censorship.serverhello_compact
&& matches!(client_tls_version, tls::ClientHelloTlsVersion::Tls12)
{
Some(cached_entry)
} else {
None
}
} else {
None
};
let preferred_key_share_group = cached_entry
.as_ref()
.and_then(|cached_entry| emulator::profiled_server_hello_key_share_group(cached_entry));
let Some(server_key_share) =
tls::build_server_hello_key_share(handshake, preferred_key_share_group, rng)
else {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
"TLS handshake rejected: ClientHello did not offer a usable TLS 1.3 key_share"
);
return HandshakeResult::BadClient { reader, writer };
};
let preferred_cipher_suite = if let Some(cached_entry) = cached_entry.as_ref() {
if cached_entry.server_hello_template.cipher_suite == [0, 0] {
[0x13, 0x01]
} else {
cached_entry.server_hello_template.cipher_suite
}
} else {
[0x13, 0x01]
};
let Some(selected_cipher_suite) =
tls::select_server_hello_cipher_suite(handshake, preferred_cipher_suite)
else {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
"TLS handshake rejected: ClientHello did not offer a supported TLS 1.3 cipher suite"
);
return HandshakeResult::BadClient { reader, writer };
};
let cached = if let Some(cached_entry) = cached_entry {
let use_full_cert_payload = if config.censorship.serverhello_compact
&& matches!(client_tls_version, tls::ClientHelloTlsVersion::Tls12)
{
if let Some(cache) = tls_cache.as_ref() {
cache
.take_full_cert_budget_for_ip(
peer.ip(),
@@ -1489,11 +1419,11 @@ where
.await
} else {
true
};
Some((cached_entry, use_full_cert_payload))
}
} else {
None
}
true
};
Some((cached_entry, use_full_cert_payload))
} else {
None
};
@@ -1504,13 +1434,6 @@ where
let validation_session_id_slice = &validation_session_id[..validation_session_id_len];
let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
let preferred_cipher_suite = if cached_entry.server_hello_template.cipher_suite == [0, 0] {
[0x13, 0x01]
} else {
cached_entry.server_hello_template.cipher_suite
};
let selected_cipher_suite =
tls::select_server_hello_cipher_suite(handshake, preferred_cipher_suite);
emulator::build_emulated_server_hello(
&validated_secret,
&validation_digest,
@@ -1520,12 +1443,12 @@ where
config.censorship.serverhello_compact,
client_tls_version,
selected_cipher_suite,
&server_key_share,
rng,
selected_alpn.clone(),
config.censorship.tls_new_session_tickets,
)
} else {
let selected_cipher_suite = tls::select_server_hello_cipher_suite(handshake, [0x13, 0x01]);
tls::build_server_hello_with_cipher(
&validated_secret,
&validation_digest,
@@ -1533,6 +1456,7 @@ where
config.censorship.fake_cert_len,
rng,
selected_cipher_suite,
&server_key_share,
selected_alpn.clone(),
config.censorship.tls_new_session_tickets,
)
@@ -1817,7 +1741,16 @@ where
return HandshakeResult::BadClient { reader, writer };
}
let validation = matched_validation.expect("validation must exist when matched");
let Some(validation) = matched_validation else {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(
peer = %peer,
user = %matched_user,
"MTProto handshake matched user without validation material"
);
return HandshakeResult::BadClient { reader, writer };
};
if config
.access
+126
View File
@@ -0,0 +1,126 @@
use hmac::{Hmac, Mac};
use sha2::Sha256;
use subtle::ConstantTimeEq;
use crate::protocol::tls;
type HmacSha256 = Hmac<Sha256>;
/// Parsed TLS authentication material extracted from a ClientHello candidate.
#[derive(Clone, Copy)]
pub(super) struct ParsedTlsAuthMaterial {
digest: [u8; tls::TLS_DIGEST_LEN],
session_id: [u8; 32],
session_id_len: usize,
now: i64,
ignore_time_skew: bool,
boot_time_cap_secs: u32,
}
/// Successful TLS secret validation output used by the handshake state machine.
#[derive(Clone, Copy)]
pub(super) struct TlsCandidateValidation {
pub(super) digest: [u8; tls::TLS_DIGEST_LEN],
pub(super) session_id: [u8; 32],
pub(super) session_id_len: usize,
}
/// Parse TLS auth digest and session-id material from a candidate handshake.
pub(super) fn parse_tls_auth_material(
handshake: &[u8],
ignore_time_skew: bool,
replay_window_secs: u64,
) -> Option<ParsedTlsAuthMaterial> {
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
return None;
}
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.try_into()
.ok()?;
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
if session_id_len > 32 {
return None;
}
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
return None;
}
let mut session_id = [0u8; 32];
session_id[..session_id_len]
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
let now = if !ignore_time_skew {
let d = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?;
i64::try_from(d.as_secs()).ok()?
} else {
0_i64
};
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
let boot_time_cap_secs = if ignore_time_skew {
0
} else {
tls::BOOT_TIME_MAX_SECS
.min(replay_window_u32)
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
};
Some(ParsedTlsAuthMaterial {
digest,
session_id,
session_id_len,
now,
ignore_time_skew,
boot_time_cap_secs,
})
}
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> Option<[u8; 32]> {
let mut mac = HmacSha256::new_from_slice(secret).ok()?;
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
Some(mac.finalize().into_bytes().into())
}
/// Validate a candidate secret against parsed TLS authentication material.
pub(super) fn validate_tls_secret_candidate(
parsed: &ParsedTlsAuthMaterial,
handshake: &[u8],
secret: &[u8],
) -> Option<TlsCandidateValidation> {
let computed = compute_tls_hmac_zeroed_digest(secret, handshake)?;
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
return None;
}
let timestamp = u32::from_le_bytes([
parsed.digest[28] ^ computed[28],
parsed.digest[29] ^ computed[29],
parsed.digest[30] ^ computed[30],
parsed.digest[31] ^ computed[31],
]);
if !parsed.ignore_time_skew {
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
if !is_boot_time {
let time_diff = parsed.now - i64::from(timestamp);
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
return None;
}
}
}
Some(TlsCandidateValidation {
digest: parsed.digest,
session_id: parsed.session_id,
session_id_len: parsed.session_id_len,
})
}
+211 -86
View File
@@ -3,12 +3,15 @@
use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::protocol::tls;
use crate::proxy::shared_state::ProxySharedState;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use crate::transport::socket::configure_tcp_socket;
#[cfg(unix)]
use nix::ifaddrs::getifaddrs;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng};
use std::io::{Error as IoError, ErrorKind};
use std::net::{IpAddr, SocketAddr};
use std::str;
#[cfg(test)]
@@ -17,7 +20,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant as StdInstant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::net::{TcpStream, lookup_host};
#[cfg(unix)]
use tokio::net::UnixStream;
#[cfg(unix)]
@@ -36,6 +39,8 @@ const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
const MASK_BUFFER_GROW_AFTER_BYTES: usize = 256 * 1024;
const MASK_BUFFER_MAX_SIZE: usize = 64 * 1024;
#[cfg(unix)]
#[cfg(not(test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300);
@@ -53,6 +58,27 @@ struct MaskTcpTarget<'a> {
port: u16,
}
fn mask_copy_read_len(total: usize, byte_cap: usize) -> usize {
// Keep short scanner probes on the small baseline buffer and grow only
// after the session has proven to be sustained masking relay traffic.
let active_buffer_size = if total >= MASK_BUFFER_GROW_AFTER_BYTES {
MASK_BUFFER_MAX_SIZE
} else {
MASK_BUFFER_SIZE
};
if byte_cap == 0 {
return active_buffer_size;
}
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
return 0;
}
remaining_budget.min(active_buffer_size)
}
async fn copy_with_idle_timeout<R, W>(
reader: &mut R,
writer: &mut W,
@@ -64,21 +90,18 @@ where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
let mut total = 0usize;
let mut ended_by_eof = false;
let unlimited = byte_cap == 0;
loop {
let read_len = if unlimited {
MASK_BUFFER_SIZE
} else {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
remaining_budget.min(MASK_BUFFER_SIZE)
};
let read_len = mask_copy_read_len(total, byte_cap);
if read_len == 0 {
break;
}
if buf.len() < read_len {
buf.resize(read_len, 0);
}
let read_res = timeout(idle_timeout, reader.read(&mut buf[..read_len])).await;
let n = match read_res {
Ok(Ok(n)) => n,
@@ -250,6 +273,32 @@ async fn consume_client_data_with_timeout_and_cap<R>(
}
}
fn mask_failure_drain_cap(config: &ProxyConfig) -> usize {
let configured_cap = config.censorship.mask_relay_max_bytes;
if configured_cap == 0 {
return MASK_BUFFER_SIZE;
}
configured_cap.min(MASK_BUFFER_SIZE)
}
async fn consume_mask_failure_path<R>(
reader: R,
config: &ProxyConfig,
relay_timeout: Duration,
idle_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
consume_client_data_with_timeout_and_cap(
reader,
mask_failure_drain_cap(config),
relay_timeout,
idle_timeout,
)
.await;
}
async fn wait_mask_connect_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
@@ -385,7 +434,7 @@ mod tls_domain_mask_host_tests {
let mut config = ProxyConfig::default();
config.censorship.tls_domain = "a.com".to_string();
config.censorship.tls_domains = vec!["b.com".to_string(), "c.com".to_string()];
config.censorship.mask_host = Some("a.com".to_string());
config.censorship.mask_host = None;
config
}
@@ -419,6 +468,15 @@ mod tls_domain_mask_host_tests {
assert_eq!(mask_host_for_initial_data(&config, &initial_data), "b.com");
}
#[test]
fn mask_host_uses_primary_domain_when_dynamic_masking_is_disabled() {
let mut config = config_with_tls_domains();
config.censorship.mask_dynamic = false;
let initial_data = client_hello_with_sni("b.com");
assert_eq!(mask_host_for_initial_data(&config, &initial_data), "a.com");
}
#[test]
fn exclusive_mask_target_overrides_only_matching_sni() {
let mut config = config_with_tls_domains();
@@ -471,6 +529,32 @@ fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
host.parse::<IpAddr>().ok()
}
async fn resolve_mask_target_addrs(
mask_host: &str,
mask_port: u16,
) -> std::io::Result<Vec<SocketAddr>> {
if let Some(addr) = resolve_socket_addr(mask_host, mask_port) {
return Ok(vec![addr]);
}
if let Some(ip) = parse_mask_host_ip_literal(mask_host) {
return Ok(vec![SocketAddr::new(ip, mask_port)]);
}
let addrs = timeout(MASK_TIMEOUT, lookup_host((mask_host, mask_port)))
.await
.map_err(|_| IoError::new(ErrorKind::TimedOut, "mask target DNS lookup timed out"))??;
let addrs = addrs.collect::<Vec<_>>();
if addrs.is_empty() {
return Err(IoError::new(
ErrorKind::NotFound,
"mask target DNS lookup returned no addresses",
));
}
Ok(addrs)
}
fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
return Some(config.censorship.tls_domain.as_str());
@@ -577,24 +661,32 @@ fn default_mask_tcp_target_for_initial_data<'a>(
.as_deref()
.unwrap_or(&config.censorship.tls_domain);
if !configured_mask_host.eq_ignore_ascii_case(&config.censorship.tls_domain) {
if config.censorship.mask_host.is_none() && config.censorship.mask_dynamic {
let extracted_sni = if sni.is_none() {
tls::extract_sni_from_client_hello(initial_data)
} else {
None
};
if let Some(host) = sni
.or(extracted_sni.as_deref())
.and_then(|sni| matching_tls_domain_for_sni(config, sni))
{
return MaskTcpTarget {
host,
port: config.censorship.mask_port,
};
}
}
if let Some(mask_host) = config.censorship.mask_host.as_deref() {
return MaskTcpTarget {
host: configured_mask_host,
host: mask_host,
port: config.censorship.mask_port,
};
}
let extracted_sni = if sni.is_none() {
tls::extract_sni_from_client_hello(initial_data)
} else {
None
};
let host = sni
.or(extracted_sni.as_deref())
.and_then(|sni| matching_tls_domain_for_sni(config, sni))
.unwrap_or(configured_mask_host);
MaskTcpTarget {
host,
host: configured_mask_host,
port: config.censorship.mask_port,
}
}
@@ -744,7 +836,7 @@ fn is_mask_target_local_listener_with_interfaces(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
resolved_addrs: &[SocketAddr],
interface_ips: &[IpAddr],
) -> bool {
if mask_port != local_addr.port() {
@@ -754,7 +846,7 @@ fn is_mask_target_local_listener_with_interfaces(
let local_ip = canonical_ip(local_addr.ip());
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
if let Some(addr) = resolved_override {
for addr in resolved_addrs {
let resolved_ip = canonical_ip(addr.ip());
if resolved_ip == local_ip {
return true;
@@ -791,7 +883,7 @@ fn is_mask_target_local_listener(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
resolved_addrs: &[SocketAddr],
) -> bool {
if mask_port != local_addr.port() {
return false;
@@ -802,7 +894,7 @@ fn is_mask_target_local_listener(
mask_host,
mask_port,
local_addr,
resolved_override,
resolved_addrs,
&interfaces,
)
}
@@ -811,7 +903,7 @@ async fn is_mask_target_local_listener_async(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
resolved_addrs: &[SocketAddr],
) -> bool {
if mask_port != local_addr.port() {
return false;
@@ -822,7 +914,7 @@ async fn is_mask_target_local_listener_async(
mask_host,
mask_port,
local_addr,
resolved_override,
resolved_addrs,
&interfaces,
)
}
@@ -860,7 +952,13 @@ fn build_mask_proxy_header(
}
}
/// Handle a bad client by forwarding to mask host
fn configure_mask_backend_socket(stream: &TcpStream) {
if let Err(e) = configure_tcp_socket(stream, false, Duration::from_secs(0)) {
debug!(error = %e, "Failed to configure mask backend socket");
}
}
/// Handles a bad client by forwarding it to the configured mask target.
pub async fn handle_bad_client<R, W>(
reader: R,
writer: W,
@@ -872,6 +970,34 @@ pub async fn handle_bad_client<R, W>(
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let shared = ProxySharedState::new();
handle_bad_client_with_shared(
reader,
writer,
initial_data,
peer,
local_addr,
config,
beobachten,
shared.as_ref(),
)
.await;
}
/// Handles a bad client with shared pre-auth fallback admission state.
pub(crate) async fn handle_bad_client_with_shared<R, W>(
reader: R,
writer: W,
initial_data: &[u8],
peer: SocketAddr,
local_addr: SocketAddr,
config: &ProxyConfig,
beobachten: &BeobachtenStore,
shared: &ProxySharedState,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let client_type = detect_client_type(initial_data);
if config.general.beobachten {
@@ -894,6 +1020,17 @@ pub async fn handle_bad_client<R, W>(
return;
}
let Some(_masking_permit) = shared.try_acquire_masking_fallback_permit() else {
let outcome_started = Instant::now();
debug!(
client_type = client_type,
"Masking fallback concurrency limit reached"
);
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
};
let client_sni = tls::extract_sni_from_client_hello(initial_data);
let exclusive_tcp_target = client_sni
.as_deref()
@@ -956,24 +1093,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@@ -986,11 +1111,27 @@ pub async fn handle_bad_client<R, W>(
let mask_host = mask_target.host;
let mask_port = mask_target.port;
let resolved_mask_addrs = match resolve_mask_target_addrs(mask_host, mask_port).await {
Ok(addrs) => addrs,
Err(e) => {
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
port = mask_port,
error = %e,
"Failed to resolve mask target"
);
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
};
// Fail closed when fallback points at our own listener endpoint.
// Self-referential masking can create recursive proxy loops under
// misconfiguration and leak distinguishable load spikes to adversaries.
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, &resolved_mask_addrs)
.await
{
let outcome_started = Instant::now();
@@ -1001,13 +1142,7 @@ pub async fn handle_bad_client<R, W>(
local = %local_addr,
"Mask target resolves to local listener; refusing self-referential masking fallback"
);
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
@@ -1022,14 +1157,15 @@ pub async fn handle_bad_client<R, W>(
"Forwarding bad client to mask host"
);
// Apply runtime DNS override for mask target when configured.
let mask_addr = resolved_mask_addr
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
let connect_result = timeout(
MASK_TIMEOUT,
TcpStream::connect(resolved_mask_addrs.as_slice()),
)
.await;
match connect_result {
Ok(Ok(stream)) => {
configure_mask_backend_socket(&stream);
let proxy_header =
build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr);
@@ -1068,24 +1204,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
relay_timeout,
idle_timeout,
)
.await;
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@@ -1173,20 +1297,17 @@ async fn consume_client_data<R: AsyncRead + Unpin>(
idle_timeout: Duration,
) {
// Keep drain path fail-closed under slow-loris stalls.
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
let mut total = 0usize;
let unlimited = byte_cap == 0;
loop {
let read_len = if unlimited {
MASK_BUFFER_SIZE
} else {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
remaining_budget.min(MASK_BUFFER_SIZE)
};
let read_len = mask_copy_read_len(total, byte_cap);
if read_len == 0 {
break;
}
if buf.len() < read_len {
buf.resize(read_len, 0);
}
let n = match timeout(idle_timeout, reader.read(&mut buf[..read_len])).await {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
@@ -1197,7 +1318,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(
}
total = total.saturating_add(n);
if !unlimited && total >= byte_cap {
if byte_cap != 0 && total >= byte_cap {
break;
}
}
@@ -1315,6 +1436,10 @@ mod masking_interface_cache_concurrency_security_tests;
#[path = "tests/masking_production_cap_regression_security_tests.rs"]
mod masking_production_cap_regression_security_tests;
#[cfg(test)]
#[path = "tests/masking_relay_manual_perf_tests.rs"]
mod masking_relay_manual_perf_tests;
#[cfg(test)]
#[path = "tests/masking_extended_attack_surface_security_tests.rs"]
mod masking_extended_attack_surface_security_tests;
+2 -1
View File
@@ -52,7 +52,8 @@ use self::c2me::{
};
use self::d2c::{
MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason,
flush_client_or_cancel, observe_me_d2c_flush_event,
flush_client_or_cancel, me_d2c_flush_reason_requires_client_flush,
observe_me_d2c_flush_event,
process_me_writer_response_with_traffic_lease,
};
use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in};
+39 -11
View File
@@ -55,6 +55,37 @@ pub(super) fn classify_me_d2c_flush_reason(
MeD2cFlushReason::QueueDrain
}
pub(super) fn me_d2c_flush_reason_requires_client_flush(reason: MeD2cFlushReason) -> bool {
!matches!(reason, MeD2cFlushReason::QueueDrain)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn queue_drain_is_not_a_physical_flush_trigger() {
assert!(!me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::QueueDrain
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::AckImmediate
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::BatchFrames
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::BatchBytes
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::MaxDelay
));
assert!(me_d2c_flush_reason_requires_client_flush(
MeD2cFlushReason::Close
));
}
}
pub(super) fn observe_me_d2c_flush_event(
stats: &Stats,
reason: MeD2cFlushReason,
@@ -276,20 +307,17 @@ pub(in crate::proxy::middle_relay) fn compute_intermediate_secure_wire_len(
let wire_len = data_len
.checked_add(padding_len)
.ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?;
if wire_len > 0x7fff_ffffusize {
return Err(ProxyError::Proxy(format!(
"Intermediate/Secure frame too large: {wire_len}"
)));
}
let len_val =
crate::protocol::framing::encode_intermediate_header(wire_len, quickack).ok_or_else(
|| {
ProxyError::Proxy(format!(
"Intermediate/Secure frame too large: {wire_len}"
))
},
)?;
let total = 4usize
.checked_add(wire_len)
.ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?;
let mut len_val = u32::try_from(wire_len)
.map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?;
if quickack {
len_val |= 0x8000_0000;
}
Ok((len_val, total))
}
+5 -4
View File
@@ -236,10 +236,10 @@ where
}
Err(e) => return Err(e),
}
let quickack = (len_buf[3] & 0x80) != 0;
let header = crate::protocol::framing::parse_intermediate_header(len_buf);
(
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize,
quickack,
header.wire_len,
header.quickack,
Some(len_buf),
)
}
@@ -331,7 +331,8 @@ where
)
.await?;
// Secure Intermediate: strip validated trailing padding bytes.
// Secure Intermediate strips only non-aligned tail padding; full-word
// padding is indistinguishable from payload in VersionD framing.
if proto_tag == ProtoTag::Secure {
payload.truncate(secure_payload_len);
}
+8 -2
View File
@@ -491,12 +491,18 @@ where
d2c_flush_policy.max_bytes,
max_delay_fired,
);
let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() {
let physical_flush =
me_d2c_flush_reason_requires_client_flush(flush_reason);
let flush_started_at = if physical_flush
&& stats_clone.telemetry_policy().me_level.allows_debug()
{
Some(Instant::now())
} else {
None
};
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
if physical_flush {
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
}
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
+12 -1
View File
@@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
use std::time::Instant;
use dashmap::DashMap;
use tokio::sync::mpsc;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
use tokio_util::sync::CancellationToken;
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
@@ -14,6 +14,7 @@ use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateReg
use crate::proxy::traffic_limiter::TrafficLimiter;
const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64;
const MASKING_FALLBACK_MAX_CONCURRENT: usize = 512;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ConntrackCloseReason {
@@ -72,6 +73,7 @@ pub(crate) struct ProxySharedState {
active_user_sessions: DashMap<(String, u64), CancellationToken>,
pub(crate) conntrack_pressure_active: AtomicBool,
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
masking_fallback_permits: Arc<Semaphore>,
}
#[must_use = "registered user sessions must be kept alive until relay completion"]
@@ -131,9 +133,18 @@ impl ProxySharedState {
active_user_sessions: DashMap::new(),
conntrack_pressure_active: AtomicBool::new(false),
conntrack_close_tx: Mutex::new(None),
masking_fallback_permits: Arc::new(Semaphore::new(MASKING_FALLBACK_MAX_CONCURRENT)),
})
}
/// Attempts to reserve one masking fallback slot for a pre-auth connection.
pub(crate) fn try_acquire_masking_fallback_permit(&self) -> Option<OwnedSemaphorePermit> {
self.masking_fallback_permits
.clone()
.try_acquire_owned()
.ok()
}
pub(crate) fn is_user_enabled(&self, user: &str) -> bool {
!self.disabled_users.contains_key(user)
}
@@ -86,17 +86,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -183,10 +238,11 @@ async fn run_tls_success_mtproto_fail_capture(
assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side.write_all(&bad_mtproto_record).await.unwrap();
let mut client_payload = bad_mtproto_record;
for record in trailing_records {
client_side.write_all(&record).await.unwrap();
client_payload.extend_from_slice(&record);
}
client_side.write_all(&client_payload).await.unwrap();
let got = tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
@@ -435,11 +491,9 @@ async fn blackhat_campaign_06_replayed_tls_hello_is_masked_without_serverhello()
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&first_tail);
client_side.write_all(&client_payload).await.unwrap();
} else {
let mut one = [0u8; 1];
let no_server_hello = tokio::time::timeout(
@@ -741,8 +795,9 @@ async fn blackhat_campaign_12_parallel_tls_success_mtproto_fail_sessions_keep_is
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side.write_all(&bad).await.unwrap();
client_side.write_all(&tail).await.unwrap();
let mut client_payload = bad;
client_payload.extend_from_slice(&tail);
client_side.write_all(&client_payload).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), handler)
@@ -65,17 +65,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -240,11 +295,9 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend(
assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -80,17 +80,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -173,13 +228,11 @@ async fn run_tls_success_mtproto_fail_capture(
assert_eq!(tls_response_head[0], 0x16);
read_tls_record_body(&mut client_side, tls_response_head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let mut client_payload = invalid_mtproto_record;
for record in trailing_records {
client_side.write_all(&record).await.unwrap();
client_payload.extend_from_slice(&record);
}
client_side.write_all(&client_payload).await.unwrap();
let got = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -344,11 +397,9 @@ async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() {
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&first_tail);
client_side.write_all(&client_payload).await.unwrap();
} else {
let mut one = [0u8; 1];
let no_server_hello = tokio::time::timeout(
@@ -419,11 +470,9 @@ async fn connects_bad_increments_once_per_invalid_mtproto() {
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
read_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&tail).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&tail);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -676,8 +725,9 @@ async fn concurrent_tls_mtproto_fail_sessions_are_isolated() {
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
read_tls_record_body(&mut client_side, head).await;
client_side.write_all(&invalid_mtproto).await.unwrap();
client_side.write_all(&trailing).await.unwrap();
let mut client_payload = invalid_mtproto;
client_payload.extend_from_slice(&trailing);
client_side.write_all(&client_payload).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
@@ -71,17 +71,77 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -250,11 +310,9 @@ async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clea
assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
client_side.shutdown().await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
@@ -77,17 +77,73 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
@@ -156,14 +212,9 @@ async fn run_tls_success_mtproto_fail_session(
let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap();
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side
.write_all(&wrap_tls_application_data(&tail))
.await
.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&wrap_tls_application_data(&tail));
client_side.write_all(&client_payload).await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -34,17 +34,77 @@ fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -119,14 +179,9 @@ async fn run_replay_candidate_session(
invalid_mtproto_record.extend_from_slice(&TLS_VERSION);
invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes());
invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side
.write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n");
client_side.write_all(&client_payload).await.unwrap();
}
client_side.shutdown().await.unwrap();
@@ -80,17 +80,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -205,8 +260,13 @@ async fn run_parallel_tail_fallback_case(
assert_eq!(server_hello_head[0], 0x16);
read_tls_record_body(&mut client_side, server_hello_head).await;
client_side.write_all(&invalid_mtproto).await.unwrap();
for chunk in trailing.chunks(write_chunk.max(1)) {
let mut chunks = trailing.chunks(write_chunk.max(1));
let mut client_payload = invalid_mtproto;
if let Some(first_chunk) = chunks.next() {
client_payload.extend_from_slice(first_chunk);
}
client_side.write_all(&client_payload).await.unwrap();
for chunk in chunks {
client_side.write_all(chunk).await.unwrap();
}
client_side.shutdown().await.unwrap();
+88 -14
View File
@@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::{AesCtr, sha256, sha256_hmac};
use crate::protocol::constants::{
DC_IDX_POS, HANDSHAKE_LEN, IV_LEN, PREKEY_LEN, PROTO_TAG_POS, ProtoTag, SKIP_LEN,
TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_CHANGE_CIPHER, TLS_VERSION,
};
use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess;
@@ -1630,17 +1630,73 @@ fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len:
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![0x42u8; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let fill = 0x42u8;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -1663,6 +1719,9 @@ fn make_valid_tls_client_hello_with_alpn(
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
@@ -1674,6 +1733,19 @@ fn make_valid_tls_client_hello_with_alpn(
body.push(0);
let mut ext_blob = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
ext_blob.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
ext_blob.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&key_share_extension);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
@@ -2062,8 +2134,9 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&tls_app_record).await.unwrap();
client_side.write_all(&trailing_tls_record).await.unwrap();
let mut client_payload = tls_app_record;
client_payload.extend_from_slice(&trailing_tls_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -2188,8 +2261,9 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() {
client.read_exact(&mut tls_response_head).await.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client.write_all(&tls_app_record).await.unwrap();
client.write_all(&trailing_tls_record).await.unwrap();
let mut client_payload = tls_app_record;
client_payload.extend_from_slice(&trailing_tls_record);
client.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), mask_accept_task)
.await
@@ -79,17 +79,72 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fi
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const TLS_EXTENSION_PADDING: u16 = 0x0015;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let base_tls_len = 4
+ 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
assert!(
tls_len == base_tls_len || tls_len >= base_tls_len + 4,
"TLS length must leave room for a complete padding extension"
);
if tls_len > base_tls_len {
let padding_len = tls_len - base_tls_len - 4;
extensions.extend_from_slice(&TLS_EXTENSION_PADDING.to_be_bytes());
extensions.extend_from_slice(&(padding_len as u16).to_be_bytes());
extensions.resize(extensions.len() + padding_len, fill);
}
let body_len = tls_len - 4;
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + tls_len);
handshake.push(0x16);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&(tls_len as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
@@ -191,11 +246,9 @@ async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() {
assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -261,11 +314,9 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -335,11 +386,9 @@ async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -403,11 +452,9 @@ async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -481,11 +528,9 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -586,11 +631,9 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
@@ -660,12 +703,14 @@ async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let mut chunks = trailing_record.chunks(3);
let mut client_payload = invalid_mtproto_record;
if let Some(first_chunk) = chunks.next() {
client_payload.extend_from_slice(first_chunk);
}
client_side.write_all(&client_payload).await.unwrap();
for chunk in trailing_record.chunks(3) {
for chunk in chunks {
client_side.write_all(chunk).await.unwrap();
}
@@ -729,11 +774,13 @@ async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for b in trailing_record.iter().copied() {
let mut bytes = trailing_record.iter().copied();
let mut client_payload = invalid_mtproto_record;
if let Some(first_byte) = bytes.next() {
client_payload.push(first_byte);
}
client_side.write_all(&client_payload).await.unwrap();
for b in bytes {
client_side.write_all(&[b]).await.unwrap();
}
@@ -802,14 +849,16 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17];
let mut pos = 0usize;
let mut idx = 0usize;
let mut client_payload = invalid_mtproto_record;
let first_step = chaos[idx % chaos.len()];
let first_end = first_step.min(trailing_record.len());
client_payload.extend_from_slice(&trailing_record[..first_end]);
client_side.write_all(&client_payload).await.unwrap();
pos = first_end;
idx += 1;
while pos < trailing_record.len() {
let step = chaos[idx % chaos.len()];
let end = (pos + step).min(trailing_record.len());
@@ -884,11 +933,9 @@ async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order()
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&r1).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&r1);
client_side.write_all(&client_payload).await.unwrap();
client_side.write_all(&r2).await.unwrap();
client_side.write_all(&r3).await.unwrap();
@@ -958,11 +1005,9 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend()
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
client_side.shutdown().await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
@@ -1029,11 +1074,9 @@ async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated
assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
@@ -1090,11 +1133,9 @@ async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() {
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let write_res = client_side.write_all(&trailing_record).await;
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
let write_res = client_side.write_all(&client_payload).await;
assert!(
write_res.is_ok() || write_res.is_err(),
"write completion is environment dependent under backend reset"
@@ -1170,11 +1211,9 @@ async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity()
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
tokio::time::timeout(Duration::from_secs(5), accept_task)
.await
@@ -1254,11 +1293,9 @@ async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhel
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&trailing_record);
client_side.write_all(&client_payload).await.unwrap();
} else {
let mut one = [0u8; 1];
let no_server_hello = tokio::time::timeout(
@@ -1352,13 +1389,29 @@ async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure()
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31];
for record in [&a, &b, &c] {
let records = [&a, &b, &c];
let mut records_iter = records.iter().copied();
let mut client_payload = invalid_mtproto_record;
if let Some(first_record) = records_iter.next() {
let first_step = chaos[0].min(first_record.len());
client_payload.extend_from_slice(&first_record[..first_step]);
client_side.write_all(&client_payload).await.unwrap();
let mut pos = first_step;
let mut idx = 1usize;
while pos < first_record.len() {
let step = chaos[idx % chaos.len()];
let end = (pos + step).min(first_record.len());
client_side
.write_all(&first_record[pos..end])
.await
.unwrap();
pos = end;
idx += 1;
}
}
for record in records_iter {
let mut pos = 0usize;
let mut idx = 0usize;
while pos < record.len() {
@@ -1433,11 +1486,9 @@ async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_ve
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&ccs).await.unwrap();
let mut client_payload = invalid_mtproto_record;
client_payload.extend_from_slice(&ccs);
client_side.write_all(&client_payload).await.unwrap();
client_side.write_all(&app).await.unwrap();
client_side.write_all(&alert).await.unwrap();
@@ -1533,11 +1584,13 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak()
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for chunk in record.chunks((idx % 9) + 1) {
let mut chunks = record.chunks((idx % 9) + 1);
let mut client_payload = invalid_mtproto_record;
if let Some(first_chunk) = chunks.next() {
client_payload.extend_from_slice(first_chunk);
}
client_side.write_all(&client_payload).await.unwrap();
for chunk in chunks {
client_side.write_all(chunk).await.unwrap();
}
@@ -21,11 +21,59 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
let fill = 0x42u8;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let body_len = 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + 4 + body_len);
handshake.push(TLS_RECORD_HANDSHAKE);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&((4 + body_len) as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
@@ -85,6 +133,9 @@ fn make_valid_tls_client_hello_with_alpn(
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
@@ -96,6 +147,19 @@ fn make_valid_tls_client_hello_with_alpn(
body.push(0);
let mut ext_blob = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
ext_blob.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
ext_blob.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&key_share_extension);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
@@ -150,13 +214,7 @@ async fn tls_minimum_viable_length_boundary() {
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1;
let mut exact_min_handshake = vec![0x42u8; min_len];
exact_min_handshake[min_len - 1] = 0;
exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let digest = sha256_hmac(&secret, &exact_min_handshake);
exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
let exact_min_handshake = make_valid_tls_handshake(&secret, 0);
let res = handle_tls_handshake(
&exact_min_handshake,
@@ -171,12 +229,12 @@ async fn tls_minimum_viable_length_boundary() {
.await;
assert!(
matches!(res, HandshakeResult::Success(_)),
"Exact minimum length TLS handshake must succeed"
"Minimum valid TLS ClientHello must succeed"
);
let short_handshake = vec![0x42u8; min_len - 1];
let short_handshake = &exact_min_handshake[..exact_min_handshake.len() - 1];
let res_short = handle_tls_handshake(
&short_handshake,
short_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
@@ -188,7 +246,7 @@ async fn tls_minimum_viable_length_boundary() {
.await;
assert!(
matches!(res_short, HandshakeResult::BadClient { .. }),
"Handshake 1 byte shorter than minimum must fail closed"
"Handshake 1 byte shorter than minimum valid ClientHello must fail closed"
);
}
@@ -1,5 +1,6 @@
use super::*;
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{TLS_RECORD_HANDSHAKE, TLS_VERSION};
use crate::stats::ReplayChecker;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::{Duration, Instant};
@@ -17,11 +18,59 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
let fill = 0x42u8;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let body_len = 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + 4 + body_len);
handshake.push(TLS_RECORD_HANDSHAKE);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&((4 + body_len) as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
+67 -3
View File
@@ -25,11 +25,59 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
let fill = 0x42u8;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let body_len = 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + 4 + body_len);
handshake.push(TLS_RECORD_HANDSHAKE);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&((4 + body_len) as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
@@ -90,6 +138,9 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
sni_host: &str,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
@@ -112,6 +163,19 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
ext_blob.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
ext_blob.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&key_share_extension);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
@@ -24,6 +24,9 @@ fn make_valid_tls_client_hello_with_alpn(
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
@@ -35,6 +38,19 @@ fn make_valid_tls_client_hello_with_alpn(
body.push(0);
let mut ext_blob = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
ext_blob.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
ext_blob.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&key_share_extension);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
+94 -25
View File
@@ -10,11 +10,62 @@ use std::time::{Duration, Instant};
use tokio::sync::Barrier;
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
make_valid_tls_handshake_with_fill(secret, timestamp, 0x42)
}
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
fn make_valid_tls_handshake_with_fill(secret: &[u8], timestamp: u32, fill: u8) -> Vec<u8> {
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let body_len = 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + 4 + body_len);
handshake.push(TLS_RECORD_HANDSHAKE);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&((4 + body_len) as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
@@ -34,6 +85,9 @@ fn make_valid_tls_client_hello_with_alpn(
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
@@ -45,6 +99,19 @@ fn make_valid_tls_client_hello_with_alpn(
body.push(0);
let mut ext_blob = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
ext_blob.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
ext_blob.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&key_share_extension);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
@@ -92,6 +159,9 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
sni_host: &str,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
@@ -114,6 +184,19 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
ext_blob.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
ext_blob.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&key_share_extension);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
@@ -549,25 +632,6 @@ async fn adversarial_tls_replay_churn_allows_only_unique_digests() {
let replay_checker = Arc::new(ReplayChecker::new(8192, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new());
let make_tagged_handshake = |timestamp: u32, tag: u8| {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![tag; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(&secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
};
let mut tasks = Vec::new();
// 128 exact duplicates: only one should pass.
@@ -596,12 +660,17 @@ async fn adversarial_tls_replay_churn_allows_only_unique_digests() {
}));
}
// 128 unique timestamps: all should pass because HMAC digest differs.
// 128 unique ClientHello bodies: all should pass because replay tracks the
// first digest half, while timestamp skew is encoded in the last bytes.
for i in 0..128u16 {
let config = Arc::clone(&config);
let replay_checker = Arc::clone(&replay_checker);
let rng = Arc::clone(&rng);
let handshake = make_tagged_handshake(10_000 + i as u32, (i as u8).wrapping_add(0x80));
let handshake = make_valid_tls_handshake_with_fill(
&secret,
10_000 + i as u32,
(i as u8).wrapping_add(0x80),
);
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(198, 18, 0, ((i % 250) + 1) as u8)),
@@ -47,11 +47,59 @@ fn make_valid_mtproto_handshake(
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
const TLS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01];
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
let fill = 0x42u8;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
let mut extensions = Vec::new();
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
extensions.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
extensions.extend_from_slice(&key_share_extension);
let body_len = 2
+ 32
+ 1
+ session_id_len
+ 2
+ TLS_AES_128_GCM_SHA256.len()
+ 1
+ 1
+ 2
+ extensions.len();
let mut body = Vec::with_capacity(body_len);
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[fill; 32]);
body.push(session_id_len as u8);
body.extend_from_slice(&[fill; 32]);
body.extend_from_slice(&(TLS_AES_128_GCM_SHA256.len() as u16).to_be_bytes());
body.extend_from_slice(&TLS_AES_128_GCM_SHA256);
body.push(1);
body.push(0);
body.extend_from_slice(&(extensions.len() as u16).to_be_bytes());
body.extend_from_slice(&extensions);
assert_eq!(body.len(), body_len);
let mut handshake = Vec::with_capacity(5 + 4 + body_len);
handshake.push(TLS_RECORD_HANDSHAKE);
handshake.extend_from_slice(&[0x03, 0x01]);
handshake.extend_from_slice(&((4 + body_len) as u16).to_be_bytes());
handshake.push(0x01);
let body_len_bytes = (body_len as u32).to_be_bytes();
handshake.extend_from_slice(&body_len_bytes[1..4]);
handshake.extend_from_slice(&body);
// The proxy authenticates TLS-fronted clients through the random field.
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
@@ -72,6 +120,9 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
sni_host: &str,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
const TLS_EXTENSION_KEY_SHARE: u16 = 0x0033;
const X25519_KEY_SHARE_LEN: usize = 32;
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
@@ -93,6 +144,19 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
let mut key_share = Vec::new();
key_share.extend_from_slice(&tls::TLS_NAMED_GROUP_X25519.to_be_bytes());
key_share.extend_from_slice(&(X25519_KEY_SHARE_LEN as u16).to_be_bytes());
key_share.push(9);
key_share.resize(key_share.len() + X25519_KEY_SHARE_LEN - 1, 0);
let mut key_share_extension = Vec::new();
key_share_extension.extend_from_slice(&(key_share.len() as u16).to_be_bytes());
key_share_extension.extend_from_slice(&key_share);
ext_blob.extend_from_slice(&TLS_EXTENSION_KEY_SHARE.to_be_bytes());
ext_blob.extend_from_slice(&(key_share_extension.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&key_share_extension);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
@@ -34,7 +34,7 @@ fn loop_guard_unspecified_bind_uses_interface_inventory() {
"mask.example",
443,
local,
Some(resolved),
&[resolved],
&interfaces,
));
}
@@ -25,7 +25,7 @@ async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
let barrier = std::sync::Arc::clone(&barrier);
tasks.push(tokio::spawn(async move {
barrier.wait().await;
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await
}));
}
@@ -17,8 +17,8 @@ async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
assert_eq!(
local_interface_enumerations_for_tests(),
@@ -35,7 +35,7 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, &[]).await;
assert!(
!is_local,
@@ -0,0 +1,111 @@
use super::*;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::time::{Duration, Instant};
const PERF_TOTAL_BYTES: usize = 64 * 1024 * 1024;
struct PatternReader {
remaining: usize,
chunk: usize,
read_calls: AtomicUsize,
}
impl PatternReader {
fn new(total: usize, chunk: usize) -> Self {
Self {
remaining: total,
chunk,
read_calls: AtomicUsize::new(0),
}
}
fn read_calls(&self) -> usize {
self.read_calls.load(Ordering::Relaxed)
}
}
impl AsyncRead for PatternReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.read_calls.fetch_add(1, Ordering::Relaxed);
if self.remaining == 0 {
return Poll::Ready(Ok(()));
}
let take = self.remaining.min(self.chunk).min(buf.remaining());
if take == 0 {
return Poll::Ready(Ok(()));
}
static PATTERN: [u8; MASK_BUFFER_MAX_SIZE] = [0xA5; MASK_BUFFER_MAX_SIZE];
buf.put_slice(&PATTERN[..take]);
self.remaining -= take;
Poll::Ready(Ok(()))
}
}
#[derive(Default)]
struct CountingWriter {
written: usize,
}
impl AsyncWrite for CountingWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.written = self.written.saturating_add(buf.len());
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
#[ignore = "manual benchmark: throughput-sensitive and host-dependent"]
async fn masking_copy_with_idle_timeout_manual_throughput() {
let mut reader = PatternReader::new(PERF_TOTAL_BYTES, MASK_BUFFER_MAX_SIZE);
let mut writer = CountingWriter::default();
let started = Instant::now();
let outcome = copy_with_idle_timeout(
&mut reader,
&mut writer,
PERF_TOTAL_BYTES,
true,
Duration::from_secs(30),
)
.await;
let elapsed = started.elapsed();
let mb = PERF_TOTAL_BYTES as f64 / (1024.0 * 1024.0);
let mbps = mb / elapsed.as_secs_f64();
assert_eq!(outcome.total, PERF_TOTAL_BYTES);
assert_eq!(writer.written, PERF_TOTAL_BYTES);
assert!(
!outcome.ended_by_eof,
"manual throughput run should terminate at byte cap"
);
eprintln!(
"masking manual throughput: bytes={} elapsed_ms={} mib_per_sec={:.2} read_calls={}",
PERF_TOTAL_BYTES,
elapsed.as_millis(),
mbps,
reader.read_calls()
);
}
@@ -15,38 +15,49 @@ fn closed_local_port() -> u16 {
#[tokio::test]
async fn self_target_detection_matches_literal_ipv4_listener() {
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await);
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, &[],).await);
}
#[tokio::test]
async fn self_target_detection_matches_bracketed_ipv6_listener() {
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await);
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, &[],).await);
}
#[tokio::test]
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await);
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, &[],).await);
}
#[tokio::test]
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await);
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, &[],).await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await);
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, &[],).await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await);
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, &[remote],).await);
}
#[tokio::test]
async fn self_target_detection_checks_all_resolved_addresses() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
let loopback: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(
is_mask_target_local_listener_async("mask.example", 443, local, &[remote, loopback],).await
);
}
#[tokio::test]
+40 -30
View File
@@ -15,6 +15,7 @@ use crate::crypto::SecureRandom;
use crate::protocol::constants::{
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
};
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
// ============= Unified Codec =============
@@ -197,13 +198,9 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
let len = header.wire_len;
meta.quickack = header.quickack;
// Validate size
if len > max_size {
@@ -239,10 +236,12 @@ fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
dst.reserve(4 + data.len());
let mut len = data.len() as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
let len = encode_intermediate_header(data.len(), frame.meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("frame too large: {} bytes", data.len()),
)
})?;
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
@@ -258,13 +257,9 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
let len = header.wire_len;
meta.quickack = header.quickack;
// Validate size
if len > max_size {
@@ -317,16 +312,18 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
));
}
// Generate padding that keeps total length non-divisible by 4.
// Telegram Desktop VersionD uses a 4-bit random padding length.
let padding_len = secure_padding_len(data.len(), rng);
let total_len = data.len() + padding_len;
dst.reserve(4 + total_len);
let mut len = total_len as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
let len = encode_intermediate_header(total_len, frame.meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("frame too large: {} bytes", total_len),
)
})?;
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
@@ -523,6 +520,16 @@ mod tests {
use tokio::io::duplex;
use tokio_util::codec::{FramedRead, FramedWrite};
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
assert!(decoded.starts_with(original));
assert!(
(original.len()..=original.len() + 12).contains(&decoded.len()),
"Secure decoded payload may retain up to 12 bytes of full-word padding, got {}",
decoded.len()
);
assert_eq!(decoded.len() % 4, 0);
}
#[tokio::test]
async fn test_framed_abridged() {
let (client, server) = duplex(4096);
@@ -565,7 +572,7 @@ mod tests {
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], &original[..]);
assert_secure_decoded_payload(&received.data, &original);
}
#[tokio::test]
@@ -588,7 +595,11 @@ mod tests {
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(received.data.len(), 8);
if proto_tag == ProtoTag::Secure {
assert_secure_decoded_payload(&received.data, &original);
} else {
assert_eq!(received.data.len(), original.len());
}
}
}
@@ -642,7 +653,7 @@ mod tests {
}
#[test]
fn secure_codec_always_adds_padding_and_jitters_wire_length() {
fn secure_codec_uses_tdesktop_padding_range_and_jitters_wire_length() {
let codec = SecureCodec::new(Arc::new(SecureRandom::new()));
let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let mut wire_lens = HashSet::new();
@@ -652,13 +663,12 @@ mod tests {
let mut out = BytesMut::new();
codec.encode(&frame, &mut out).unwrap();
assert!(out.len() > 4 + payload.len());
let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize;
assert_eq!(out.len(), 4 + wire_len);
assert!(
(payload.len() + 1..=payload.len() + 3).contains(&wire_len),
"Secure wire length must be payload+1..3, got {wire_len}"
(payload.len()..=payload.len() + 15).contains(&wire_len),
"Secure wire length must be payload+0..15, got {wire_len}"
);
assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned");
wire_lens.insert(wire_len);
}
+212 -40
View File
@@ -5,21 +5,47 @@
use super::traits::{FrameMeta, LayeredStream};
use crate::crypto::{SecureRandom, crc32};
use crate::protocol::constants::*;
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
use bytes::Bytes;
use std::io::{Error, ErrorKind, Result};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
fn reject_oversize_frame(len: usize, max_frame_size: usize, protocol: &str) -> Result<()> {
if len > max_frame_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("{protocol} frame too large: {len} bytes (max {max_frame_size})"),
));
}
Ok(())
}
// ============= Abridged (Compact) Frame =============
/// Reader for abridged MTProto framing
pub struct AbridgedFrameReader<R> {
upstream: R,
max_frame_size: usize,
}
impl<R> AbridgedFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self {
Self { upstream }
Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
}
}
@@ -47,10 +73,12 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
}
// Length is in 4-byte words
let byte_len = len * 4;
// Length is in 4-byte words.
let byte_len = len
.checked_mul(4)
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "abridged frame length overflow"))?;
reject_oversize_frame(byte_len, self.max_frame_size, "abridged")?;
// Read data
let mut data = vec![0u8; byte_len];
self.upstream.read_exact(&mut data).await?;
@@ -105,10 +133,17 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
if len_div_4 < 0x7f {
// Short length (1 byte)
self.upstream.write_all(&[len_div_4 as u8]).await?;
let mut first = len_div_4 as u8;
if meta.quickack {
first |= 0x80;
}
self.upstream.write_all(&[first]).await?;
} else if len_div_4 < (1 << 24) {
// Long length (4 bytes: 0x7f + 3 bytes)
let mut header = [0x7f, 0, 0, 0];
if meta.quickack {
header[0] |= 0x80;
}
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
self.upstream.write_all(&header).await?;
} else {
@@ -144,11 +179,23 @@ impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
/// Reader for intermediate MTProto framing
pub struct IntermediateFrameReader<R> {
upstream: R,
max_frame_size: usize,
}
impl<R> IntermediateFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self {
Self { upstream }
Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
}
}
@@ -160,15 +207,11 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
let header = parse_intermediate_header(len_bytes);
let len = header.wire_len;
meta.quickack = header.quickack;
reject_oversize_frame(len, self.max_frame_size, "intermediate")?;
// Check QuickACK flag (high bit)
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
@@ -204,7 +247,13 @@ impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
if meta.simple_ack {
self.upstream.write_all(data).await?;
} else {
let len_bytes = (data.len() as u32).to_le_bytes();
let len = encode_intermediate_header(data.len(), meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("Frame too large: {} bytes", data.len()),
)
})?;
let len_bytes = len.to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
}
@@ -233,11 +282,23 @@ impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
/// Reader for secure intermediate MTProto framing (with padding)
pub struct SecureIntermediateFrameReader<R> {
upstream: R,
max_frame_size: usize,
}
impl<R> SecureIntermediateFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self {
Self { upstream }
Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
}
}
@@ -249,24 +310,19 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data (including padding)
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
let header = parse_intermediate_header(len_bytes);
let len = header.wire_len;
meta.quickack = header.quickack;
reject_oversize_frame(len, self.max_frame_size, "secure intermediate")?;
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
format!("Invalid secure frame length: {len}"),
)
})?;
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
data.truncate(payload_len);
Ok((Bytes::from(data), meta))
@@ -311,12 +367,20 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
));
}
// Add padding so total length is never divisible by 4 (MTProto Secure)
// Telegram Desktop VersionD uses a 4-bit random padding length.
let padding_len = secure_padding_len(data.len(), &self.rng);
let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes();
let total_len = data.len().checked_add(padding_len).ok_or_else(|| {
Error::new(ErrorKind::InvalidInput, "secure frame length overflow")
})?;
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("Frame too large: {total_len} bytes"),
)
})?;
let len_bytes = len.to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
@@ -495,15 +559,26 @@ pub enum FrameReaderKind<R> {
}
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
/// Creates a frame reader with the default maximum frame size.
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
Self::with_max_frame_size(upstream, proto_tag, DEFAULT_MAX_FRAME_SIZE)
}
fn with_max_frame_size(
upstream: R,
proto_tag: ProtoTag,
max_frame_size: usize,
) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
ProtoTag::Intermediate => {
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream))
}
ProtoTag::Secure => {
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream))
}
ProtoTag::Abridged => FrameReaderKind::Abridged(
AbridgedFrameReader::with_max_frame_size(upstream, max_frame_size),
),
ProtoTag::Intermediate => FrameReaderKind::Intermediate(
IntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
),
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(
SecureIntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
),
}
}
@@ -557,7 +632,18 @@ mod tests {
use super::*;
use crate::crypto::SecureRandom;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
assert!(decoded.starts_with(original));
assert!(
(original.len()..=original.len() + 12).contains(&decoded.len()),
"Secure decoded payload may retain up to 12 bytes of full-word padding, got {}",
decoded.len()
);
assert_eq!(decoded.len() % 4, 0);
}
#[tokio::test]
async fn test_abridged_roundtrip() {
@@ -613,6 +699,92 @@ mod tests {
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_intermediate_quickack_zero_length_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = IntermediateFrameWriter::new(client);
let mut reader = IntermediateFrameReader::new(server);
writer
.write_frame(&[], &FrameMeta::new().with_quickack())
.await
.unwrap();
writer.flush().await.unwrap();
let (received, meta) = reader.read_frame().await.unwrap();
assert!(received.is_empty());
assert!(meta.quickack);
}
#[tokio::test]
async fn test_abridged_quickack_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
let data = vec![1u8, 2, 3, 4];
writer
.write_frame(&data, &FrameMeta::new().with_quickack())
.await
.unwrap();
writer.flush().await.unwrap();
let (received, meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
assert!(meta.quickack);
}
#[tokio::test]
async fn abridged_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = AbridgedFrameReader::new(server);
let len_words = (DEFAULT_MAX_FRAME_SIZE / 4) + 1;
let encoded = (len_words as u32).to_le_bytes();
client
.write_all(&[0x7f, encoded[0], encoded[1], encoded[2]])
.await
.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn intermediate_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = IntermediateFrameReader::new(server);
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 1, false).unwrap();
client.write_all(&len.to_le_bytes()).await.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn secure_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = SecureIntermediateFrameReader::new(server);
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 4, false).unwrap();
client.write_all(&len.to_le_bytes()).await.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);
@@ -625,7 +797,7 @@ mod tests {
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(received.len(), data.len());
assert_secure_decoded_payload(&received, &data);
}
#[tokio::test]
+601
View File
@@ -0,0 +1,601 @@
use std::collections::BTreeSet;
use std::net::IpAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::sync::watch;
use tracing::warn;
use crate::config::{ProxyConfig, SynLimitMode};
const IPTABLES_CHAIN: &str = "TELEMT_SYNLIMIT";
const IPTABLES_HASHLIMIT_NAME: &str = "TELEMT-BUMPER";
const NFT_TABLE: &str = "telemt_synlimit";
const NFT_CHAIN: &str = "input";
type SynLimitTarget = (Option<IpAddr>, u16, u32, u32, u32);
#[derive(Default)]
struct SynLimitTargets {
iptables_v4: Vec<SynLimitTarget>,
iptables_v6: Vec<SynLimitTarget>,
nft_v4: Vec<SynLimitTarget>,
nft_v6: Vec<SynLimitTarget>,
}
#[derive(Clone, Copy)]
struct NftTableFamilies {
inet: bool,
ip: bool,
ip6: bool,
}
#[derive(Clone, Copy)]
enum NftFamily {
Inet,
Ip,
Ip6,
}
struct NftApplyPlan<'a> {
family: NftFamily,
v4_targets: &'a [SynLimitTarget],
v6_targets: &'a [SynLimitTarget],
}
impl SynLimitTargets {
fn is_empty(&self) -> bool {
self.iptables_v4.is_empty()
&& self.iptables_v6.is_empty()
&& self.nft_v4.is_empty()
&& self.nft_v6.is_empty()
}
fn has_iptables_targets(&self) -> bool {
!self.iptables_v4.is_empty() || !self.iptables_v6.is_empty()
}
fn has_nft_targets(&self) -> bool {
!self.nft_v4.is_empty() || !self.nft_v6.is_empty()
}
}
impl NftFamily {
fn as_str(self) -> &'static str {
match self {
Self::Inet => "inet",
Self::Ip => "ip",
Self::Ip6 => "ip6",
}
}
}
pub(crate) fn spawn_synlimit_controller(config_rx: watch::Receiver<Arc<ProxyConfig>>) {
if !cfg!(target_os = "linux") {
if has_synlimit_config(&config_rx.borrow()) {
warn!("SYN limiter is configured but unsupported on this OS; skipping netfilter rules");
}
return;
}
tokio::spawn(async move {
wait_for_config_channel_close_and_reconcile(config_rx).await;
if let Err(error) = clear_synlimit_rules_all_backends().await {
warn!(error = %error, "Failed to clear SYN limiter rules after config channel close");
}
});
}
async fn wait_for_config_channel_close_and_reconcile(
mut config_rx: watch::Receiver<Arc<ProxyConfig>>,
) {
while config_rx.changed().await.is_ok() {
let cfg = config_rx.borrow_and_update().clone();
reconcile_synlimit_rules(&cfg).await;
}
}
pub(crate) async fn reconcile_synlimit_rules(cfg: &ProxyConfig) {
if let Err(error) = clear_synlimit_rules_all_backends().await {
warn!(error = %error, "Failed to clear existing SYN limiter rules before reconcile");
}
let targets = synlimit_targets(cfg);
if targets.is_empty() {
return;
}
if !has_cap_net_admin() {
warn!(
"SYN limiter configured but CAP_NET_ADMIN is not available; netfilter rules not applied"
);
return;
}
if targets.has_iptables_targets()
&& let Err(error) = apply_iptables_synlimit_rules(&targets).await
{
warn!(error = %error, "Failed to apply iptables SYN limiter rules");
}
if targets.has_nft_targets()
&& let Err(error) = apply_nft_synlimit_rules(&targets).await
{
warn!(error = %error, "Failed to apply nftables SYN limiter rules");
}
}
pub(crate) async fn clear_synlimit_rules_all_backends() -> Result<(), String> {
let mut errors = Vec::new();
if let Err(error) = clear_nft_synlimit_rules_all_families().await {
errors.push(error);
}
if let Err(error) = clear_iptables_synlimit_rules_for_binary("iptables").await {
errors.push(error);
}
if let Err(error) = clear_iptables_synlimit_rules_for_binary("ip6tables").await {
errors.push(error);
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.join("; "))
}
}
fn has_synlimit_config(cfg: &ProxyConfig) -> bool {
cfg.server
.listeners
.iter()
.any(|listener| !matches!(listener.synlimit, SynLimitMode::Off))
}
fn synlimit_targets(cfg: &ProxyConfig) -> SynLimitTargets {
let mut iptables_v4 = BTreeSet::new();
let mut iptables_v6 = BTreeSet::new();
let mut nft_v4 = BTreeSet::new();
let mut nft_v6 = BTreeSet::new();
for listener in &cfg.server.listeners {
let backend = listener.synlimit;
if matches!(backend, SynLimitMode::Off) {
continue;
}
let port = listener.port.unwrap_or(cfg.server.port);
let ip = (!listener.ip.is_unspecified()).then_some(listener.ip);
let seconds = listener.synlimit_seconds;
let hitcount = listener.synlimit_hitcount;
let burst = listener.synlimit_burst;
match (backend, listener.ip.is_ipv4()) {
(SynLimitMode::Iptables, true) => {
iptables_v4.insert((ip, port, seconds, hitcount, burst));
}
(SynLimitMode::Iptables, false) => {
iptables_v6.insert((ip, port, seconds, hitcount, burst));
}
(SynLimitMode::Nftables, true) => {
nft_v4.insert((ip, port, seconds, hitcount, burst));
}
(SynLimitMode::Nftables, false) => {
nft_v6.insert((ip, port, seconds, hitcount, burst));
}
(SynLimitMode::Off, _) => {}
}
}
SynLimitTargets {
iptables_v4: iptables_v4.into_iter().collect(),
iptables_v6: iptables_v6.into_iter().collect(),
nft_v4: nft_v4.into_iter().collect(),
nft_v6: nft_v6.into_iter().collect(),
}
}
async fn apply_iptables_synlimit_rules(targets: &SynLimitTargets) -> Result<(), String> {
apply_iptables_synlimit_rules_for_binary("iptables", &targets.iptables_v4).await?;
apply_iptables_synlimit_rules_for_binary("ip6tables", &targets.iptables_v6).await
}
async fn apply_iptables_synlimit_rules_for_binary(
binary: &str,
targets: &[SynLimitTarget],
) -> Result<(), String> {
if targets.is_empty() {
return Ok(());
}
let _ = run_command(binary, &["-t", "filter", "-N", IPTABLES_CHAIN], None).await;
run_command(binary, &["-t", "filter", "-F", IPTABLES_CHAIN], None).await?;
if run_command(
binary,
&["-t", "filter", "-C", "INPUT", "-j", IPTABLES_CHAIN],
None,
)
.await
.is_err()
{
run_command(
binary,
&["-t", "filter", "-A", "INPUT", "-j", IPTABLES_CHAIN],
None,
)
.await?;
}
for (idx, (ip, port, seconds, hitcount, burst)) in targets.iter().enumerate() {
let hashlimit_name = format!("{IPTABLES_HASHLIMIT_NAME}-{idx}");
let accept_args = iptables_hashlimit_accept_rule_args(
ip,
*port,
*seconds,
*hitcount,
*burst,
&hashlimit_name,
);
let drop_args = iptables_synlimit_drop_rule_args(ip, *port);
let drop_refs: Vec<&str> = drop_args.iter().map(String::as_str).collect();
let accept_refs: Vec<&str> = accept_args.iter().map(String::as_str).collect();
run_command(binary, &accept_refs, None).await?;
run_command(binary, &drop_refs, None).await?;
}
run_command(
binary,
&["-t", "filter", "-A", IPTABLES_CHAIN, "-j", "RETURN"],
None,
)
.await?;
Ok(())
}
fn iptables_hashlimit_accept_rule_args(
ip: &Option<IpAddr>,
port: u16,
seconds: u32,
hitcount: u32,
burst: u32,
hashlimit_name: &str,
) -> Vec<String> {
let mut args = vec![
"-t".to_string(),
"filter".to_string(),
"-A".to_string(),
IPTABLES_CHAIN.to_string(),
"-p".to_string(),
"tcp".to_string(),
"--syn".to_string(),
];
if let Some(ip) = ip {
args.push("-d".to_string());
args.push(ip.to_string());
}
let rate = synlimit_rate_arg(seconds, hitcount);
args.extend([
"--dport".to_string(),
port.to_string(),
"-m".to_string(),
"hashlimit".to_string(),
"--hashlimit-name".to_string(),
hashlimit_name.to_string(),
"--hashlimit-mode".to_string(),
"srcip".to_string(),
"--hashlimit-upto".to_string(),
rate,
"--hashlimit-burst".to_string(),
burst.to_string(),
"--hashlimit-htable-expire".to_string(),
"15000".to_string(),
"-j".to_string(),
"ACCEPT".to_string(),
]);
args
}
fn iptables_synlimit_drop_rule_args(ip: &Option<IpAddr>, port: u16) -> Vec<String> {
let mut args = vec![
"-t".to_string(),
"filter".to_string(),
"-A".to_string(),
IPTABLES_CHAIN.to_string(),
"-p".to_string(),
"tcp".to_string(),
"--syn".to_string(),
];
if let Some(ip) = ip {
args.push("-d".to_string());
args.push(ip.to_string());
}
args.extend([
"--dport".to_string(),
port.to_string(),
"-j".to_string(),
"DROP".to_string(),
]);
args
}
fn synlimit_rate_arg(seconds: u32, hitcount: u32) -> String {
let seconds = u64::from(seconds.max(1));
let hitcount = u64::from(hitcount.max(1));
for (unit_seconds, unit_name) in [
(1_u64, "second"),
(60_u64, "minute"),
(3_600_u64, "hour"),
(86_400_u64, "day"),
] {
let amount = hitcount.saturating_mul(unit_seconds);
if amount >= seconds && amount % seconds == 0 {
return format!("{}/{}", amount / seconds, unit_name);
}
}
let amount = hitcount.saturating_mul(86_400).saturating_add(seconds - 1) / seconds;
format!("{}/day", amount.max(1))
}
async fn clear_iptables_synlimit_rules_for_binary(binary: &str) -> Result<(), String> {
let mut errors = Vec::new();
for _ in 0..8 {
match run_command(
binary,
&["-t", "filter", "-D", "INPUT", "-j", IPTABLES_CHAIN],
None,
)
.await
{
Ok(()) => {}
Err(error) if is_missing_command_or_iptables_rule(&error) => break,
Err(error) => {
errors.push(format!("{binary} delete INPUT jump failed: {error}"));
break;
}
}
}
if let Err(error) = run_command(binary, &["-t", "filter", "-F", IPTABLES_CHAIN], None).await
&& !is_missing_command_or_iptables_rule(&error)
{
errors.push(format!("{binary} flush chain failed: {error}"));
}
if let Err(error) = run_command(binary, &["-t", "filter", "-X", IPTABLES_CHAIN], None).await
&& !is_missing_command_or_iptables_rule(&error)
{
errors.push(format!("{binary} delete chain failed: {error}"));
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.join(", "))
}
}
async fn apply_nft_synlimit_rules(targets: &SynLimitTargets) -> Result<(), String> {
let families = detect_nft_table_families().await;
for plan in nft_apply_plan(families, &targets.nft_v4, &targets.nft_v6) {
let script = nft_synlimit_script(plan);
run_command("nft", &["-f", "-"], Some(script)).await?;
}
Ok(())
}
async fn detect_nft_table_families() -> NftTableFamilies {
let Ok(output) = run_command_stdout("nft", &["list", "tables"]).await else {
return NftTableFamilies {
inet: false,
ip: false,
ip6: false,
};
};
let mut families = NftTableFamilies {
inet: false,
ip: false,
ip6: false,
};
for line in output.lines() {
let mut fields = line.split_whitespace();
if fields.next() != Some("table") {
continue;
}
match fields.next() {
Some("inet") => families.inet = true,
Some("ip") => families.ip = true,
Some("ip6") => families.ip6 = true,
_ => {}
}
}
families
}
fn nft_apply_plan<'a>(
families: NftTableFamilies,
v4_targets: &'a [SynLimitTarget],
v6_targets: &'a [SynLimitTarget],
) -> Vec<NftApplyPlan<'a>> {
if !v4_targets.is_empty() && !v6_targets.is_empty() {
return vec![NftApplyPlan {
family: NftFamily::Inet,
v4_targets,
v6_targets,
}];
}
if !v4_targets.is_empty() {
return vec![NftApplyPlan {
family: if families.inet || !families.ip {
NftFamily::Inet
} else {
NftFamily::Ip
},
v4_targets,
v6_targets: &[],
}];
}
if !v6_targets.is_empty() {
return vec![NftApplyPlan {
family: if families.inet || !families.ip6 {
NftFamily::Inet
} else {
NftFamily::Ip6
},
v4_targets: &[],
v6_targets,
}];
}
Vec::new()
}
fn nft_synlimit_script(plan: NftApplyPlan<'_>) -> String {
let mut script = String::new();
script.push_str(&format!("table {} {NFT_TABLE} {{\n", plan.family.as_str()));
script.push_str(&format!(" chain {NFT_CHAIN} {{\n"));
script.push_str(" type filter hook input priority filter; policy accept;\n");
for (idx, (ip, port, seconds, hitcount, burst)) in plan.v4_targets.iter().enumerate() {
let daddr = ip
.map(|ip| format!(" ip daddr {ip}"))
.unwrap_or_else(String::new);
let rate = synlimit_rate_arg(*seconds, *hitcount);
script.push_str(&format!(
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} meter telemt_synlimit_v4_{idx} {{ ip saddr limit rate over {rate} burst {burst} packets }} drop\n"
));
script.push_str(&format!(
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} accept\n"
));
}
for (idx, (ip, port, seconds, hitcount, burst)) in plan.v6_targets.iter().enumerate() {
let daddr = ip
.map(|ip| format!(" ip6 daddr {ip}"))
.unwrap_or_else(String::new);
let rate = synlimit_rate_arg(*seconds, *hitcount);
script.push_str(&format!(
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} meter telemt_synlimit_v6_{idx} {{ ip6 saddr limit rate over {rate} burst {burst} packets }} drop\n"
));
script.push_str(&format!(
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} accept\n"
));
}
script.push_str(" }\n");
script.push_str("}\n");
script
}
async fn clear_nft_synlimit_rules_all_families() -> Result<(), String> {
let mut errors = Vec::new();
for family in [NftFamily::Inet, NftFamily::Ip, NftFamily::Ip6] {
if let Err(error) = run_command(
"nft",
&["delete", "table", family.as_str(), NFT_TABLE],
None,
)
.await
&& !is_missing_command_or_nft_table(&error)
{
errors.push(format!(
"nft delete table {} {NFT_TABLE} failed: {error}",
family.as_str()
));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.join(", "))
}
}
fn is_missing_command_or_iptables_rule(error: &str) -> bool {
error.contains("is not available")
|| error.contains("No chain/target/match by that name")
|| error.contains("does not exist")
}
fn is_missing_command_or_nft_table(error: &str) -> bool {
error.contains("is not available") || error.contains("No such file or directory")
}
async fn run_command(binary: &str, args: &[&str], stdin: Option<String>) -> Result<(), String> {
let Some(command_path) = resolve_command(binary) else {
return Err(format!("{binary} is not available"));
};
let mut command = Command::new(command_path);
command.args(args);
if stdin.is_some() {
command.stdin(std::process::Stdio::piped());
}
command.stdout(std::process::Stdio::null());
command.stderr(std::process::Stdio::piped());
let mut child = command
.spawn()
.map_err(|e| format!("spawn {binary} failed: {e}"))?;
if let Some(blob) = stdin
&& let Some(mut writer) = child.stdin.take()
{
writer
.write_all(blob.as_bytes())
.await
.map_err(|e| format!("stdin write {binary} failed: {e}"))?;
}
let output = child
.wait_with_output()
.await
.map_err(|e| format!("wait {binary} failed: {e}"))?;
if output.status.success() {
return Ok(());
}
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
Err(if stderr.is_empty() {
format!("{binary} exited with status {}", output.status)
} else {
stderr
})
}
async fn run_command_stdout(binary: &str, args: &[&str]) -> Result<String, String> {
let Some(command_path) = resolve_command(binary) else {
return Err(format!("{binary} is not available"));
};
let output = Command::new(command_path)
.args(args)
.output()
.await
.map_err(|e| format!("wait {binary} failed: {e}"))?;
if output.status.success() {
return Ok(String::from_utf8_lossy(&output.stdout).to_string());
}
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
Err(if stderr.is_empty() {
format!("{binary} exited with status {}", output.status)
} else {
stderr
})
}
fn resolve_command(binary: &str) -> Option<PathBuf> {
let mut dirs = std::env::var_os("PATH")
.map(|path| std::env::split_paths(&path).collect::<Vec<_>>())
.unwrap_or_default();
dirs.extend(["/usr/sbin", "/sbin", "/usr/bin", "/bin"].map(PathBuf::from));
dirs.into_iter()
.map(|dir| dir.join(binary))
.find(|candidate| candidate.exists() && candidate.is_file())
}
fn has_cap_net_admin() -> bool {
#[cfg(target_os = "linux")]
{
let Ok(status) = std::fs::read_to_string("/proc/self/status") else {
return false;
};
for line in status.lines() {
if let Some(raw) = line.strip_prefix("CapEff:") {
let caps = raw.trim();
if let Ok(bits) = u64::from_str_radix(caps, 16) {
const CAP_NET_ADMIN_BIT: u64 = 12;
return (bits & (1u64 << CAP_NET_ADMIN_BIT)) != 0;
}
}
}
false
}
#[cfg(not(target_os = "linux"))]
{
false
}
}
+58 -9
View File
@@ -12,7 +12,8 @@ use tokio::time::sleep;
use tracing::{debug, info, warn};
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsFetchResult, TlsProfileSource,
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsFetchResult, TlsProfileQuality,
TlsProfileSource,
};
const FULL_CERT_SENT_SWEEP_INTERVAL_SECS: u64 = 30;
@@ -47,10 +48,14 @@ pub struct TlsFrontCache {
pub(crate) struct TlsFrontProfileHealth {
pub(crate) domain: String,
pub(crate) source: &'static str,
pub(crate) quality: &'static str,
pub(crate) key_share_group: &'static str,
pub(crate) age_seconds: u64,
pub(crate) is_default: bool,
pub(crate) has_cert_info: bool,
pub(crate) has_cert_payload: bool,
pub(crate) server_hello_record_len: usize,
pub(crate) server_hello_extensions: usize,
pub(crate) app_data_records: usize,
pub(crate) ticket_records: usize,
pub(crate) change_cipher_spec_count: u8,
@@ -66,6 +71,23 @@ fn profile_source_label(source: TlsProfileSource) -> &'static str {
}
}
fn profile_quality_label(quality: TlsProfileQuality) -> &'static str {
match quality {
TlsProfileQuality::Fallback => "fallback",
TlsProfileQuality::RawPartial => "raw_partial",
TlsProfileQuality::RawStrict => "raw_strict",
}
}
fn key_share_group_label(group: Option<u16>) -> &'static str {
match group {
Some(0x001d) => "x25519",
Some(0x11ec) => "x25519mlkem768",
Some(_) => "other",
None => "none",
}
}
#[allow(dead_code)]
impl TlsFrontCache {
pub fn new(domains: &[String], default_len: usize, disk_path: impl AsRef<Path>) -> Self {
@@ -137,7 +159,8 @@ impl TlsFrontCache {
.get(domain)
.cloned()
.unwrap_or_else(|| self.default.clone());
let behavior = &cached.behavior_profile;
let mut behavior = cached.behavior_profile.clone();
behavior.refresh_server_hello_summary(&cached.server_hello_template);
let age_seconds = now
.duration_since(cached.fetched_at)
.map(|duration| duration.as_secs())
@@ -146,10 +169,14 @@ impl TlsFrontCache {
snapshot.push(TlsFrontProfileHealth {
domain: domain.clone(),
source: profile_source_label(behavior.source),
quality: profile_quality_label(behavior.quality),
key_share_group: key_share_group_label(behavior.server_hello_key_share_group),
age_seconds,
is_default: cached.domain == "default",
has_cert_info: cached.cert_info.is_some(),
has_cert_payload: cached.cert_payload.is_some(),
server_hello_record_len: behavior.server_hello_record_len,
server_hello_extensions: behavior.server_hello_extension_types.len(),
app_data_records: cached
.app_data_records_sizes
.len()
@@ -337,6 +364,9 @@ impl TlsFrontCache {
warn!(domain = %cached.domain, "Skipping stale TLS cache entry (>72h)");
continue;
}
cached
.behavior_profile
.refresh_server_hello_summary(&cached.server_hello_template);
let domain = cached.domain.clone();
self.set(&domain, cached).await;
loaded += 1;
@@ -378,20 +408,39 @@ impl TlsFrontCache {
/// Replace cached entry from a fetch result.
pub async fn update_from_fetch(&self, domain: &str, fetched: TlsFetchResult) {
let TlsFetchResult {
server_hello_parsed,
app_data_records_sizes,
total_app_data_len,
mut behavior_profile,
cert_info,
cert_payload,
} = fetched;
behavior_profile.refresh_server_hello_summary(&server_hello_parsed);
let quality = behavior_profile.quality;
let data = CachedTlsData {
server_hello_template: fetched.server_hello_parsed,
cert_info: fetched.cert_info,
cert_payload: fetched.cert_payload,
app_data_records_sizes: fetched.app_data_records_sizes.clone(),
total_app_data_len: fetched.total_app_data_len,
behavior_profile: fetched.behavior_profile,
server_hello_template: server_hello_parsed,
cert_info,
cert_payload,
app_data_records_sizes: app_data_records_sizes.clone(),
total_app_data_len,
behavior_profile,
fetched_at: SystemTime::now(),
domain: domain.to_string(),
};
self.set(domain, data.clone()).await;
self.persist(domain, &data).await;
debug!(domain = %domain, len = fetched.total_app_data_len, "TLS cache updated");
if quality == TlsProfileQuality::RawStrict {
debug!(domain = %domain, len = total_app_data_len, "TLS cache updated");
} else {
warn!(
domain = %domain,
quality = profile_quality_label(quality),
len = total_app_data_len,
"TLS cache updated with non-strict front profile"
);
}
}
pub fn default_entry(&self) -> Arc<CachedTlsData> {
+319 -59
View File
@@ -6,7 +6,8 @@ use crate::protocol::constants::{
TLS_RECORD_HANDSHAKE, TLS_VERSION,
};
use crate::protocol::tls::{
ClientHelloTlsVersion, TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key,
ClientHelloTlsVersion, ServerHelloKeyShare, TLS_DIGEST_LEN, TLS_DIGEST_POS,
TLS_NAMED_GROUP_X25519, TLS_NAMED_GROUP_X25519MLKEM768,
};
use crate::tls_front::types::{
CachedTlsData, ParsedCertificateInfo, TlsExtension, TlsProfileSource,
@@ -20,6 +21,61 @@ const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
const EXT_KEY_SHARE: u16 = 0x0033;
const EXT_ALPN: u16 = 0x0010;
#[derive(Clone, Copy)]
enum FallbackShapeFamily {
NginxLike,
BoringSslLike,
RustlsLike,
}
fn parse_profiled_key_share_group(data: &[u8]) -> Option<u16> {
if data.len() < 4 {
return None;
}
let group = u16::from_be_bytes([data[0], data[1]]);
let key_exchange_len = u16::from_be_bytes([data[2], data[3]]) as usize;
if data.len() != 4 + key_exchange_len {
return None;
}
match group {
TLS_NAMED_GROUP_X25519 | TLS_NAMED_GROUP_X25519MLKEM768 => Some(group),
_ => None,
}
}
fn effective_profiled_server_hello_record_len(cached: &CachedTlsData) -> usize {
if cached.behavior_profile.server_hello_record_len == 0 {
cached.server_hello_template.record_body_len()
} else {
cached.behavior_profile.server_hello_record_len
}
}
fn should_replay_profiled_server_hello_shape(cached: &CachedTlsData) -> bool {
matches!(
cached.behavior_profile.source,
TlsProfileSource::Raw | TlsProfileSource::Merged
) && cached
.server_hello_template
.is_replay_safe_tls13_shape(effective_profiled_server_hello_record_len(cached))
}
/// Return the origin-profiled ServerHello key_share group when it is replay-safe.
pub(crate) fn profiled_server_hello_key_share_group(cached: &CachedTlsData) -> Option<u16> {
if !should_replay_profiled_server_hello_shape(cached) {
return None;
}
cached
.server_hello_template
.extensions
.iter()
.find(|ext| ext.ext_type == EXT_KEY_SHARE)
.and_then(|ext| parse_profiled_key_share_group(&ext.data))
}
fn jitter_and_clamp_sizes(sizes: &[usize], rng: &SecureRandom) -> Vec<usize> {
sizes
.iter()
@@ -70,31 +126,69 @@ fn ensure_payload_capacity(mut sizes: Vec<usize>, payload_len: usize) -> Vec<usi
sizes
}
fn fallback_shape_family(cached: &CachedTlsData) -> FallbackShapeFamily {
match cached.behavior_profile.source {
TlsProfileSource::Rustls => FallbackShapeFamily::RustlsLike,
TlsProfileSource::Default => {
let mut hasher = Hasher::new();
hasher.update(cached.domain.as_bytes());
hasher.update(&cached.total_app_data_len.to_le_bytes());
if hasher.finalize() & 1 == 0 {
FallbackShapeFamily::NginxLike
} else {
FallbackShapeFamily::BoringSslLike
}
}
TlsProfileSource::Raw | TlsProfileSource::Merged => FallbackShapeFamily::NginxLike,
}
}
fn fallback_total_app_data_len(cached: &CachedTlsData) -> usize {
cached
.total_app_data_len
.max(cached.app_data_records_sizes.iter().sum())
.max(1024)
}
fn push_fallback_size(sizes: &mut Vec<usize>, size: usize) {
sizes.push(size.clamp(MIN_APP_DATA, MAX_APP_DATA));
}
fn fallback_family_app_data_sizes(cached: &CachedTlsData) -> Vec<usize> {
let mut sizes = Vec::with_capacity(1);
let size = if matches!(cached.behavior_profile.source, TlsProfileSource::Rustls) {
cached
.app_data_records_sizes
.first()
.copied()
.unwrap_or_else(|| fallback_total_app_data_len(cached))
} else {
fallback_total_app_data_len(cached)
};
push_fallback_size(&mut sizes, size);
sizes
}
fn emulated_app_data_sizes(cached: &CachedTlsData) -> Vec<usize> {
match cached.behavior_profile.source {
TlsProfileSource::Raw | TlsProfileSource::Merged => {
return cached
.app_data_records_sizes
.first()
.copied()
.or_else(|| {
cached
.behavior_profile
.app_data_record_sizes
.first()
.copied()
})
.map(|size| vec![size])
.unwrap_or_else(|| vec![cached.total_app_data_len.max(1024)]);
if let Some(size) = cached.behavior_profile.app_data_record_sizes.first() {
return vec![(*size).clamp(MIN_APP_DATA, MAX_APP_DATA)];
}
if let Some(size) = cached.app_data_records_sizes.first() {
return vec![(*size).clamp(MIN_APP_DATA, MAX_APP_DATA)];
}
return vec![
cached
.total_app_data_len
.max(1024)
.clamp(MIN_APP_DATA, MAX_APP_DATA),
];
}
TlsProfileSource::Default | TlsProfileSource::Rustls => {
return fallback_family_app_data_sizes(cached);
}
TlsProfileSource::Default | TlsProfileSource::Rustls => {}
}
let mut sizes = cached.app_data_records_sizes.clone();
if sizes.is_empty() {
sizes.push(cached.total_app_data_len.max(1024));
}
sizes
}
fn emulated_change_cipher_spec_count(_cached: &CachedTlsData) -> usize {
@@ -122,7 +216,13 @@ fn emulated_ticket_record_sizes(
sizes.extend(profiled_sizes.iter().copied().take(target_count));
while sizes.len() < target_count {
sizes.push(rng.range(48) + 48);
let family = fallback_shape_family(cached);
let base = match family {
FallbackShapeFamily::NginxLike => 96,
FallbackShapeFamily::BoringSslLike => 80,
FallbackShapeFamily::RustlsLike => 112,
};
sizes.push(base + rng.range(64));
}
sizes
@@ -196,19 +296,36 @@ fn push_supported_versions_extension(extensions: &mut Vec<u8>) {
extensions.extend_from_slice(&0x0304u16.to_be_bytes());
}
fn push_key_share_extension(extensions: &mut Vec<u8>, rng: &SecureRandom) {
let key = gen_fake_x25519_key(rng);
fn push_key_share_entry(extensions: &mut Vec<u8>, group: u16, key_exchange: &[u8]) {
let Ok(key_exchange_len) = u16::try_from(key_exchange.len()) else {
return;
};
let Some(entry_len) = key_exchange.len().checked_add(4) else {
return;
};
let Ok(entry_len) = u16::try_from(entry_len) else {
return;
};
extensions.extend_from_slice(&EXT_KEY_SHARE.to_be_bytes());
extensions.extend_from_slice(&(2 + 2 + 32u16).to_be_bytes());
extensions.extend_from_slice(&0x001du16.to_be_bytes());
extensions.extend_from_slice(&(32u16).to_be_bytes());
extensions.extend_from_slice(&key);
extensions.extend_from_slice(&entry_len.to_be_bytes());
extensions.extend_from_slice(&group.to_be_bytes());
extensions.extend_from_slice(&key_exchange_len.to_be_bytes());
extensions.extend_from_slice(key_exchange);
}
fn push_key_share_extension(extensions: &mut Vec<u8>, server_key_share: &ServerHelloKeyShare) {
push_key_share_entry(
extensions,
server_key_share.group(),
server_key_share.key_exchange(),
);
}
fn replay_profiled_server_hello_extension(
ext: &TlsExtension,
extensions: &mut Vec<u8>,
rng: &SecureRandom,
server_key_share: &ServerHelloKeyShare,
saw_supported_versions: &mut bool,
saw_key_share: &mut bool,
) {
@@ -218,7 +335,7 @@ fn replay_profiled_server_hello_extension(
*saw_supported_versions = true;
}
EXT_KEY_SHARE if !*saw_key_share => {
push_key_share_extension(extensions, rng);
push_key_share_extension(extensions, server_key_share);
*saw_key_share = true;
}
EXT_ALPN => {}
@@ -226,7 +343,10 @@ fn replay_profiled_server_hello_extension(
}
}
fn build_profiled_server_hello_extensions(cached: &CachedTlsData, rng: &SecureRandom) -> Vec<u8> {
fn build_profiled_server_hello_extensions(
cached: &CachedTlsData,
server_key_share: &ServerHelloKeyShare,
) -> Vec<u8> {
let capacity = cached
.server_hello_template
.extensions
@@ -238,22 +358,24 @@ fn build_profiled_server_hello_extensions(cached: &CachedTlsData, rng: &SecureRa
let mut saw_supported_versions = false;
let mut saw_key_share = false;
for ext in &cached.server_hello_template.extensions {
replay_profiled_server_hello_extension(
ext,
&mut extensions,
rng,
&mut saw_supported_versions,
&mut saw_key_share,
);
if should_replay_profiled_server_hello_shape(cached) {
for ext in &cached.server_hello_template.extensions {
replay_profiled_server_hello_extension(
ext,
&mut extensions,
server_key_share,
&mut saw_supported_versions,
&mut saw_key_share,
);
}
}
if !saw_key_share {
push_key_share_extension(&mut extensions, rng);
}
if !saw_supported_versions {
push_supported_versions_extension(&mut extensions);
}
if !saw_key_share {
push_key_share_extension(&mut extensions, server_key_share);
}
extensions
}
@@ -268,12 +390,13 @@ pub fn build_emulated_server_hello(
serverhello_compact: bool,
client_tls_version: ClientHelloTlsVersion,
selected_cipher_suite: [u8; 2],
server_key_share: &ServerHelloKeyShare,
rng: &SecureRandom,
alpn: Option<Vec<u8>>,
new_session_tickets: u8,
) -> Vec<u8> {
// --- ServerHello ---
let extensions = build_profiled_server_hello_extensions(cached, rng);
// ServerHello carries the authenticated digest bytes that the client verifies.
let extensions = build_profiled_server_hello_extensions(cached, server_key_share);
let extensions_len = extensions.len() as u16;
let body_len = 2 + 32 + 1 + session_id.len() + 2 + 1 + 2 + extensions.len();
@@ -304,7 +427,7 @@ pub fn build_emulated_server_hello(
server_hello.extend_from_slice(&(message.len() as u16).to_be_bytes());
server_hello.extend_from_slice(&message);
// --- ChangeCipherSpec ---
// ChangeCipherSpec is part of the client-visible TLS shim prefix.
let change_cipher_spec_count = emulated_change_cipher_spec_count(cached);
let mut change_cipher_spec = Vec::with_capacity(change_cipher_spec_count * 6);
for _ in 0..change_cipher_spec_count {
@@ -318,7 +441,8 @@ pub fn build_emulated_server_hello(
]);
}
// --- ApplicationData (fake encrypted records) ---
// Telegram clients authenticate the hello prefix and then expose any later
// ApplicationData bytes to the MTProto packet parser.
let mut sizes = {
let base_sizes = emulated_app_data_sizes(cached);
match cached.behavior_profile.source {
@@ -368,6 +492,7 @@ pub fn build_emulated_server_hello(
// ALPN selection is encrypted inside EncryptedExtensions in real TLS 1.3.
// Keeping the FakeTLS record body opaque avoids a stable plaintext marker.
let _ = alpn;
let mut payload_offset = 0usize;
for size in sizes {
let mut rec = Vec::with_capacity(5 + size);
rec.push(TLS_RECORD_APPLICATION);
@@ -377,10 +502,11 @@ pub fn build_emulated_server_hello(
if let Some(payload) = selected_payload {
if size > 17 {
let body_len = size - 17;
let remaining = payload.len();
let remaining = payload.len().saturating_sub(payload_offset);
let copy_len = remaining.min(body_len);
if copy_len > 0 {
rec.extend_from_slice(&payload[..copy_len]);
rec.extend_from_slice(&payload[payload_offset..payload_offset + copy_len]);
payload_offset += copy_len;
}
if body_len > copy_len {
rec.extend_from_slice(&rng.bytes(body_len - copy_len));
@@ -403,8 +529,7 @@ pub fn build_emulated_server_hello(
app_data.extend_from_slice(&rec);
}
// --- Combine ---
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
// Optional NewSessionTicket mimic records are an explicit fingerprint opt-in.
let mut tickets = Vec::new();
for ticket_len in emulated_ticket_record_sizes(cached, new_session_tickets, rng) {
let mut rec = Vec::with_capacity(5 + ticket_len);
@@ -423,7 +548,7 @@ pub fn build_emulated_server_hello(
response.extend_from_slice(&app_data);
response.extend_from_slice(&tickets);
// --- HMAC ---
// The digest authenticates the server response bytes emitted by this builder.
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
hmac_input.extend_from_slice(client_digest);
hmac_input.extend_from_slice(&response);
@@ -452,13 +577,16 @@ mod tests {
use super::{
build_compact_cert_info_payload, build_emulated_server_hello,
hash_compact_cert_info_payload,
hash_compact_cert_info_payload, profiled_server_hello_key_share_group,
};
use crate::crypto::SecureRandom;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::protocol::tls::ClientHelloTlsVersion;
use crate::protocol::tls::{
ClientHelloTlsVersion, ServerHelloKeyShare, TLS_NAMED_GROUP_X25519,
TLS_NAMED_GROUP_X25519MLKEM768,
};
fn first_app_data_payload(response: &[u8]) -> &[u8] {
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
@@ -523,6 +651,50 @@ mod tests {
}
}
fn test_server_key_share() -> ServerHelloKeyShare {
ServerHelloKeyShare::new(TLS_NAMED_GROUP_X25519MLKEM768, vec![0x42; 1120])
}
fn server_key_share_extension_data(group: u16, len: usize) -> Vec<u8> {
let mut data = Vec::new();
data.extend_from_slice(&group.to_be_bytes());
data.extend_from_slice(&(len as u16).to_be_bytes());
data.resize(4 + len, 0x42);
data
}
#[test]
fn profiled_server_hello_key_share_group_reads_raw_x25519_profile() {
let mut cached = make_cached(None);
cached.behavior_profile.source = TlsProfileSource::Raw;
cached.server_hello_template.extensions = vec![
TlsExtension {
ext_type: 0x002b,
data: vec![0x03, 0x04],
},
TlsExtension {
ext_type: 0x0033,
data: server_key_share_extension_data(TLS_NAMED_GROUP_X25519, 32),
},
];
assert_eq!(
profiled_server_hello_key_share_group(&cached),
Some(TLS_NAMED_GROUP_X25519)
);
}
#[test]
fn profiled_server_hello_key_share_group_ignores_default_profile() {
let mut cached = make_cached(None);
cached.server_hello_template.extensions = vec![TlsExtension {
ext_type: 0x0033,
data: server_key_share_extension_data(TLS_NAMED_GROUP_X25519, 32),
}];
assert_eq!(profiled_server_hello_key_share_group(&cached), None);
}
#[test]
fn test_build_emulated_server_hello_uses_cached_cert_payload() {
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
@@ -540,6 +712,7 @@ mod tests {
true,
ClientHelloTlsVersion::Tls12,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
@@ -569,6 +742,7 @@ mod tests {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x03],
&test_server_key_share(),
&rng,
None,
0,
@@ -604,6 +778,7 @@ mod tests {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
Some(b"h2".to_vec()),
0,
@@ -615,6 +790,82 @@ mod tests {
);
}
#[test]
fn test_build_emulated_server_hello_replays_safe_raw_extension_order() {
let mut cached = make_cached(None);
cached.behavior_profile.source = TlsProfileSource::Raw;
cached.server_hello_template.extensions = vec![
TlsExtension {
ext_type: 0x0033,
data: server_key_share_extension_data(TLS_NAMED_GROUP_X25519, 32),
},
TlsExtension {
ext_type: 0x002b,
data: vec![0x03, 0x04],
},
];
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x21; 32],
&[0x22; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
);
assert_eq!(
server_hello_extension_types(&response),
vec![0x0033, 0x002b]
);
}
#[test]
fn test_build_emulated_server_hello_uses_canonical_order_for_unsafe_raw_shape() {
let mut cached = make_cached(None);
cached.behavior_profile.source = TlsProfileSource::Raw;
cached.server_hello_template.extensions = vec![
TlsExtension {
ext_type: 0x0010,
data: vec![0x00, 0x03, 0x02, b'h', b'2'],
},
TlsExtension {
ext_type: 0x0033,
data: server_key_share_extension_data(TLS_NAMED_GROUP_X25519, 32),
},
TlsExtension {
ext_type: 0x002b,
data: vec![0x03, 0x04],
},
];
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x21; 32],
&[0x22; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
);
assert_eq!(
server_hello_extension_types(&response),
vec![0x002b, 0x0033]
);
}
#[test]
fn test_build_emulated_server_hello_random_fallback_when_no_cert_payload() {
let cached = make_cached(None);
@@ -628,6 +879,7 @@ mod tests {
true,
ClientHelloTlsVersion::Tls12,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
@@ -663,6 +915,7 @@ mod tests {
true,
ClientHelloTlsVersion::Tls12,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
@@ -704,6 +957,7 @@ mod tests {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
@@ -737,6 +991,7 @@ mod tests {
false,
ClientHelloTlsVersion::Tls12,
[0x13, 0x01],
&test_server_key_share(),
&rng,
Some(b"h2".to_vec()),
0,
@@ -769,6 +1024,7 @@ mod tests {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
@@ -776,11 +1032,15 @@ mod tests {
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
let app_start = ccs_start + 6;
let app_len =
u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize;
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
assert_eq!(app_len, 64);
assert_eq!(app_start + 5 + app_len, response.len());
let mut pos = ccs_start + 6;
let mut app_lens = Vec::new();
while pos + 5 <= response.len() {
let record_len = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
assert_eq!(response[pos], TLS_RECORD_APPLICATION);
app_lens.push(record_len);
pos += 5 + record_len;
}
assert_eq!(app_lens, vec![64]);
assert_eq!(pos, response.len());
}
}
+124 -29
View File
@@ -9,6 +9,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use anyhow::{Result, anyhow};
use ml_kem::{DecapsulationKey as MlKemDecapsulationKey, KeyExport, MlKem768, Seed as MlKemSeed};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
@@ -33,6 +34,7 @@ use crate::network::dns_overrides::resolve_socket_addr;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::protocol::tls::{TLS_NAMED_GROUP_X25519, TLS_NAMED_GROUP_X25519MLKEM768};
use crate::tls_front::types::{
ParsedCertificateInfo, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsExtension,
TlsFetchResult, TlsProfileSource,
@@ -40,6 +42,10 @@ use crate::tls_front::types::{
use crate::transport::UpstreamStream;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
#[cfg(test)]
const X25519_KEY_SHARE_LEN: usize = 32;
const MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN: usize = 1184;
/// No-op verifier: accept any certificate (we only need lengths and metadata).
#[derive(Debug)]
struct NoVerify;
@@ -393,8 +399,13 @@ fn profile_cipher_suites(profile: TlsFetchProfile) -> &'static [u16] {
}
fn profile_groups(profile: TlsFetchProfile) -> &'static [u16] {
const MODERN: &[u16] = &[0x001d, 0x0017, 0x0018]; // x25519, secp256r1, secp384r1
const COMPAT: &[u16] = &[0x001d, 0x0017];
const MODERN: &[u16] = &[
TLS_NAMED_GROUP_X25519MLKEM768,
TLS_NAMED_GROUP_X25519,
0x0017,
0x0018,
];
const COMPAT: &[u16] = &[TLS_NAMED_GROUP_X25519, 0x0017];
const LEGACY: &[u16] = &[0x0017];
match profile {
@@ -454,7 +465,9 @@ fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] {
fn profile_padding_target(profile: TlsFetchProfile) -> usize {
match profile {
TlsFetchProfile::ModernChromeLike => 220,
// X25519MLKEM768 makes the Chrome-like ClientHello much larger than
// legacy pre-hybrid profiles; keep enough headroom for padding.
TlsFetchProfile::ModernChromeLike => 1450,
TlsFetchProfile::ModernFirefoxLike => 200,
TlsFetchProfile::CompatTls12 => 180,
TlsFetchProfile::LegacyMinimal => 64,
@@ -475,6 +488,48 @@ fn grease_value(rng: &SecureRandom, deterministic: bool, seed: &str) -> u16 {
}
}
fn gen_mlkem768_client_encapsulation_key(
rng: &SecureRandom,
deterministic: bool,
seed: &str,
) -> Option<Vec<u8>> {
let seed_bytes = if deterministic {
deterministic_bytes(seed, 64)
} else {
rng.bytes(64)
};
let seed = MlKemSeed::try_from(seed_bytes.as_slice()).ok()?;
let decapsulation_key = MlKemDecapsulationKey::<MlKem768>::from_seed(seed);
let encapsulation_key = decapsulation_key.encapsulation_key().to_bytes();
let bytes = encapsulation_key.as_slice();
if bytes.len() == MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN {
Some(bytes.to_vec())
} else {
None
}
}
fn gen_x25519mlkem768_client_key_share(
rng: &SecureRandom,
deterministic: bool,
seed: &str,
) -> Option<Vec<u8>> {
let mlkem_key =
gen_mlkem768_client_encapsulation_key(rng, deterministic, &format!("{seed}:mlkem768"))?;
let x25519_key = gen_key_share(rng, deterministic, &format!("{seed}:x25519"));
let mut key_share =
Vec::with_capacity(MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN + x25519_key.len());
key_share.extend_from_slice(&mlkem_key);
key_share.extend_from_slice(&x25519_key);
Some(key_share)
}
fn push_client_key_share_entry(keyshare: &mut Vec<u8>, group: u16, key: &[u8]) {
keyshare.extend_from_slice(&group.to_be_bytes());
keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes());
keyshare.extend_from_slice(key);
}
fn build_client_hello(
sni: &str,
rng: &SecureRandom,
@@ -597,16 +652,20 @@ fn build_client_hello(
push_extension(0x002d, &[0x01, 0x01]);
}
// key_share (x25519)
let key = gen_key_share(
rng,
deterministic,
&format!("tls-fetch-keyshare:{sni}:{}", profile.as_str()),
);
let mut keyshare = Vec::with_capacity(4 + key.len());
keyshare.extend_from_slice(&0x001du16.to_be_bytes());
keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes());
keyshare.extend_from_slice(&key);
// key_share
let key_share_seed = format!("tls-fetch-keyshare:{sni}:{}", profile.as_str());
let mut keyshare = Vec::new();
if matches!(
profile,
TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike
) {
if let Some(key) = gen_x25519mlkem768_client_key_share(rng, deterministic, &key_share_seed)
{
push_client_key_share_entry(&mut keyshare, TLS_NAMED_GROUP_X25519MLKEM768, &key);
}
}
let key = gen_key_share(rng, deterministic, &key_share_seed);
push_client_key_share_entry(&mut keyshare, TLS_NAMED_GROUP_X25519, &key);
let mut keyshare_ext = Vec::with_capacity(2 + keyshare.len());
keyshare_ext.extend_from_slice(&(keyshare.len() as u16).to_be_bytes());
keyshare_ext.extend_from_slice(&keyshare);
@@ -776,6 +835,7 @@ fn derive_behavior_profile(records: &[(u8, Vec<u8>)]) -> TlsBehaviorProfile {
app_data_record_sizes,
ticket_record_sizes,
source: TlsProfileSource::Raw,
..TlsBehaviorProfile::default()
}
}
@@ -1025,24 +1085,26 @@ where
}
let mut server_hello = None;
let mut server_hello_record_len = 0usize;
for (t, body) in &records {
if *t == TLS_RECORD_HANDSHAKE && server_hello.is_none() {
server_hello = parse_server_hello(body);
server_hello_record_len = body.len();
}
}
let parsed = server_hello.ok_or_else(|| anyhow!("ServerHello not received"))?;
let behavior_profile = derive_behavior_profile(&records);
let mut behavior_profile = derive_behavior_profile(&records);
behavior_profile.server_hello_record_len = server_hello_record_len;
behavior_profile.refresh_server_hello_summary(&parsed);
let mut app_sizes = behavior_profile.app_data_record_sizes.clone();
app_sizes.extend_from_slice(&behavior_profile.ticket_record_sizes);
let total_app_data_len = app_sizes.iter().sum::<usize>().max(1024);
let app_data_records_sizes = behavior_profile
.app_data_record_sizes
.first()
.copied()
.or_else(|| behavior_profile.ticket_record_sizes.first().copied())
.map(|size| vec![size])
.unwrap_or_else(|| vec![total_app_data_len]);
let app_data_records_sizes = if app_sizes.is_empty() {
vec![total_app_data_len]
} else {
app_sizes
};
Ok(TlsFetchResult {
server_hello_parsed: parsed,
@@ -1212,6 +1274,7 @@ where
app_data_record_sizes: app_data_records_sizes,
ticket_record_sizes: Vec::new(),
source: TlsProfileSource::Rustls,
..TlsBehaviorProfile::default()
},
cert_info,
cert_payload,
@@ -1411,6 +1474,8 @@ pub async fn fetch_real_tls_with_strategy(
raw.cert_info = rustls.cert_info;
raw.cert_payload = rustls.cert_payload;
raw.behavior_profile.source = TlsProfileSource::Merged;
raw.behavior_profile
.refresh_server_hello_summary(&raw.server_hello_parsed);
debug!(sni = %sni, "Fetched TLS metadata via adaptive raw probe + rustls cert chain");
Ok(raw)
} else {
@@ -1462,9 +1527,10 @@ mod tests {
use std::time::{Duration, Instant};
use super::{
ProfileCacheValue, TlsFetchStrategy, build_client_hello, build_tls_fetch_proxy_header,
derive_behavior_profile, encode_tls13_certificate_message, fetch_via_rustls_stream,
order_profiles, profile_alpn, profile_cache, profile_cache_key,
MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN, ProfileCacheValue, TLS_NAMED_GROUP_X25519,
TLS_NAMED_GROUP_X25519MLKEM768, TlsFetchStrategy, X25519_KEY_SHARE_LEN, build_client_hello,
build_tls_fetch_proxy_header, derive_behavior_profile, encode_tls13_certificate_message,
fetch_via_rustls_stream, order_profiles, profile_alpn, profile_cache, profile_cache_key,
};
use crate::config::TlsFetchProfile;
use crate::crypto::SecureRandom;
@@ -1790,11 +1856,40 @@ mod tests {
key_share_data.len() - 2,
"key_share list length mismatch"
);
let group = u16::from_be_bytes([key_share_data[2], key_share_data[3]]);
let key_len = u16::from_be_bytes([key_share_data[4], key_share_data[5]]) as usize;
let key = &key_share_data[6..6 + key_len];
assert_eq!(group, 0x001d, "key_share group must be x25519");
assert_eq!(key_len, 32, "x25519 key length must be 32");
let mut pos = 2usize;
let hybrid_group = u16::from_be_bytes([key_share_data[pos], key_share_data[pos + 1]]);
let hybrid_len =
u16::from_be_bytes([key_share_data[pos + 2], key_share_data[pos + 3]]) as usize;
pos += 4;
let hybrid_key = &key_share_data[pos..pos + hybrid_len];
pos += hybrid_len;
assert_eq!(
hybrid_group, TLS_NAMED_GROUP_X25519MLKEM768,
"first key_share group must be X25519MLKEM768"
);
assert_eq!(
hybrid_len,
MLKEM768_CLIENT_ENCAPSULATION_KEY_LEN + X25519_KEY_SHARE_LEN,
"hybrid key length must match X25519MLKEM768"
);
assert!(
hybrid_key.iter().any(|b| *b != 0),
"hybrid key must not be all zero"
);
let group = u16::from_be_bytes([key_share_data[pos], key_share_data[pos + 1]]);
let key_len =
u16::from_be_bytes([key_share_data[pos + 2], key_share_data[pos + 3]]) as usize;
pos += 4;
let key = &key_share_data[pos..pos + key_len];
assert_eq!(
group, TLS_NAMED_GROUP_X25519,
"second key_share group must be x25519"
);
assert_eq!(
key_len, X25519_KEY_SHARE_LEN,
"x25519 key length must be 32"
);
assert!(
key.iter().any(|b| *b != 0),
"x25519 key must not be all zero"
@@ -4,7 +4,9 @@ use crate::crypto::SecureRandom;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::protocol::tls::ClientHelloTlsVersion;
use crate::protocol::tls::{
ClientHelloTlsVersion, ServerHelloKeyShare, TLS_NAMED_GROUP_X25519MLKEM768,
};
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource,
@@ -29,6 +31,7 @@ fn make_cached() -> CachedTlsData {
app_data_record_sizes: vec![1200, 900],
ticket_record_sizes: vec![220, 180],
source: TlsProfileSource::Merged,
..TlsBehaviorProfile::default()
},
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
@@ -52,6 +55,10 @@ fn record_lengths_by_type(response: &[u8], wanted_type: u8) -> Vec<usize> {
out
}
fn test_server_key_share() -> ServerHelloKeyShare {
ServerHelloKeyShare::new(TLS_NAMED_GROUP_X25519MLKEM768, vec![0x42; 1120])
}
#[test]
fn emulated_server_hello_keeps_single_change_cipher_spec_for_client_compatibility() {
let cached = make_cached();
@@ -66,6 +73,7 @@ fn emulated_server_hello_keeps_single_change_cipher_spec_for_client_compatibilit
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
@@ -91,6 +99,7 @@ fn emulated_server_hello_does_not_emit_profile_ticket_tail_when_disabled() {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
@@ -100,6 +109,36 @@ fn emulated_server_hello_does_not_emit_profile_ticket_tail_when_disabled() {
assert_eq!(app_records, vec![1200]);
}
#[test]
fn emulated_server_hello_keeps_default_profile_primary_app_data_single() {
let mut cached = make_cached();
cached.behavior_profile.source = TlsProfileSource::Default;
cached.behavior_profile.app_data_record_sizes.clear();
cached.behavior_profile.ticket_record_sizes.clear();
cached.app_data_records_sizes = vec![2048, 1024];
cached.total_app_data_len = 5000;
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x85; 32],
&[0x86; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
0,
);
let app_records = record_lengths_by_type(&response, TLS_RECORD_APPLICATION);
assert_eq!(app_records.len(), 1);
assert!(app_records[0] >= 64);
}
#[test]
fn emulated_server_hello_uses_profile_ticket_lengths_when_enabled() {
let cached = make_cached();
@@ -114,6 +153,7 @@ fn emulated_server_hello_uses_profile_ticket_lengths_when_enabled() {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
None,
2,
+11 -1
View File
@@ -4,7 +4,9 @@ use crate::crypto::SecureRandom;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::protocol::tls::ClientHelloTlsVersion;
use crate::protocol::tls::{
ClientHelloTlsVersion, ServerHelloKeyShare, TLS_NAMED_GROUP_X25519MLKEM768,
};
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
@@ -29,6 +31,7 @@ fn make_cached(cert_payload: Option<crate::tls_front::types::TlsCertPayload>) ->
app_data_record_sizes: vec![64],
ticket_record_sizes: Vec::new(),
source: TlsProfileSource::Default,
..TlsBehaviorProfile::default()
},
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
@@ -44,6 +47,10 @@ fn first_app_data_payload(response: &[u8]) -> &[u8] {
&response[app_start + 5..app_start + 5 + app_len]
}
fn test_server_key_share() -> ServerHelloKeyShare {
ServerHelloKeyShare::new(TLS_NAMED_GROUP_X25519MLKEM768, vec![0x42; 1120])
}
#[test]
fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
let cached = make_cached(None);
@@ -59,6 +66,7 @@ fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
Some(oversized_alpn),
0,
@@ -98,6 +106,7 @@ fn emulated_server_hello_keeps_alpn_marker_out_of_appdata() {
true,
ClientHelloTlsVersion::Tls13,
[0x13, 0x01],
&test_server_key_share(),
&rng,
Some(b"h2".to_vec()),
0,
@@ -129,6 +138,7 @@ fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() {
true,
ClientHelloTlsVersion::Tls12,
[0x13, 0x01],
&test_server_key_share(),
&rng,
Some(b"h2".to_vec()),
0,
+249
View File
@@ -1,6 +1,14 @@
use serde::{Deserialize, Serialize};
use std::time::SystemTime;
const EXT_ALPN: u16 = 0x0010;
const EXT_SUPPORTED_VERSIONS: u16 = 0x002b;
const EXT_KEY_SHARE: u16 = 0x0033;
const TLS_LEGACY_SERVER_HELLO_VERSION: [u8; 2] = [0x03, 0x03];
const TLS_VERSION_13: [u8; 2] = [0x03, 0x04];
const TLS_NAMED_GROUP_X25519: u16 = 0x001d;
const TLS_NAMED_GROUP_X25519MLKEM768: u16 = 0x11ec;
/// Parsed representation of an unencrypted TLS ServerHello.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedServerHello {
@@ -19,6 +27,96 @@ pub struct TlsExtension {
pub data: Vec<u8>,
}
impl ParsedServerHello {
/// Return the TLS record body length that would contain this ServerHello.
pub(crate) fn record_body_len(&self) -> usize {
let extensions_len = self
.extensions
.iter()
.map(|extension| 4 + extension.data.len())
.sum::<usize>();
4 + 2 + 32 + 1 + self.session_id.len() + 2 + 1 + 2 + extensions_len
}
/// Return visible ServerHello extension types in wire order.
pub(crate) fn extension_types(&self) -> Vec<u16> {
self.extensions
.iter()
.map(|extension| extension.ext_type)
.collect()
}
/// Return a replay-safe ServerHello key_share group when the extension is well-formed.
pub(crate) fn key_share_group(&self) -> Option<u16> {
self.extensions
.iter()
.find(|extension| extension.ext_type == EXT_KEY_SHARE)
.and_then(|extension| parse_key_share_group(&extension.data))
}
/// Return true when the cached ServerHello can safely drive visible TLS 1.3 replay.
pub(crate) fn is_replay_safe_tls13_shape(&self, record_body_len: usize) -> bool {
if self.version != TLS_LEGACY_SERVER_HELLO_VERSION
|| self.compression != 0
|| self.session_id.len() > 32
|| !is_supported_tls13_cipher_suite(self.cipher_suite)
{
return false;
}
if record_body_len != 0 && record_body_len != self.record_body_len() {
return false;
}
let mut saw_supported_versions = false;
let mut saw_key_share = false;
for extension in &self.extensions {
match extension.ext_type {
EXT_SUPPORTED_VERSIONS => {
if saw_supported_versions || extension.data.as_slice() != TLS_VERSION_13 {
return false;
}
saw_supported_versions = true;
}
EXT_KEY_SHARE => {
if saw_key_share || parse_key_share_group(&extension.data).is_none() {
return false;
}
saw_key_share = true;
}
EXT_ALPN => {
return false;
}
_ => {}
}
}
saw_supported_versions && saw_key_share
}
}
fn is_supported_tls13_cipher_suite(cipher_suite: [u8; 2]) -> bool {
matches!(u16::from_be_bytes(cipher_suite), 0x1301 | 0x1302 | 0x1303)
}
fn parse_key_share_group(data: &[u8]) -> Option<u16> {
if data.len() < 4 {
return None;
}
let group = u16::from_be_bytes([data[0], data[1]]);
let key_exchange_len = u16::from_be_bytes([data[2], data[3]]) as usize;
if data.len() != 4 + key_exchange_len {
return None;
}
match group {
TLS_NAMED_GROUP_X25519 | TLS_NAMED_GROUP_X25519MLKEM768 => Some(group),
_ => None,
}
}
/// Basic certificate metadata (optional, informative).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedCertificateInfo {
@@ -54,6 +152,19 @@ pub enum TlsProfileSource {
Merged,
}
/// DPI-facing quality class of a cached TLS front profile.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum TlsProfileQuality {
/// No raw origin ServerHello shape is available.
#[default]
Fallback,
/// Raw origin ServerHello was captured, but encrypted flight shape is incomplete.
RawPartial,
/// Raw origin ServerHello and encrypted flight record sizes were captured.
RawStrict,
}
/// Coarse-grained TLS response behavior captured per SNI.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsBehaviorProfile {
@@ -69,6 +180,18 @@ pub struct TlsBehaviorProfile {
/// Source of this behavior profile.
#[serde(default)]
pub source: TlsProfileSource,
/// DPI-facing quality of this profile.
#[serde(default)]
pub quality: TlsProfileQuality,
/// Captured ServerHello TLS record body length.
#[serde(default)]
pub server_hello_record_len: usize,
/// Captured visible ServerHello extension types in wire order.
#[serde(default)]
pub server_hello_extension_types: Vec<u16>,
/// Captured ServerHello key_share group when replay-safe.
#[serde(default)]
pub server_hello_key_share_group: Option<u16>,
}
fn default_change_cipher_spec_count() -> u8 {
@@ -82,10 +205,54 @@ impl Default for TlsBehaviorProfile {
app_data_record_sizes: Vec::new(),
ticket_record_sizes: Vec::new(),
source: TlsProfileSource::Default,
quality: TlsProfileQuality::Fallback,
server_hello_record_len: 0,
server_hello_extension_types: Vec::new(),
server_hello_key_share_group: None,
}
}
}
impl TlsBehaviorProfile {
/// Refresh cached visible ServerHello summary fields and quality.
pub(crate) fn refresh_server_hello_summary(&mut self, server_hello: &ParsedServerHello) {
let mut has_replay_safe_server_hello = false;
if matches!(
self.source,
TlsProfileSource::Raw | TlsProfileSource::Merged
) {
if self.server_hello_record_len == 0 {
self.server_hello_record_len = server_hello.record_body_len();
}
self.server_hello_extension_types = server_hello.extension_types();
self.server_hello_key_share_group = server_hello.key_share_group();
has_replay_safe_server_hello =
server_hello.is_replay_safe_tls13_shape(self.server_hello_record_len);
} else {
self.server_hello_record_len = 0;
self.server_hello_extension_types.clear();
self.server_hello_key_share_group = None;
}
self.refresh_quality(has_replay_safe_server_hello);
}
/// Recompute the profile quality from current source and record-size evidence.
fn refresh_quality(&mut self, has_replay_safe_server_hello: bool) {
let has_raw_server_hello = matches!(
self.source,
TlsProfileSource::Raw | TlsProfileSource::Merged
) && has_replay_safe_server_hello;
self.quality = if has_raw_server_hello && !self.app_data_record_sizes.is_empty() {
TlsProfileQuality::RawStrict
} else if has_raw_server_hello {
TlsProfileQuality::RawPartial
} else {
TlsProfileQuality::Fallback
};
}
}
/// Cached data per SNI used by the emulator.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedTlsData {
@@ -122,6 +289,34 @@ pub struct TlsFetchResult {
mod tests {
use super::*;
fn tls13_key_share_extension() -> TlsExtension {
let mut data = Vec::new();
data.extend_from_slice(&TLS_NAMED_GROUP_X25519.to_be_bytes());
data.extend_from_slice(&32u16.to_be_bytes());
data.resize(36, 0x42);
TlsExtension {
ext_type: EXT_KEY_SHARE,
data,
}
}
fn replay_safe_server_hello() -> ParsedServerHello {
ParsedServerHello {
version: TLS_LEGACY_SERVER_HELLO_VERSION,
random: [0u8; 32],
session_id: vec![0x11; 32],
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: vec![
TlsExtension {
ext_type: EXT_SUPPORTED_VERSIONS,
data: TLS_VERSION_13.to_vec(),
},
tls13_key_share_extension(),
],
}
}
#[test]
fn cached_tls_data_deserializes_without_behavior_profile() {
let json = r#"
@@ -147,5 +342,59 @@ mod tests {
assert!(cached.behavior_profile.app_data_record_sizes.is_empty());
assert!(cached.behavior_profile.ticket_record_sizes.is_empty());
assert_eq!(cached.behavior_profile.source, TlsProfileSource::Default);
assert_eq!(cached.behavior_profile.quality, TlsProfileQuality::Fallback);
}
#[test]
fn replay_safe_raw_server_hello_with_app_data_is_raw_strict() {
let server_hello = replay_safe_server_hello();
let mut behavior = TlsBehaviorProfile {
source: TlsProfileSource::Raw,
app_data_record_sizes: vec![1200],
..TlsBehaviorProfile::default()
};
behavior.refresh_server_hello_summary(&server_hello);
assert_eq!(behavior.quality, TlsProfileQuality::RawStrict);
assert_eq!(
behavior.server_hello_extension_types,
vec![EXT_SUPPORTED_VERSIONS, EXT_KEY_SHARE]
);
assert_eq!(
behavior.server_hello_key_share_group,
Some(TLS_NAMED_GROUP_X25519)
);
}
#[test]
fn replay_safe_raw_server_hello_without_app_data_is_raw_partial() {
let server_hello = replay_safe_server_hello();
let mut behavior = TlsBehaviorProfile {
source: TlsProfileSource::Raw,
..TlsBehaviorProfile::default()
};
behavior.refresh_server_hello_summary(&server_hello);
assert_eq!(behavior.quality, TlsProfileQuality::RawPartial);
}
#[test]
fn malformed_raw_server_hello_is_fallback_quality() {
let mut server_hello = replay_safe_server_hello();
server_hello.extensions.push(TlsExtension {
ext_type: EXT_ALPN,
data: vec![0x00, 0x03, 0x02, b'h', b'2'],
});
let mut behavior = TlsBehaviorProfile {
source: TlsProfileSource::Raw,
app_data_record_sizes: vec![1200],
..TlsBehaviorProfile::default()
};
behavior.refresh_server_hello_summary(&server_hello);
assert_eq!(behavior.quality, TlsProfileQuality::Fallback);
}
}
+17 -81
View File
@@ -18,7 +18,7 @@ use tokio::time::timeout;
use tracing::{debug, info, warn};
use crate::config::MeSocksKdfPolicy;
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
use crate::crypto::{SecureRandom, derive_middleproxy_keys};
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::network::probe::is_bogon;
@@ -292,14 +292,17 @@ impl MePool {
BndPortStatus::Error
};
record_bnd_status(bnd_addr_status, bnd_port_status, raw_socks_bound_addr);
let reflected = if let Some(bound) = socks_bound_addr {
let socks_bound_kdf_addr = socks_bound_addr.filter(|bound| bound.port() != 0);
// SOCKS BND is the only reflected source that can supply both KDF IP and
// port. Direct STUN reflection is IP-only and keeps the TCP local port.
let reflected = if let Some(bound) = socks_bound_kdf_addr {
Some(bound)
} else if is_socks_route {
match self.socks_kdf_policy() {
MeSocksKdfPolicy::Strict => {
self.stats.increment_me_socks_kdf_strict_reject();
return Err(ProxyError::InvalidHandshake(
"SOCKS route returned no valid BND.ADDR for ME KDF (strict policy)"
"SOCKS route returned no valid BND tuple for ME KDF (strict policy)"
.to_string(),
));
}
@@ -323,16 +326,14 @@ impl MePool {
let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected);
let peer_addr_nat =
SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
let client_addr_for_kdf = socks_bound_kdf_addr.unwrap_or(local_addr_nat);
if let Some(upstream_info) = upstream_egress {
let client_ip_for_kdf = socks_bound_addr
.map(|value| value.ip())
.unwrap_or(local_addr_nat.ip());
record_upstream_bnd_status(
upstream_info.upstream_id,
bnd_addr_status,
bnd_port_status,
raw_socks_bound_addr,
Some(client_ip_for_kdf),
Some(client_addr_for_kdf.ip()),
);
}
let (mut rd, mut wr) = tokio::io::split(stream);
@@ -409,6 +410,7 @@ impl MePool {
info!(
%local_addr,
%local_addr_nat,
%client_addr_for_kdf,
reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string),
%peer_addr,
%transport_peer_addr,
@@ -417,21 +419,20 @@ impl MePool {
key_selector = format_args!("0x{ks:08x}"),
crypto_schema = format_args!("0x{schema:08x}"),
skew_secs = skew,
socks_kdf_policy = ?self.socks_kdf_policy(),
"ME key derivation parameters"
);
let ts_bytes = crypto_ts.to_le_bytes();
let server_port_bytes = peer_addr_nat.port().to_le_bytes();
let socks_bound_port = socks_bound_addr
.map(|bound| bound.port())
.filter(|port| *port != 0);
let client_port_for_kdf = socks_bound_port.unwrap_or(local_addr_nat.port());
let socks_bound_port = socks_bound_kdf_addr.map(|bound| bound.port());
let client_port_for_kdf = client_addr_for_kdf.port();
let client_port_source = KdfClientPortSource::from_socks_bound_port(socks_bound_port);
let kdf_fingerprint = Self::kdf_material_fingerprint(
local_addr_nat.ip(),
client_addr_for_kdf.ip(),
peer_addr_nat,
reflected.map(|value| value.ip()),
socks_bound_addr.map(|value| value.ip()),
socks_bound_kdf_addr.map(|value| value.ip()),
client_port_source,
);
let previous_kdf_fingerprint = {
@@ -473,7 +474,7 @@ impl MePool {
let client_port_bytes = client_port_for_kdf.to_le_bytes();
let server_ip = extract_ip_material(peer_addr_nat);
let client_ip = extract_ip_material(local_addr_nat);
let client_ip = extract_ip_material(client_addr_for_kdf);
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) =
match (server_ip, client_ip) {
@@ -494,38 +495,6 @@ impl MePool {
}
};
let diag_level: u8 = std::env::var("ME_DIAG")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0);
let prekey_client = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"CLIENT",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let prekey_server = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"SERVER",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let (wk, wi) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
@@ -556,47 +525,14 @@ impl MePool {
let requested_crc_mode = RpcChecksumMode::Crc32c;
let hs_payload = build_handshake_payload(
hs_our_ip,
local_addr.port(),
client_port_for_kdf,
hs_peer_ip,
peer_addr.port(),
peer_addr_nat.port(),
requested_crc_mode.advertised_flags(),
);
let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32);
if diag_level >= 1 {
info!(
write_key = %hex_dump(&wk),
write_iv = %hex_dump(&wi),
read_key = %hex_dump(&rk),
read_iv = %hex_dump(&ri),
srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
srv_port = %hex_dump(&server_port_bytes),
clt_port = %hex_dump(&client_port_bytes),
crypto_ts = %hex_dump(&ts_bytes),
nonce_srv = %hex_dump(&srv_nonce),
nonce_clt = %hex_dump(&my_nonce),
prekey_sha256_client = %hex_dump(&sha256(&prekey_client)),
prekey_sha256_server = %hex_dump(&sha256(&prekey_server)),
hs_plain = %hex_dump(&hs_frame),
proxy_secret_sha256 = %hex_dump(&sha256(&secret)),
"ME diag: derived keys and handshake plaintext"
);
}
if diag_level >= 2 {
info!(
prekey_client = %hex_dump(&prekey_client),
prekey_server = %hex_dump(&prekey_server),
"ME diag: full prekey buffers"
);
}
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
if diag_level >= 1 {
info!(
hs_cipher = %hex_dump(&encrypted_hs),
"ME diag: handshake ciphertext"
);
}
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
+2
View File
@@ -1728,6 +1728,8 @@ mod tests {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,
+6
View File
@@ -336,6 +336,8 @@ pub(super) struct NatRuntimeCore {
pub(super) nat_probe: bool,
pub(super) nat_stun: Option<String>,
pub(super) nat_stun_servers: Vec<String>,
pub(super) stun_tcp_fallback: bool,
pub(super) http_ip_detect_urls: Vec<String>,
pub(super) nat_stun_live_servers: Arc<RwLock<Vec<String>>>,
pub(super) nat_probe_concurrency: usize,
pub(super) detected_ipv6: Option<Ipv6Addr>,
@@ -484,6 +486,8 @@ impl MePool {
nat_probe: bool,
nat_stun: Option<String>,
nat_stun_servers: Vec<String>,
stun_tcp_fallback: bool,
http_ip_detect_urls: Vec<String>,
nat_probe_concurrency: usize,
detected_ipv6: Option<Ipv6Addr>,
me_one_retry: u8,
@@ -706,6 +710,8 @@ impl MePool {
nat_probe,
nat_stun,
nat_stun_servers,
stun_tcp_fallback,
http_ip_detect_urls,
nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())),
nat_probe_concurrency: nat_probe_concurrency.max(1),
detected_ipv6,
+35 -54
View File
@@ -1,19 +1,22 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};
use std::net::IpAddr;
use std::time::Duration;
use tokio::task::JoinSet;
use tokio::time::timeout;
use tracing::{debug, info, warn};
use tracing::{debug, info};
use crate::error::{ProxyError, Result};
use crate::network::probe::is_bogon;
use crate::network::stun::{IpFamily, stun_probe_dual, stun_probe_family_with_bind};
use crate::network::probe::{detect_public_ipv4_http, is_bogon};
use crate::network::stun::{
IpFamily, stun_probe_dual_with_tcp_fallback, stun_probe_family_with_bind_and_tcp_fallback,
};
use super::MePool;
use std::time::Instant;
const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5);
const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12);
#[allow(dead_code)]
pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stun::DualStunResult> {
@@ -28,16 +31,13 @@ pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stu
"STUN server is not configured".to_string(),
));
}
stun_probe_dual(&stun_addr).await
stun_probe_dual_with_tcp_fallback(&stun_addr, false).await
}
#[allow(dead_code)]
pub async fn detect_public_ip() -> Option<IpAddr> {
fetch_public_ipv4_with_retry()
.await
.ok()
.flatten()
.map(IpAddr::V4)
let urls = crate::config::defaults::default_http_ip_detect_urls();
detect_public_ipv4_http(&urls).await.map(IpAddr::V4)
}
impl MePool {
@@ -65,15 +65,26 @@ impl MePool {
let mut live_servers = Vec::new();
let mut best_by_ip: HashMap<IpAddr, (usize, std::net::SocketAddr)> = HashMap::new();
let concurrency = self.nat_runtime.nat_probe_concurrency.max(1);
let tcp_fallback = self.nat_runtime.stun_tcp_fallback;
while next_idx < servers.len() || !join_set.is_empty() {
while next_idx < servers.len() && join_set.len() < concurrency {
let stun_addr = servers[next_idx].clone();
next_idx += 1;
join_set.spawn(async move {
let batch_timeout = if tcp_fallback {
STUN_BATCH_TCP_FALLBACK_TIMEOUT
} else {
STUN_BATCH_TIMEOUT
};
let res = timeout(
STUN_BATCH_TIMEOUT,
stun_probe_family_with_bind(&stun_addr, family, bind_ip),
batch_timeout,
stun_probe_family_with_bind_and_tcp_fallback(
&stun_addr,
family,
bind_ip,
tcp_fallback,
),
)
.await;
(stun_addr, res)
@@ -193,6 +204,10 @@ impl MePool {
return self.nat_runtime.nat_ip_cfg;
}
if !self.nat_runtime.nat_probe {
return None;
}
if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
return None;
}
@@ -201,21 +216,15 @@ impl MePool {
return Some(ip);
}
match fetch_public_ipv4_with_retry().await {
Ok(Some(ip)) => {
{
let mut guard = self.nat_runtime.nat_ip_detected.write().await;
*guard = Some(IpAddr::V4(ip));
}
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
Some(IpAddr::V4(ip))
}
Ok(None) => None,
Err(e) => {
warn!(error = %e, "Failed to auto-detect public IP");
None
}
let Some(ip) = detect_public_ipv4_http(&self.nat_runtime.http_ip_detect_urls).await else {
return None;
};
{
let mut guard = self.nat_runtime.nat_ip_detected.write().await;
*guard = Some(IpAddr::V4(ip));
}
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
Some(IpAddr::V4(ip))
}
pub(super) async fn maybe_reflect_public_addr(
@@ -365,31 +374,3 @@ impl MePool {
None
}
}
async fn fetch_public_ipv4_with_retry() -> Result<Option<Ipv4Addr>> {
let providers = [
"https://checkip.amazonaws.com",
"http://v4.ident.me",
"http://ipv4.icanhazip.com",
];
for url in providers {
if let Ok(Some(ip)) = fetch_public_ipv4_once(url).await {
return Ok(Some(ip));
}
}
Ok(None)
}
async fn fetch_public_ipv4_once(url: &str) -> Result<Option<Ipv4Addr>> {
let res = reqwest::get(url)
.await
.map_err(|e| ProxyError::Proxy(format!("public IP detection request failed: {e}")))?;
let text = res
.text()
.await
.map_err(|e| ProxyError::Proxy(format!("public IP detection read failed: {e}")))?;
let ip = text.trim().parse().ok();
Ok(ip)
}
+2 -4
View File
@@ -464,8 +464,7 @@ impl MePool {
if !self.writer_accepts_new_binding(w) {
continue;
}
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
let (payload, meta) = build_routed_payload(our_addr);
match w.tx.clone().try_reserve_owned() {
Ok(permit) => {
if !self.registry.bind_writer(conn_id, w.id, meta).await {
@@ -520,8 +519,7 @@ impl MePool {
}
self.stats
.increment_me_writer_pick_blocking_fallback_total();
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
let (payload, meta) = build_routed_payload(our_addr);
let reserve_result =
if let Some(timeout) = self.route_runtime.me_route_blocking_send_timeout {
match tokio::time::timeout(timeout, w.tx.clone().reserve_owned()).await {
@@ -38,6 +38,8 @@ async fn make_pool(
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,
@@ -36,6 +36,8 @@ async fn make_pool(
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,
@@ -31,6 +31,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,
@@ -20,6 +20,8 @@ async fn make_pool() -> Arc<MePool> {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,
@@ -25,6 +25,8 @@ async fn make_pool() -> Arc<MePool> {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,
@@ -31,6 +31,8 @@ async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,
@@ -175,6 +177,37 @@ async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duratio
data_count
}
async fn recv_first_data_payload(
rx: &mut mpsc::Receiver<WriterCommand>,
budget: Duration,
) -> Option<Vec<u8>> {
let start = Instant::now();
while Instant::now().duration_since(start) < budget {
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
Ok(Some(WriterCommand::Data(payload))) => return Some(payload.to_vec()),
Ok(Some(WriterCommand::DataAndFlush(payload))) => return Some(payload.to_vec()),
Ok(Some(_)) => {}
Ok(None) => break,
Err(_) => break,
}
}
None
}
fn proxy_req_our_addr_from_payload(payload: &[u8]) -> SocketAddr {
const CLIENT_ADDR_WIRE_LEN: usize = 20;
const OUR_ADDR_OFFSET: usize = 4 + 4 + 8 + CLIENT_ADDR_WIRE_LEN;
let our_addr = &payload[OUR_ADDR_OFFSET..OUR_ADDR_OFFSET + CLIENT_ADDR_WIRE_LEN];
let ip = Ipv4Addr::new(our_addr[12], our_addr[13], our_addr[14], our_addr[15]);
let port = u32::from_le_bytes([our_addr[16], our_addr[17], our_addr[18], our_addr[19]]);
SocketAddr::new(
IpAddr::V4(ip),
u16::try_from(port).expect("test port must fit u16"),
)
}
#[tokio::test]
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
let (pool, _rng) = make_pool().await;
@@ -288,3 +321,47 @@ async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay
drop(writers);
assert_eq!(writer_ids, vec![23]);
}
#[tokio::test]
async fn send_proxy_req_preserves_client_facing_our_addr_when_writer_source_ip_differs() {
let (pool, _rng) = make_pool().await;
pool.rr.store(0, Ordering::Relaxed);
let (conn_id, _rx) = pool.registry.register().await;
let mut live_rx = insert_writer(
&pool,
31,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 2, 31)), 443),
true,
)
.await;
{
let mut writers = pool.writers.write().await;
let writer = writers
.iter_mut()
.find(|writer| writer.id == 31)
.expect("test writer must exist");
writer.source_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 31));
}
let our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)), 8443);
let result = pool
.send_proxy_req(
conn_id,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 7)), 30002),
our_addr,
b"route",
0,
None,
)
.await;
assert!(result.is_ok());
let payload = recv_first_data_payload(&mut live_rx, Duration::from_millis(50))
.await
.expect("writer must receive routed payload");
assert_eq!(proxy_req_our_addr_from_payload(&payload), our_addr);
}
+73 -1
View File
@@ -9,7 +9,7 @@ use std::io::Result;
use std::net::{IpAddr, SocketAddr};
use std::time::Duration;
use tokio::net::TcpStream;
use tracing::debug;
use tracing::{debug, warn};
const DEFAULT_SOCKET_BUFFER_BYTES: usize = 256 * 1024;
@@ -125,6 +125,39 @@ pub fn clear_linger_fd(fd: std::os::unix::io::RawFd) -> Result<()> {
Ok(())
}
/// Raise the TCP MSS on an already-accepted connection's fd. Used to fragment
/// ONLY the TLS handshake (via a low listener MSS) and then restore a normal MSS
/// for the bulk (post-handshake) data phase — cuts packets-per-second ~10x without losing the
/// DPI evasion that the fragmented ServerHello provides. No-op safe: errors are
/// returned to the caller, which logs and continues with the handshake MSS.
#[cfg(target_os = "linux")]
pub fn set_tcp_mss_fd(fd: std::os::unix::io::RawFd, mss: u32) -> Result<()> {
use std::io::Error;
let mss = i32::try_from(mss)
.map_err(|_| Error::new(std::io::ErrorKind::InvalidInput, "bulk MSS out of range"))?;
// Direct setsockopt(TCP_MAXSEG) — same pattern as the TCP_USER_TIMEOUT call
// above; avoids socket2 method-name drift across versions.
let rc = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_MAXSEG,
&mss as *const libc::c_int as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
if rc != 0 {
return Err(Error::last_os_error());
}
Ok(())
}
/// Non-Linux stub: MSS shaping only on Linux (TCP_MAXSEG).
#[cfg(all(unix, not(target_os = "linux")))]
pub fn set_tcp_mss_fd(_fd: std::os::unix::io::RawFd, _mss: u32) -> Result<()> {
Ok(())
}
/// Create a new TCP socket for outgoing connections
#[allow(dead_code)]
pub fn create_outgoing_socket(addr: SocketAddr) -> Result<Socket> {
@@ -283,6 +316,8 @@ pub struct ListenOptions {
pub backlog: u32,
/// IPv6 only (disable dual-stack)
pub ipv6_only: bool,
/// Client-facing TCP MSS to announce on accepted TCP sessions.
pub client_mss: Option<u16>,
}
impl Default for ListenOptions {
@@ -292,6 +327,7 @@ impl Default for ListenOptions {
reuse_port: true,
backlog: 1024,
ipv6_only: false,
client_mss: None,
}
}
}
@@ -319,6 +355,19 @@ pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result<Sock
socket.set_only_v6(true)?;
}
if let Some(client_mss) = options.client_mss {
if let Err(error) = socket.set_tcp_mss(u32::from(client_mss)) {
warn!(
addr = %addr,
client_mss,
error = %error,
"Failed to apply listener client MSS; continuing with kernel default"
);
} else {
debug!(addr = %addr, client_mss, "Applied listener client MSS");
}
}
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
socket.listen(options.backlog as i32)?;
@@ -637,5 +686,28 @@ mod tests {
assert!(opts.reuse_addr);
assert!(opts.reuse_port);
assert_eq!(opts.backlog, 1024);
assert_eq!(opts.client_mss, None);
}
#[cfg(target_os = "linux")]
#[test]
fn test_create_listener_applies_client_mss() {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let options = ListenOptions {
reuse_port: false,
client_mss: Some(256),
..Default::default()
};
let socket = match create_listener(addr, &options) {
Ok(socket) => socket,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("create_listener failed: {e}"),
};
let mss = match socket.tcp_mss() {
Ok(mss) => mss,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("tcp_mss failed: {e}"),
};
assert_eq!(mss, 256);
}
}