Compare commits

...

119 Commits

Author SHA1 Message Date
Vladimir Krivopalov 50452b22c7
Merge 95685adba7 into 2d3c2807ab 2026-03-21 19:09:47 +00:00
Vladimir Krivopalov 95685adba7
Add multi-destination logging: syslog and file support
Implement logging infrastructure for non-systemd platforms:

- Add src/logging.rs with syslog and file logging support
- New CLI flags: --syslog, --log-file, --log-file-daily
- Syslog uses libc directly with LOG_DAEMON facility
- File logging via tracing-appender with optional daily rotation

Update service scripts:
- OpenRC and FreeBSD rc.d now use --syslog by default
- Ensures logs are captured on platforms without journald

Default (stderr) behavior unchanged for systemd compatibility.
Log destination is selected at startup based on CLI flags.

Signed-off-by: Vladimir Krivopalov <argenet@yandex.ru>
2026-03-21 21:09:29 +02:00
Vladimir Krivopalov 909714af31
Add multi-platform service manager integration
Implement automatic init system detection and service file generation
for systemd, OpenRC (Alpine/Gentoo), and FreeBSD rc.d:

- Add src/service module with init system detection and generators
- Auto-detect init system via filesystem probes
- Generate platform-appropriate service files during --init

systemd enhancements:
- ExecReload for SIGHUP config reload
- PIDFile directive
- Comprehensive security hardening (ProtectKernelTunables,
  RestrictAddressFamilies, MemoryDenyWriteExecute, etc.)
- CAP_NET_BIND_SERVICE for privileged ports

OpenRC support:
- Standard openrc-run script with depend/reload functions
- Directory setup in start_pre

FreeBSD rc.d support:
- rc.subr integration with rc.conf variables
- reload extra command

The --init command now detects the init system and runs the
appropriate enable/start commands (systemctl, rc-update, sysrc).

Signed-off-by: Vladimir Krivopalov <argenet@yandex.ru>
2026-03-21 21:09:29 +02:00
Vladimir Krivopalov dc2b4395bd
Add daemon lifecycle subcommands: start, stop, reload, status
Implement CLI subcommands for managing telemt as a daemon:

- `start [config.toml]` - Start as background daemon (implies --daemon)
- `stop` - Stop running daemon by sending SIGTERM
- `reload` - Reload configuration by sending SIGHUP
- `status` - Check if daemon is running via PID file

Subcommands use the PID file (default /var/run/telemt.pid) to locate
the running daemon. Stop command waits up to 10 seconds for graceful
shutdown. Status cleans up stale PID files automatically.

Updated help text with subcommand documentation and usage examples.
Exit codes follow Unix convention: 0 for success, 1 for not running
or error.

Signed-off-by: Vladimir Krivopalov <argenet@yandex.ru>
2026-03-21 21:09:29 +02:00
Vladimir Krivopalov 39875afbff
Add comprehensive Unix signal handling for daemon mode
Enhance signal handling to support proper daemon operation:

- SIGTERM: Graceful shutdown (same behavior as SIGINT)
- SIGQUIT: Graceful shutdown with full statistics dump
- SIGUSR1: Log rotation acknowledgment for external tools
- SIGUSR2: Dump runtime status to log without stopping

Statistics dump includes connection counts, ME keepalive metrics,
and relay adaptive tuning counters. SIGHUP config reload unchanged
(handled in hot_reload.rs).

Signals are handled via tokio::signal::unix with async select!
to avoid blocking the runtime. Non-shutdown signals (USR1/USR2)
run in a background task spawned at startup.

Signed-off-by: Vladimir Krivopalov <argenet@yandex.ru>
2026-03-21 21:09:29 +02:00
Vladimir Krivopalov 2ea7813ed4
Add Unix daemon mode with PID file and privilege dropping
Implement core daemon infrastructure for running telemt as a background
  service on Unix platforms (Linux, FreeBSD, etc.):

  - Add src/daemon module with classic double-fork daemonization
  - Implement flock-based PID file management to prevent duplicate instances
  - Add privilege dropping (setuid/setgid) after socket binding
  - New CLI flags: --daemon, --foreground, --pid-file, --run-as-user,
    --run-as-group, --working-dir

  Daemonization occurs before tokio runtime starts to ensure clean fork.
  PID file uses exclusive locking to detect already-running instances.
  Privilege dropping happens after bind_listeners() to allow binding
  to privileged ports (< 1024) before switching to unprivileged user.

Signed-off-by: Vladimir Krivopalov <argenet@yandex.ru>
2026-03-21 21:09:29 +02:00
Alexey 2d3c2807ab
Merge pull request #531 from DavidOsipov/flow
Small brittle test fix
2026-03-21 21:51:33 +03:00
David Osipov 50ae16ddf7
Add interval_gap_usize function and enhance integration test assertions for class separability 2026-03-21 22:49:39 +04:00
Alexey de5c26b7d7
Merge branch 'main' into flow 2026-03-21 21:46:45 +03:00
Alexey a95678988a
Merge pull request #530 from telemt/workflow
Update release.yml
2026-03-21 21:45:23 +03:00
Alexey b17482ede3
Update release.yml 2026-03-21 21:45:01 +03:00
Alexey a059de9191
Merge pull request #529 from DavidOsipov/flow
Усиление обхода DPI (Shape/Timing Hardening), защита от тайминг-атак и масштабное покрытие тестами
2026-03-21 21:31:05 +03:00
David Osipov e7e763888b
Implement aggressive shape hardening mode and related tests 2026-03-21 22:25:29 +04:00
David Osipov c0a3e43aa8
Add comprehensive security tests for proxy functionality
- Introduced client TLS record wrapping tests to ensure correct handling of empty and oversized payloads.
- Added integration tests for middle relay to validate quota saturation behavior under concurrent pressure.
- Implemented high-risk security tests covering various payload scenarios, including alignment checks and boundary conditions.
- Developed length cast hardening tests to verify proper handling of wire lengths and overflow conditions.
- Created quota overflow lock tests to ensure stable behavior under saturation and reclaim scenarios.
- Refactored existing middle relay security tests for improved clarity and consistency in lock handling.
2026-03-21 20:54:13 +04:00
David Osipov 4c32370b25
Refactor proxy and transport modules for improved safety and performance
- Enhanced linting rules in `src/proxy/mod.rs` to enforce stricter code quality checks in production.
- Updated hash functions in `src/proxy/middle_relay.rs` for better efficiency.
- Added new security tests in `src/proxy/tests/middle_relay_stub_completion_security_tests.rs` to validate desynchronization behavior.
- Removed ignored test stubs in `src/proxy/tests/middle_relay_security_tests.rs` to clean up the test suite.
- Improved error handling and code readability in various transport modules, including `src/transport/middle_proxy/config_updater.rs` and `src/transport/middle_proxy/pool.rs`.
- Introduced new padding functions in `src/stream/frame_stream_padding_security_tests.rs` to ensure consistent behavior across different implementations.
- Adjusted TLS stream validation in `src/stream/tls_stream.rs` for better boundary checking.
- General code cleanup and dead code elimination across multiple files to enhance maintainability.
2026-03-21 20:05:07 +04:00
Alexey a6c298b633
Merge branch 'main' into flow 2026-03-21 16:54:47 +03:00
Alexey e7a1d26e6e
Merge pull request #526 from telemt/workflow
Update release.yml
2026-03-21 16:48:53 +03:00
Alexey b91c6cb339
Update release.yml 2026-03-21 16:48:42 +03:00
Alexey e676633dcd
Merge branch 'main' into flow 2026-03-21 16:32:24 +03:00
Alexey c4e7f54cbe
Merge pull request #524 from telemt/workflow
Update release.yml
2026-03-21 16:31:15 +03:00
Alexey f85205d48d
Update release.yml 2026-03-21 16:31:05 +03:00
Alexey d767ec02ee
Update release.yml 2026-03-21 16:24:06 +03:00
Alexey 51835c33f2
Merge branch 'main' into flow 2026-03-21 16:19:02 +03:00
Alexey 88a4c652b6
Merge pull request #523 from telemt/workflow
Update release.yml
2026-03-21 16:18:48 +03:00
Alexey ea2d964502
Update release.yml 2026-03-21 16:18:24 +03:00
Alexey bd7218c39c
Merge branch 'main' into flow 2026-03-21 16:06:03 +03:00
Alexey 3055637571
Merge pull request #522 from telemt/workflow
Update release.yml
2026-03-21 16:04:56 +03:00
Alexey 19b84b9d73
Update release.yml 2026-03-21 16:03:54 +03:00
Alexey 165a1ede57
Merge branch 'main' into flow 2026-03-21 15:58:53 +03:00
Alexey 6ead8b1922
Merge pull request #521 from telemt/workflow
Update release.yml
2026-03-21 15:58:36 +03:00
Alexey 63aa1038c0
Update release.yml 2026-03-21 15:58:25 +03:00
Alexey 4473826303
Update crypto_bench.rs 2026-03-21 15:48:28 +03:00
Alexey d7bbb376c9
Format 2026-03-21 15:45:29 +03:00
Alexey 7a8f946029
Update Cargo.lock 2026-03-21 15:35:03 +03:00
Alexey f2e6dc1774
Update Cargo.toml 2026-03-21 15:27:21 +03:00
Alexey 54d65dd124
Merge branch 'main' into flow 2026-03-21 15:22:49 +03:00
Alexey 24594e648e
Merge pull request #519 from telemt/workflow
Update release.yml
2026-03-21 15:21:47 +03:00
Alexey e8b38ea860
Update release.yml 2026-03-21 15:21:25 +03:00
Alexey b14c2b0a9b
Merge pull request #517 from DavidOsipov/test/main-into-flow-sec
Усиление обхода DPI (Shape/Timing Hardening), защита от тайминг-атак и масштабное покрытие тестами
2026-03-21 15:03:05 +03:00
David Osipov c1ee43fbac
Add stress testing for quota-lock and refactor test guard usage 2026-03-21 15:54:14 +04:00
David Osipov c8632de5b6
Update dependencies and refactor random number generation
- Bump versions of several dependencies in Cargo.toml for improved functionality and security, including:
  - socket2 to 0.6
  - nix to 0.31
  - toml to 1.0
  - x509-parser to 0.18
  - dashmap to 6.1
  - rand to 0.10
  - reqwest to 0.13
  - notify to 8.2
  - ipnetwork to 0.21
  - webpki-roots to 1.0
  - criterion to 0.8
- Introduce `OnceLock` for secure random number generation in multiple modules to ensure thread safety and reduce overhead.
- Refactor random number generation calls to use the new `RngExt` trait methods for consistency and clarity.
- Add new PNG files for architectural documentation.
2026-03-21 15:43:07 +04:00
David Osipov b930ea1ec5
Add regression and security tests for relay quota and TLS stream handling
- Introduced regression tests for relay quota wake liveness to ensure proper handling of contention and wake events.
- Added adversarial tests to validate the behavior of the quota system under stress and contention scenarios.
- Implemented security tests for the TLS stream to verify the preservation of pending plaintext during state transitions.
- Enhanced the pool writer tests to ensure proper quarantine behavior and validate the removal of writers from the registry.
- Included fuzz testing to assess the robustness of the quota and TLS handling mechanisms against unexpected inputs and states.
2026-03-21 15:16:20 +04:00
David Osipov 3b86a883b9
Add comprehensive tests for relay quota management and adversarial scenarios
- Introduced `relay_quota_boundary_blackhat_tests.rs` to validate behavior under quota limits, including edge cases and adversarial conditions.
- Added `relay_quota_model_adversarial_tests.rs` to ensure quota management maintains integrity during bidirectional communication and various load scenarios.
- Created `relay_quota_overflow_regression_tests.rs` to address overflow issues and ensure that quota limits are respected during aggressive data transmission.
- Implemented `route_mode_coherence_adversarial_tests.rs` to verify the consistency of route mode transitions and timestamp management across different relay modes.
2026-03-21 14:14:58 +04:00
David Osipov 5933b5e821
Refactor and enhance tests for proxy and relay functionality
- Renamed test functions in `client_tls_clienthello_truncation_adversarial_tests.rs` to remove "but_leaks" suffix for clarity.
- Added new tests in `direct_relay_business_logic_tests.rs` to validate business logic for data center resolution and scope hints.
- Introduced tests in `direct_relay_common_mistakes_tests.rs` to cover common mistakes in direct relay configurations.
- Added security tests in `direct_relay_security_tests.rs` to ensure proper handling of symlink and parent swap scenarios.
- Created `direct_relay_subtle_adversarial_tests.rs` to stress test concurrent logging and validate scope hint behavior.
- Implemented `relay_quota_lock_pressure_adversarial_tests.rs` to test quota lock behavior under high contention and stress.
- Updated `relay_security_tests.rs` to include quota lock contention tests ensuring proper behavior under concurrent access.
- Introduced `ip_tracker_hotpath_adversarial_tests.rs` to validate the performance and correctness of the IP tracking logic under various scenarios.
2026-03-21 13:38:17 +04:00
David Osipov 8188fedf6a
Add masking shape classifier and guard tests for adversarial resistance
- Implemented tests for masking shape classifier resistance against threshold attacks, ensuring that blurring reduces accuracy and increases overlap between classes.
- Added tests for masking shape guard functionality, verifying that it maintains expected behavior under various conditions, including timeout paths and clean EOF scenarios.
- Introduced helper functions for calculating accuracy and handling timing samples to support the new tests.
- Ensured that the masking shape hardening configuration is properly utilized in tests to validate its effectiveness.
2026-03-21 12:43:25 +04:00
David Osipov f2335c211c
Version change before PR 2026-03-21 11:19:51 +04:00
David Osipov 246ca11b88
Crates update 2026-03-21 11:18:43 +04:00
David Osipov bb355e916f
Add comprehensive security tests for masking and shape hardening features
- Introduced red-team expected-fail tests for client masking shape hardening.
- Added integration tests for masking AB envelope blur to improve obfuscation.
- Implemented masking security tests to validate the behavior of masking under various conditions.
- Created tests for masking shape above-cap blur to ensure proper functionality.
- Developed adversarial tests for masking shape hardening to evaluate robustness against attacks.
- Added timing normalization security tests to assess the effectiveness of timing obfuscation.
- Implemented red-team expected-fail tests for timing side-channel vulnerabilities.
2026-03-21 00:30:51 +04:00
David Osipov 8814854ae4
actually, it's a better one 2026-03-20 23:27:56 +04:00
David Osipov 44c65f9c60
changed version 2026-03-20 23:27:29 +04:00
David Osipov 1260217be9
Normalize Cargo.lock after upstream merge 2026-03-20 23:22:29 +04:00
David Osipov ebd37932c5
Merge latest upstream/main into test/main-into-flow-sec 2026-03-20 23:21:22 +04:00
David Osipov 43d7e6e991
moved tests to subdirs 2026-03-20 22:55:19 +04:00
David Osipov 0eca535955
Refactor TLS fallback tests to remove unnecessary client hello assertions
- Removed assertions for expected client hello messages in multiple TLS fallback tests to streamline the test logic.
- Updated the tests to focus on verifying the trailing TLS records received after the fallback.
- Enhanced the masking functionality by adding shape hardening features, including dynamic padding based on sent data size.
- Modified the relay_to_mask function to accommodate new parameters for shape hardening.
- Updated masking security tests to reflect changes in the relay_to_mask function signature.
2026-03-20 22:44:39 +04:00
David Osipov 3abde52de8
refactor: update TLS record size constants and related validations
- Rename MAX_TLS_RECORD_SIZE to MAX_TLS_PLAINTEXT_SIZE for clarity.
- Rename MAX_TLS_CHUNK_SIZE to MAX_TLS_CIPHERTEXT_SIZE to reflect its purpose.
- Deprecate old constants in favor of new ones.
- Update various parts of the codebase to use the new constants, including validation checks and tests.
- Add new tests to ensure compliance with RFC 8446 regarding TLS record sizes.
2026-03-20 21:00:36 +04:00
David Osipov 801f670827
Add comprehensive TLS ClientHello size validation and adversarial tests
- Refactor existing tests to improve clarity and specificity in naming.
- Introduce new tests for minimum and maximum TLS ClientHello sizes, ensuring proper masking behavior for malformed probes.
- Implement differential timing tests to compare latency between malformed TLS and plain web requests, ensuring similar performance characteristics.
- Add adversarial tests for truncated TLS ClientHello probes, verifying that even malformed traffic is masked as legitimate responses.
- Enhance the overall test suite for robustness against probing attacks, focusing on edge cases and potential vulnerabilities in TLS handling.
2026-03-20 20:30:02 +04:00
David Osipov 1689b8a5dc
Changed version 2026-03-20 18:49:17 +04:00
David Osipov babd902d95
Add adversarial tests for MTProto handshake and enhance masking functionality
- Introduced multiple adversarial tests for MTProto handshake to ensure robustness against replay attacks, invalid mutations, and concurrent flooding.
- Implemented a function to build proxy headers based on the specified version, improving the handling of masking protocols.
- Added tests to validate the behavior of the masking functionality under various conditions, including unknown proxy protocol versions and oversized payloads.
- Enhanced relay tests to ensure stability and performance under high load and half-close scenarios.
2026-03-20 18:48:19 +04:00
David Osipov 9dce748679
changed version 2026-03-20 18:04:37 +04:00
David Osipov 79093679ab
Merge latest upstream/main into test/main-into-flow-sec 2026-03-20 18:00:20 +04:00
David Osipov 35a8f5b2e5
Add method to retrieve inner reader with pending plaintext
This commit introduces the `into_inner_with_pending_plaintext` method to the `FakeTlsReader` struct. This method allows users to extract the underlying reader along with any pending plaintext data that may have been buffered during the TLS reading process. The method handles the state transition and ensures that any buffered data is returned as a vector, facilitating easier management of plaintext data in TLS streams.
2026-03-20 17:56:37 +04:00
David Osipov 456c433875
Обновил версию 2026-03-20 17:34:09 +04:00
David Osipov 8f1ffe8c25
fix(proxy): исправление wire-transparency при fallback и усиление безопасности
Исправлена критическая логическая ошибка в цепочке Fake TLS -> MTProto.
Ранее при валидном TLS-хендшейке, но неверном MTProto-пакете, прокси
ошибочно передавал в маскирующий релей обернутый (FakeTls) поток.
Теперь транспорт корректно разворачивается (unwrap) до сырого сокета
через .into_inner(), обеспечивая полную прозрачность (wire-transparency)
для DPI и маскирующего бэкенда.

Security & Hardening:
- Логика приведена в соответствие с требованиями OWASP ASVS L2 (V5: Validation, Sanitization and Encoding).
- Реализовано поведение "fail-closed": при любой ошибке верификации прокси мимикрирует под обычный веб-сервер, не раскрывая своей роли.
- Улучшена диагностика и логирование состояний аутентификации для защиты от активного пробинга.

Adversarial Testing (Black-hat mindset):
- Добавлен отдельный пакет `client_tls_mtproto_fallback_security_tests.rs` (18+ тестов).
- Покрыты сценарии: хаос-фрагментация (побайтовая нарезка TLS-записей), record-splitting,
  half-close состояния, сбросы бэкенда и replay-pressure.
- В `client_adversarial_tests.rs` добавлено 10+ тестов на "злые" гонки (race conditions),
  утечки лимитов по IP и проверку изоляции состояний параллельных сессий.
- Все 832 теста проходят (passed) в locked-режиме.
2026-03-20 17:33:46 +04:00
David Osipov a78c3e3ebd
One more small test fix 2026-03-20 16:48:14 +04:00
David Osipov a4b70405b8
Add adversarial tests module for client security testing 2026-03-20 16:47:26 +04:00
David Osipov 3afc3e1775
Changed version 2026-03-20 16:46:09 +04:00
David Osipov 512bee6a8d
Add security tests for middle relay idle policy and enhance stats tracking
- Introduced a new test module for middle relay idle policy security tests, covering various scenarios including soft mark, hard close, and grace periods.
- Implemented functions to create crypto readers and encrypt data for testing.
- Enhanced the Stats struct to include counters for relay idle soft marks, hard closes, pressure evictions, and protocol desync closes.
- Added corresponding increment and retrieval methods for the new stats fields.
2026-03-20 16:43:50 +04:00
David Osipov 5c5fdcb124
Updated cargo 2026-03-20 15:03:42 +04:00
David Osipov 0ded366199
Changed version 2026-03-20 14:29:45 +04:00
David Osipov 84a34cea3d
Merge latest upstream/main into test/main-into-flow-sec 2026-03-20 14:26:49 +04:00
David Osipov 7dc3c3666d
Merge upstream/main into test/main-into-flow-sec 2026-03-20 14:20:20 +04:00
David Osipov 6ea8ba25c4
Refactor OpenBSD build workflow for clarity 2026-03-20 02:27:21 +04:00
David Osipov 3f3bf5bbd2
Update build-openbsd.yml 2026-03-20 01:27:11 +04:00
David Osipov ec793f3065
Added cargo.toml 2026-03-20 01:06:00 +04:00
David Osipov e83d366518
Fixed issues with an action 2026-03-20 00:58:11 +04:00
David Osipov 5a4209fe00
Changed version 2026-03-20 00:53:32 +04:00
David Osipov e7daf51193
Added runner for Openbsd 2026-03-20 00:43:05 +04:00
David Osipov 754e4db8a9
Add security tests for pool writer and pool refill functionality 2026-03-20 00:07:41 +04:00
David Osipov 7416829e89
Merge remote-tracking branch 'upstream/main' into test/main-into-flow-sec
# Conflicts:
#	Cargo.toml
#	src/api/model.rs
#	src/api/runtime_stats.rs
#	src/transport/middle_proxy/health.rs
#	src/transport/middle_proxy/health_regression_tests.rs
#	src/transport/middle_proxy/pool_status.rs
2026-03-19 23:48:40 +04:00
David Osipov c07b600acb
Integration hardening: reconcile main+flow-sec API drift and restore green suite 2026-03-19 20:24:44 +04:00
David Osipov 7b44496706
Integration test merge: upstream/main into flow-sec security branch (prefer flow-sec on conflicts) 2026-03-19 19:42:04 +04:00
David Osipov e6ad9e4c7f
Add security tests for connection limits and handshake integrity
- Implement a test to ensure that exceeding the user connection limit does not leak the current connections counter.
- Add tests for direct relay connection refusal and adversarial scenarios to verify proper error handling.
- Introduce fuzz testing for MTProto handshake to ensure robustness against malformed inputs and replay attacks.
- Remove obsolete short TLS probe throttle tests and integrate their functionality into existing security tests.
- Enhance middle relay tests to validate behavior during connection drops and cutovers, ensuring graceful error handling.
- Add a test for half-close scenarios in relay to confirm bidirectional data flow continues as expected.
2026-03-19 17:31:19 +04:00
David Osipov 2a01ca2d6f
Add adversarial tests for client, handshake, masking, and relay modules
- Introduced `client_adversarial_tests.rs` to stress test connection limits and IP tracker race conditions.
- Added `handshake_adversarial_tests.rs` for mutational bit-flipping tests and timing neutrality checks.
- Created `masking_adversarial_tests.rs` to validate probing indistinguishability and SSRF prevention.
- Implemented `relay_adversarial_tests.rs` to ensure HOL blocking prevention and data quota enforcement.
- Updated respective modules to include new test paths.
2026-03-19 17:31:19 +04:00
Alexey 44376b5652
Merge pull request #463 from DavidOsipov/pr-sec-1
[WIP] Enhance metrics configuration, add health monitoring tests, security hardening, perf optimizations & loads of tests
2026-03-18 23:02:58 +03:00
David Osipov c7cf37898b
feat: enhance quota user lock management and testing
- Adjusted QUOTA_USER_LOCKS_MAX based on test and non-test configurations to improve flexibility.
- Implemented logic to retain existing locks when the maximum quota is reached, ensuring efficient memory usage.
- Added comprehensive tests for quota user lock functionality, including cache reuse, saturation behavior, and race conditions.
- Enhanced StatsIo struct to manage wake scheduling for read and write operations, preventing unnecessary self-wakes.
- Introduced separate replay checker domains for handshake and TLS to ensure isolation and prevent cross-pollution of keys.
- Added security tests for replay checker to validate domain separation and window clamping behavior.
2026-03-18 23:55:08 +04:00
David Osipov 20e205189c
Enhance TLS Emulator with ALPN Support and Add Adversarial Tests
- Modified `build_emulated_server_hello` to accept ALPN (Application-Layer Protocol Negotiation) as an optional parameter, allowing for the embedding of ALPN markers in the application data payload.
- Implemented logic to handle oversized ALPN values and ensure they do not interfere with the application data payload.
- Added new security tests in `emulator_security_tests.rs` to validate the behavior of the ALPN embedding, including scenarios for oversized ALPN and preference for certificate payloads over ALPN markers.
- Introduced `send_adversarial_tests.rs` to cover edge cases and potential issues in the middle proxy's send functionality, ensuring robustness against various failure modes.
- Updated `middle_proxy` module to include new test modules and ensure proper handling of writer commands during data transmission.
2026-03-18 17:04:50 +04:00
David Osipov 97d4a1c5c8
Refactor and enhance security in proxy and handshake modules
- Updated `direct_relay_security_tests.rs` to ensure sanitized paths are correctly validated against resolved paths.
- Added tests for symlink handling in `unknown_dc_log_path_revalidation` to prevent symlink target escape vulnerabilities.
- Modified `handshake.rs` to use a more robust hashing strategy for eviction offsets, improving the eviction logic in `auth_probe_record_failure_with_state`.
- Introduced new tests in `handshake_security_tests.rs` to validate eviction logic under various conditions, ensuring low fail streak entries are prioritized for eviction.
- Simplified `route_mode.rs` by removing unnecessary atomic mode tracking, streamlining the transition logic in `RouteRuntimeController`.
- Enhanced `route_mode_security_tests.rs` with comprehensive tests for mode transitions and their effects on session states, ensuring consistency under concurrent modifications.
- Cleaned up `emulator.rs` by removing unused ALPN extension handling, improving code clarity and maintainability.
2026-03-18 01:40:38 +04:00
David Osipov c2443e6f1a
Refactor auth probe eviction logic and improve performance
- Simplified eviction candidate selection in `auth_probe_record_failure_with_state` by tracking the oldest candidate directly.
- Enhanced the handling of stale entries to ensure newcomers are tracked even under capacity constraints.
- Added tests to verify behavior under stress conditions and ensure newcomers are correctly managed.
- Updated `decode_user_secrets` to prioritize preferred users based on SNI hints.
- Introduced new tests for TLS SNI handling and replay protection mechanisms.
- Improved deduplication hash stability and collision resistance in middle relay logic.
- Refined cutover handling in route mode to ensure consistent error messaging and session management.
2026-03-18 00:38:59 +04:00
David Osipov a7cffb547e
Implement idle timeout for masking relay and add corresponding tests
- Introduced `copy_with_idle_timeout` function to handle reading and writing with an idle timeout.
- Updated the proxy masking logic to use the new idle timeout function.
- Added tests to verify that idle relays are closed by the idle timeout before the global relay timeout.
- Ensured that connect refusal paths respect the masking budget and that responses followed by silence are cut off by the idle timeout.
- Added tests for adversarial scenarios where clients may attempt to drip-feed data beyond the idle timeout.
2026-03-17 22:48:13 +04:00
David Osipov f0c37f233e
Refactor health management: implement remove_writer_if_empty method for cleaner writer removal logic and update related functions to enhance efficiency in handling closed writers. 2026-03-17 21:38:15 +04:00
David Osipov 60953bcc2c
Refactor user connection limit checks and enhance health monitoring tests: update warning messages, add new tests for draining writers, and improve state management 2026-03-17 20:53:37 +04:00
David Osipov 2c06288b40
Enhance UserConnectionReservation: add runtime handle for cross-thread IP cleanup and implement tests for user expiration and connection limits 2026-03-17 20:21:01 +04:00
David Osipov 0284b9f9e3
Refactor health integration tests to use wait_for_pool_empty for improved readability and timeout handling 2026-03-17 20:14:07 +04:00
David Osipov 4e3f42dce3
Add must_use attribute to UserConnectionReservation and RouteConnectionLease structs for better resource management 2026-03-17 19:55:55 +04:00
David Osipov 50a827e7fd
Merge upstream/flow-sec into pr-sec-1 2026-03-17 19:48:53 +04:00
David Osipov d81140ccec
Enhance UserConnectionReservation management: add active state and release method, improve cleanup on drop, and implement tests for immediate release and concurrent handling 2026-03-17 19:39:29 +04:00
David Osipov c540a6657f
Implement user connection reservation management and enhance relay task handling in proxy 2026-03-17 19:05:26 +04:00
David Osipov 4808a30185
Merge upstream/main into flow-sec rehearsal: resolve config and middle-proxy health conflicts 2026-03-17 18:35:54 +04:00
David Osipov 1357f3cc4c
bump version to 3.3.20 and implement connection lease management for direct and middle relays 2026-03-17 18:16:17 +04:00
David Osipov d9aa6f4956
Merge upstream/main into pr-sec-1 2026-03-17 17:49:10 +04:00
Alexey 4f55d08c51
Merge pull request #454 from DavidOsipov/pr-sec-1
PR-SEC-1: Доп. харденинг и маскинг
2026-03-17 15:35:08 +03:00
David Osipov 93caab1aec
feat(proxy): refactor auth probe failure handling and add concurrent failure tests 2026-03-17 16:25:29 +04:00
David Osipov 0c6bb3a641
feat(proxy): implement auth probe eviction logic and corresponding tests 2026-03-17 15:43:07 +04:00
David Osipov b2e15327fe
feat(proxy): enhance auth probe handling with IPv6 normalization and eviction logic 2026-03-17 15:15:12 +04:00
Alexey 2e8be87ccf
ME Writer Draining-state fixes 2026-03-17 13:58:01 +03:00
Alexey d78360982c
Hot-Reload fixes 2026-03-17 13:02:12 +03:00
Alexey 822bcbf7a5
Update Cargo.toml 2026-03-17 11:21:35 +03:00
Alexey b25ec97a43
Merge pull request #447 from DavidOsipov/pr-sec-1
PR-SEC-1 (WIP): Первый PR с узкой пачкой исправлений безопасности и маскировки. Упор сделан на /src/proxy
2026-03-17 11:20:36 +03:00
David Osipov 8821e38013
feat(proxy): enhance auth probe capacity with stale entry pruning and new tests 2026-03-17 02:19:14 +04:00
David Osipov a1caebbe6f
feat(proxy): implement timeout handling for client payload reads and add corresponding tests 2026-03-17 01:53:44 +04:00
David Osipov e0d821c6b6
Merge remote-tracking branch 'upstream/main' into pr-sec-1 2026-03-17 01:51:35 +04:00
David Osipov 205fc88718
feat(proxy): enhance logging and deduplication for unknown datacenters
- Implemented a mechanism to log unknown datacenter indices with a distinct limit to avoid excessive logging.
- Introduced tests to ensure that logging is deduplicated per datacenter index and respects the distinct limit.
- Updated the fallback logic for datacenter resolution to prevent panics when only a single datacenter is available.

feat(proxy): add authentication probe throttling

- Added a pre-authentication probe throttling mechanism to limit the rate of invalid TLS and MTProto handshake attempts.
- Introduced a backoff strategy for repeated failures and ensured that successful handshakes reset the failure count.
- Implemented tests to validate the behavior of the authentication probe under various conditions.

fix(proxy): ensure proper flushing of masked writes

- Added a flush operation after writing initial data to the mask writer to ensure data integrity.

refactor(proxy): optimize desynchronization deduplication

- Replaced the Mutex-based deduplication structure with a DashMap for improved concurrency and performance.
- Implemented a bounded cache for deduplication to limit memory usage and prevent stale entries from persisting.

test(proxy): enhance security tests for middle relay and handshake

- Added comprehensive tests for the middle relay and handshake processes, including scenarios for deduplication and authentication probe behavior.
- Ensured that the tests cover edge cases and validate the expected behavior of the system under load.
2026-03-17 01:29:30 +04:00
David Osipov e4a50f9286
feat(tls): add boot time timestamp constant and validation for SNI hostnames
- Introduced `BOOT_TIME_MAX_SECS` constant to define the maximum accepted boot-time timestamp.
- Updated `validate_tls_handshake_at_time` to utilize the new boot time constant for timestamp validation.
- Enhanced `extract_sni_from_client_hello` to validate SNI hostnames against specified criteria, rejecting invalid hostnames.
- Added tests to ensure proper handling of boot time timestamps and SNI validation.

feat(handshake): improve user secret decoding and ALPN enforcement

- Refactored user secret decoding to provide better error handling and logging for invalid secrets.
- Added tests for concurrent identical handshakes to ensure replay protection works as expected.
- Implemented ALPN enforcement in handshake processing, rejecting unsupported protocols and allowing valid ones.

fix(masking): implement timeout handling for masking operations

- Added timeout handling for writing proxy headers and consuming client data in masking.
- Adjusted timeout durations for testing to ensure faster feedback during unit tests.
- Introduced tests to verify behavior when masking is disabled and when proxy header writes exceed the timeout.

test(masking): add tests for slowloris connections and proxy header timeouts

- Created tests to validate that slowloris connections are closed by consume timeout when masking is disabled.
- Added a test for proxy header write timeout to ensure it returns false when the write operation does not complete.
2026-03-16 21:37:59 +04:00
David Osipov 213ce4555a
Merge remote-tracking branch 'upstream/main' into pr-sec-1 2026-03-16 20:51:53 +04:00
David Osipov 5a16e68487
Enhance TLS record handling and security tests
- Enforce TLS record length constraints in client handling to comply with RFC 8446, rejecting records outside the range of 512 to 16,384 bytes.
- Update security tests to validate behavior for oversized and undersized TLS records, ensuring they are correctly masked or rejected.
- Introduce new tests to verify the handling of TLS records in both generic and client handler pipelines.
- Refactor handshake logic to enforce mode restrictions based on transport type, preventing misuse of secure tags.
- Add tests for nonce generation and encryption consistency, ensuring correct behavior for different configurations.
- Improve masking tests to ensure proper logging and detection of client types, including SSH and unknown probes.
2026-03-16 20:43:49 +04:00
David Osipov 6ffbc51fb0
security: harden handshake/masking flows and add adversarial regressions
- forward valid-TLS/invalid-MTProto clients to mask backend in both client paths\n- harden TLS validation against timing and clock edge cases\n- move replay tracking behind successful authentication to avoid cache pollution\n- tighten secret decoding and key-material handling paths\n- add dedicated security test modules for tls/client/handshake/masking\n- include production-path regression for ClientHandler fallback behavior
2026-03-16 20:04:41 +04:00
David Osipov dcab19a64f
ci: remove CI workflow changes (deferred to later PR) 2026-03-16 13:56:46 +04:00
David Osipov f10ca192fa
chore: merge upstream/main (92972ab) into pr-sec-1 2026-03-16 13:50:46 +04:00
David Osipov 2bd9036908
ci: add security policy, cargo-deny configuration, and audit workflow
- Add deny.toml with license/advisory policy for cargo-deny
- Add security.yml GitHub Actions workflow for automated audit
- Update rust.yml with hardened clippy lint enforcement
- Update Cargo.toml/Cargo.lock with audit-related dependency additions
- Fix clippy lint placement in config.toml (Clippy lints must not live in rustflags)

Part of PR-SEC-1: no Rust source changes, establishes CI gates for all subsequent PRs.
2026-03-15 00:30:36 +04:00
190 changed files with 52299 additions and 6790 deletions

15
.cargo/deny.toml Normal file
View File

@ -0,0 +1,15 @@
[bans]
multiple-versions = "deny"
wildcards = "allow"
highlight = "all"
# Explicitly flag the weak cryptography so the agent is forced to justify its existence
[[bans.skip]]
name = "md-5"
version = "*"
reason = "MUST VERIFY: Only allowed for legacy checksums, never for security."
[[bans.skip]]
name = "sha1"
version = "*"
reason = "MUST VERIFY: Only allowed for backwards compatibility."

View File

@ -4,136 +4,213 @@ on:
push:
tags:
- '[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+-*'
workflow_dispatch:
concurrency:
group: release-${{ github.ref }}
cancel-in-progress: true
permissions:
contents: read
packages: write
env:
CARGO_TERM_COLOR: always
RUST_BACKTRACE: "1"
BINARY_NAME: telemt
jobs:
build:
name: Build ${{ matrix.target }}
prepare:
runs-on: ubuntu-latest
permissions:
contents: read
outputs:
version: ${{ steps.meta.outputs.version }}
prerelease: ${{ steps.meta.outputs.prerelease }}
release_enabled: ${{ steps.meta.outputs.release_enabled }}
steps:
- id: meta
run: |
set -euo pipefail
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
VERSION="${GITHUB_REF#refs/tags/}"
RELEASE_ENABLED=true
else
VERSION="manual-${GITHUB_SHA::7}"
RELEASE_ENABLED=false
fi
if [[ "$VERSION" == *"-alpha"* || "$VERSION" == *"-beta"* || "$VERSION" == *"-rc"* ]]; then
PRERELEASE=true
else
PRERELEASE=false
fi
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
echo "prerelease=$PRERELEASE" >> "$GITHUB_OUTPUT"
echo "release_enabled=$RELEASE_ENABLED" >> "$GITHUB_OUTPUT"
checks:
runs-on: ubuntu-latest
container:
image: debian:trixie
steps:
- run: |
apt-get update
apt-get install -y build-essential clang llvm pkg-config curl git
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- uses: actions/cache@v4
with:
path: |
/github/home/.cargo/registry
/github/home/.cargo/git
target
key: checks-${{ hashFiles('**/Cargo.lock') }}
- run: cargo fetch --locked
- run: cargo fmt --all -- --check
- run: cargo clippy
- run: cargo test
build-binaries:
needs: [prepare, checks]
runs-on: ubuntu-latest
container:
image: debian:trixie
strategy:
fail-fast: false
matrix:
include:
- target: x86_64-unknown-linux-gnu
artifact_name: telemt
- rust_target: x86_64-unknown-linux-gnu
zig_target: x86_64-unknown-linux-gnu.2.28
asset_name: telemt-x86_64-linux-gnu
- target: aarch64-unknown-linux-gnu
artifact_name: telemt
- rust_target: aarch64-unknown-linux-gnu
zig_target: aarch64-unknown-linux-gnu.2.28
asset_name: telemt-aarch64-linux-gnu
- target: x86_64-unknown-linux-musl
artifact_name: telemt
- rust_target: x86_64-unknown-linux-musl
zig_target: x86_64-unknown-linux-musl
asset_name: telemt-x86_64-linux-musl
- target: aarch64-unknown-linux-musl
artifact_name: telemt
- rust_target: aarch64-unknown-linux-musl
zig_target: aarch64-unknown-linux-musl
asset_name: telemt-aarch64-linux-musl
steps:
- run: |
apt-get update
apt-get install -y clang llvm pkg-config curl git python3 python3-pip file tar xz-utils
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@v1
- uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
targets: ${{ matrix.target }}
- name: Install cross-compilation tools
run: |
sudo apt-get update
sudo apt-get install -y gcc-aarch64-linux-gnu
targets: ${{ matrix.rust_target }}
- uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
/github/home/.cargo/registry
/github/home/.cargo/git
target
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-${{ matrix.target }}-cargo-
key: build-${{ matrix.zig_target }}-${{ hashFiles('**/Cargo.lock') }}
- name: Install cross
run: cargo install cross --git https://github.com/cross-rs/cross
- run: |
python3 -m pip install --user --break-system-packages cargo-zigbuild
echo "/github/home/.local/bin" >> "$GITHUB_PATH"
- name: Build Release
env:
RUSTFLAGS: ${{ contains(matrix.target, 'musl') && '-C target-feature=+crt-static' || '' }}
run: cross build --release --target ${{ matrix.target }}
- run: cargo fetch --locked
- name: Package binary
run: |
cd target/${{ matrix.target }}/release
tar -czvf ${{ matrix.asset_name }}.tar.gz ${{ matrix.artifact_name }}
sha256sum ${{ matrix.asset_name }}.tar.gz > ${{ matrix.asset_name }}.sha256
- run: |
cargo zigbuild --release --locked --target "${{ matrix.zig_target }}"
- run: |
BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}"
llvm-strip "$BIN" || true
- run: |
BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}"
OUT="$RUNNER_TEMP/${{ matrix.asset_name }}"
mkdir -p "$OUT"
install -m755 "$BIN" "$OUT/${BINARY_NAME}"
tar -C "$RUNNER_TEMP" -czf "${{ matrix.asset_name }}.tar.gz" "${{ matrix.asset_name }}"
sha256sum "${{ matrix.asset_name }}.tar.gz" > "${{ matrix.asset_name }}.sha256"
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.asset_name }}
path: |
target/${{ matrix.target }}/release/${{ matrix.asset_name }}.tar.gz
target/${{ matrix.target }}/release/${{ matrix.asset_name }}.sha256
${{ matrix.asset_name }}.tar.gz
${{ matrix.asset_name }}.sha256
build-docker-image:
needs: build
docker-image:
name: Docker ${{ matrix.platform }}
needs: [prepare, build-binaries]
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
strategy:
matrix:
include:
- platform: linux/amd64
artifact: telemt-x86_64-linux-gnu
- platform: linux/arm64
artifact: telemt-aarch64-linux-gnu
steps:
- uses: actions/checkout@v4
- uses: docker/setup-qemu-action@v3
- uses: actions/download-artifact@v4
with:
name: ${{ matrix.artifact }}
path: dist
- run: |
mkdir docker-build
tar -xzf dist/*.tar.gz -C docker-build --strip-components=1
- uses: docker/setup-buildx-action@v3
- name: Login to GHCR
- name: Login
if: ${{ needs.prepare.outputs.release_enabled == 'true' }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract version
id: vars
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Build and push
uses: docker/build-push-action@v6
- uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }}
ghcr.io/${{ github.repository }}:latest
context: ./docker-build
platforms: ${{ matrix.platform }}
push: ${{ needs.prepare.outputs.release_enabled == 'true' }}
tags: ghcr.io/${{ github.repository }}:${{ needs.prepare.outputs.version }}
cache-from: type=gha,scope=telemt-${{ matrix.platform }}
cache-to: type=gha,mode=max,scope=telemt-${{ matrix.platform }}
provenance: false
sbom: false
release:
name: Create Release
needs: build
if: ${{ needs.prepare.outputs.release_enabled == 'true' }}
needs: [prepare, build-binaries]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/download-artifact@v4
with:
path: artifacts
path: release-artifacts
pattern: telemt-*
- name: Create Release
uses: softprops/action-gh-release@v2
- run: |
mkdir upload
find release-artifacts -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \;
- uses: softprops/action-gh-release@v2
with:
files: artifacts/**/*
files: upload/*
generate_release_notes: true
draft: false
prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }}
prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }}

View File

@ -45,6 +45,18 @@ jobs:
- name: Run tests
run: cargo test --verbose
- name: Stress quota-lock suites (PR only)
if: github.event_name == 'pull_request'
env:
RUST_TEST_THREADS: 16
run: |
set -euo pipefail
for i in $(seq 1 12); do
echo "[quota-lock-stress] iteration ${i}/12"
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
done
# clippy dont fail on warnings because of active development of telemt
# and many warnings
- name: Run clippy

1
.gitignore vendored
View File

@ -21,3 +21,4 @@ target
#.idea/
proxy-secret
coverage-html/

View File

@ -5,6 +5,22 @@ Your responses are precise, minimal, and architecturally sound. You are working
---
### Context: The Telemt Project
You are working on **Telemt**, a high-performance, production-grade Telegram MTProxy implementation written in Rust. It is explicitly designed to operate in highly hostile network environments and evade advanced network censorship.
**Adversarial Threat Model:**
The proxy operates under constant surveillance by DPI (Deep Packet Inspection) systems and active scanners (state firewalls, mobile operator fraud controls). These entities actively probe IPs, analyze protocol handshakes, and look for known proxy signatures to block or throttle traffic.
**Core Architectural Pillars:**
1. **TLS-Fronting (TLS-F) & TCP-Splitting (TCP-S):** To the outside world, Telemt looks like a standard TLS server. If a client presents a valid MTProxy key, the connection is handled internally. If a censor's scanner, web browser, or unauthorized crawler connects, Telemt seamlessly splices the TCP connection (L4) to a real, legitimate HTTPS fallback server (e.g., Nginx) without modifying the `ClientHello` or terminating the TLS handshake.
2. **Middle-End (ME) Orchestration:** A highly concurrent, generation-based pool managing upstream connections to Telegram Datacenters (DCs). It utilizes an **Adaptive Floor** (dynamically scaling writer connections based on traffic), **Hardswaps** (zero-downtime pool reconfiguration), and **STUN/NAT** reflection mechanisms.
3. **Strict KDF Routing:** Cryptographic Key Derivation Functions (KDF) in this protocol strictly rely on the exact pairing of Source IP/Port and Destination IP/Port. Deviations or missing port logic will silently break the MTProto handshake.
4. **Data Plane vs. Control Plane Isolation:** The Data Plane (readers, writers, payload relay, TCP splicing) must remain strictly non-blocking, zero-allocation in hot paths, and highly resilient to network backpressure. The Control Plane (API, metrics, pool generation swaps, config reloads) orchestrates the state asynchronously without stalling the Data Plane.
Any modification you make must preserve Telemt's invisibility to censors, its strict memory-safety invariants, and its hot-path throughput.
### 0. Priority Resolution — Scope Control
This section resolves conflicts between code quality enforcement and scope limitation.
@ -374,6 +390,12 @@ you MUST explain why existing invariants remain valid.
- Do not modify existing tests unless the task explicitly requires it.
- Do not weaken assertions.
- Preserve determinism in testable components.
- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new.
- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code.
- Differential/model catches logic drift over time.
- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run.
- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly.
- Dead parameter is a code smell rule.
### 15. Security Constraints

785
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[package]
name = "telemt"
version = "3.3.28"
version = "3.3.29"
edition = "2024"
[dependencies]
@ -22,28 +22,37 @@ hmac = "0.12"
crc32fast = "1.4"
crc32c = "0.6"
zeroize = { version = "1.8", features = ["derive"] }
subtle = "2.6"
static_assertions = "1.1"
# Network
socket2 = { version = "0.5", features = ["all"] }
nix = { version = "0.28", default-features = false, features = ["net"] }
socket2 = { version = "0.6", features = ["all"] }
nix = { version = "0.31", default-features = false, features = [
"net",
"user",
"process",
"fs",
"signal",
] }
shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] }
# Serialization
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
toml = "0.8"
x509-parser = "0.15"
toml = "1.0"
x509-parser = "0.18"
# Utils
bytes = "1.9"
thiserror = "2.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-appender = "0.2"
parking_lot = "0.12"
dashmap = "5.5"
dashmap = "6.1"
arc-swap = "1.7"
lru = "0.16"
rand = "0.9"
rand = "0.10"
chrono = { version = "0.4", features = ["serde"] }
hex = "0.4"
base64 = "0.22"
@ -52,23 +61,30 @@ regex = "1.11"
crossbeam-queue = "0.3"
num-bigint = "0.4"
num-traits = "0.2"
x25519-dalek = "2"
anyhow = "1.0"
# HTTP
reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
notify = { version = "6", features = ["macos_fsevent"] }
ipnetwork = "0.20"
reqwest = { version = "0.13", features = ["rustls"], default-features = false }
notify = "8.2"
ipnetwork = { version = "0.21", features = ["serde"] }
hyper = { version = "1", features = ["server", "http1"] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto"] }
http-body-util = "0.1"
httpdate = "1.0"
tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] }
rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] }
webpki-roots = "0.26"
tokio-rustls = { version = "0.26", default-features = false, features = [
"tls12",
] }
rustls = { version = "0.23", default-features = false, features = [
"std",
"tls12",
"ring",
] }
webpki-roots = "1.0"
[dev-dependencies]
tokio-test = "0.4"
criterion = "0.5"
criterion = "0.8"
proptest = "1.4"
futures = "0.3"

View File

@ -1,5 +1,5 @@
// Cryptobench
use criterion::{black_box, criterion_group, Criterion};
use criterion::{Criterion, black_box, criterion_group};
fn bench_aes_ctr(c: &mut Criterion) {
c.bench_function("aes_ctr_encrypt_64kb", |b| {

View File

@ -260,6 +260,129 @@ This document lists all configuration keys accepted by `config.toml`.
| tls_full_cert_ttl_secs | `u64` | `90` | — | TTL for sending full cert payload per (domain, client IP) tuple. |
| alpn_enforce | `bool` | `true` | — | Enforces ALPN echo behavior based on client preference. |
| mask_proxy_protocol | `u8` | `0` | — | PROXY protocol mode for mask backend (`0` disabled, `1` v1, `2` v2). |
| mask_shape_hardening | `bool` | `true` | — | Enables client->mask shape-channel hardening by applying controlled tail padding to bucket boundaries on mask relay shutdown. |
| mask_shape_hardening_aggressive_mode | `bool` | `false` | Requires `mask_shape_hardening = true`. | Opt-in aggressive shaping profile: allows shaping on backend-silent non-EOF paths and switches above-cap blur to strictly positive random tail. |
| mask_shape_bucket_floor_bytes | `usize` | `512` | Must be `> 0`; should be `<= mask_shape_bucket_cap_bytes`. | Minimum bucket size used by shape-channel hardening. |
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. |
| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. |
| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. |
| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. |
### Shape-channel hardening notes (`[censorship]`)
These parameters are designed to reduce one specific fingerprint source during masking: the exact number of bytes sent from proxy to `mask_host` for invalid or probing traffic.
Without hardening, a censor can often correlate probe input length with backend-observed length very precisely (for example: `5 + body_sent` on early TLS reject paths). That creates a length-based classifier signal.
When `mask_shape_hardening = true`, Telemt pads the **client->mask** stream tail to a bucket boundary at relay shutdown:
- Total bytes sent to mask are first measured.
- A bucket is selected using powers of two starting from `mask_shape_bucket_floor_bytes`.
- Padding is added only if total bytes are below `mask_shape_bucket_cap_bytes`.
- If bytes already exceed cap, no extra padding is added.
This means multiple nearby probe sizes collapse into the same backend-observed size class, making active classification harder.
What each parameter changes in practice:
- `mask_shape_hardening`
Enables or disables this entire length-shaping stage on the fallback path.
When `false`, backend-observed length stays close to the real forwarded probe length.
When `true`, clean relay shutdown can append random padding bytes to move the total into a bucket.
- `mask_shape_bucket_floor_bytes`
Sets the first bucket boundary used for small probes.
Example: with floor `512`, a malformed probe that would otherwise forward `37` bytes can be expanded to `512` bytes on clean EOF.
Larger floor values hide very small probes better, but increase egress cost.
- `mask_shape_bucket_cap_bytes`
Sets the largest bucket Telemt will pad up to with bucket logic.
Example: with cap `4096`, a forwarded total of `1800` bytes may be padded to `2048` or `4096` depending on the bucket ladder, but a total already above `4096` will not be bucket-padded further.
Larger cap values increase the range over which size classes are collapsed, but also increase worst-case overhead.
- Clean EOF matters in conservative mode
In the default profile, shape padding is intentionally conservative: it is applied on clean relay shutdown, not on every timeout/drip path.
This avoids introducing new timeout-tail artifacts that some backends or tests interpret as a separate fingerprint.
Practical trade-offs:
- Better anti-fingerprinting on size/shape channel.
- Slightly higher egress overhead for small probes due to padding.
- Behavior is intentionally conservative and enabled by default.
Recommended starting profile:
- `mask_shape_hardening = true` (default)
- `mask_shape_bucket_floor_bytes = 512`
- `mask_shape_bucket_cap_bytes = 4096`
### Aggressive mode notes (`[censorship]`)
`mask_shape_hardening_aggressive_mode` is an opt-in profile for higher anti-classifier pressure.
- Default is `false` to preserve conservative timeout/no-tail behavior.
- Requires `mask_shape_hardening = true`.
- When enabled, backend-silent non-EOF masking paths may be shaped.
- When enabled together with above-cap blur, the random extra tail uses `[1, max]` instead of `[0, max]`.
What changes when aggressive mode is enabled:
- Backend-silent timeout paths can be shaped
In default mode, a client that keeps the socket half-open and times out will usually not receive shape padding on that path.
In aggressive mode, Telemt may still shape that backend-silent session if no backend bytes were returned.
This is specifically aimed at active probes that try to avoid EOF in order to preserve an exact backend-observed length.
- Above-cap blur always adds at least one byte
In default mode, above-cap blur may choose `0`, so some oversized probes still land on their exact base forwarded length.
In aggressive mode, that exact-base sample is removed by construction.
- Tradeoff
Aggressive mode improves resistance to active length classifiers, but it is more opinionated and less conservative.
If your deployment prioritizes strict compatibility with timeout/no-tail semantics, leave it disabled.
If your threat model includes repeated active probing by a censor, this mode is the stronger profile.
Use this mode only when your threat model prioritizes classifier resistance over strict compatibility with conservative masking semantics.
### Above-cap blur notes (`[censorship]`)
`mask_shape_above_cap_blur` adds a second-stage blur for very large probes that are already above `mask_shape_bucket_cap_bytes`.
- A random tail in `[0, mask_shape_above_cap_blur_max_bytes]` is appended in default mode.
- In aggressive mode, the random tail becomes strictly positive: `[1, mask_shape_above_cap_blur_max_bytes]`.
- This reduces exact-size leakage above cap at bounded overhead.
- Keep `mask_shape_above_cap_blur_max_bytes` conservative to avoid unnecessary egress growth.
Operational meaning:
- Without above-cap blur
A probe that forwards `5005` bytes will still look like `5005` bytes to the backend if it is already above cap.
- With above-cap blur enabled
That same probe may look like any value in a bounded window above its base length.
Example with `mask_shape_above_cap_blur_max_bytes = 64`:
backend-observed size becomes `5005..5069` in default mode, or `5006..5069` in aggressive mode.
- Choosing `mask_shape_above_cap_blur_max_bytes`
Small values reduce cost but preserve more separability between far-apart oversized classes.
Larger values blur oversized classes more aggressively, but add more egress overhead and more output variance.
### Timing normalization envelope notes (`[censorship]`)
`mask_timing_normalization_enabled` smooths timing differences between masking outcomes by applying a target duration envelope.
- A random target is selected in `[mask_timing_normalization_floor_ms, mask_timing_normalization_ceiling_ms]`.
- Fast paths are delayed up to the selected target.
- Slow paths are not forced to finish by the ceiling (the envelope is best-effort shaping, not truncation).
Recommended starting profile for timing shaping:
- `mask_timing_normalization_enabled = true`
- `mask_timing_normalization_floor_ms = 180`
- `mask_timing_normalization_ceiling_ms = 320`
If your backend or network is very bandwidth-constrained, reduce cap first. If probes are still too distinguishable in your environment, increase floor gradually.
## [access]

BIN
docs/model/FakeTLS.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 650 KiB

BIN
docs/model/architecture.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 838 KiB

View File

@ -24,10 +24,7 @@ pub(super) fn success_response<T: Serialize>(
.unwrap()
}
pub(super) fn error_response(
request_id: u64,
failure: ApiFailure,
) -> hyper::Response<Full<Bytes>> {
pub(super) fn error_response(request_id: u64, failure: ApiFailure) -> hyper::Response<Full<Bytes>> {
let payload = ErrorResponse {
ok: false,
error: ErrorBody {

View File

@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
@ -19,8 +21,8 @@ use crate::ip_tracker::UserIpTracker;
use crate::proxy::route_mode::RouteRuntimeController;
use crate::startup::StartupTracker;
use crate::stats::Stats;
use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool;
mod config_store;
mod events;
@ -36,8 +38,8 @@ mod runtime_zero;
mod users;
use config_store::{current_revision, parse_if_match};
use http_utils::{error_response, read_json, read_optional_json, success_response};
use events::ApiEventStore;
use http_utils::{error_response, read_json, read_optional_json, success_response};
use model::{
ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData,
};
@ -55,11 +57,11 @@ use runtime_stats::{
MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data,
build_upstreams_data, build_zero_all_data,
};
use runtime_watch::spawn_runtime_watchers;
use runtime_zero::{
build_limits_effective_data, build_runtime_gates_data, build_security_posture_data,
build_system_info_data,
};
use runtime_watch::spawn_runtime_watchers;
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
pub(super) struct ApiRuntimeState {
@ -208,15 +210,15 @@ async fn handle(
));
}
if !api_cfg.whitelist.is_empty()
&& !api_cfg
.whitelist
.iter()
.any(|net| net.contains(peer.ip()))
if !api_cfg.whitelist.is_empty() && !api_cfg.whitelist.iter().any(|net| net.contains(peer.ip()))
{
return Ok(error_response(
request_id,
ApiFailure::new(StatusCode::FORBIDDEN, "forbidden", "Source IP is not allowed"),
ApiFailure::new(
StatusCode::FORBIDDEN,
"forbidden",
"Source IP is not allowed",
),
));
}
@ -347,7 +349,8 @@ async fn handle(
}
("GET", "/v1/runtime/connections/summary") => {
let revision = current_revision(&shared.config_path).await?;
let data = build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await;
let data =
build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await;
Ok(success_response(StatusCode::OK, data, revision))
}
("GET", "/v1/runtime/events/recent") => {
@ -389,13 +392,16 @@ async fn handle(
let (data, revision) = match result {
Ok(ok) => ok,
Err(error) => {
shared.runtime_events.record("api.user.create.failed", error.code);
shared
.runtime_events
.record("api.user.create.failed", error.code);
return Err(error);
}
};
shared
.runtime_events
.record("api.user.create.ok", format!("username={}", data.user.username));
shared.runtime_events.record(
"api.user.create.ok",
format!("username={}", data.user.username),
);
Ok(success_response(StatusCode::CREATED, data, revision))
}
_ => {
@ -414,7 +420,8 @@ async fn handle(
detected_ip_v6,
)
.await;
if let Some(user_info) = users.into_iter().find(|entry| entry.username == user)
if let Some(user_info) =
users.into_iter().find(|entry| entry.username == user)
{
return Ok(success_response(StatusCode::OK, user_info, revision));
}
@ -435,7 +442,8 @@ async fn handle(
));
}
let expected_revision = parse_if_match(req.headers());
let body = read_json::<PatchUserRequest>(req.into_body(), body_limit).await?;
let body =
read_json::<PatchUserRequest>(req.into_body(), body_limit).await?;
let result = patch_user(user, body, expected_revision, &shared).await;
let (data, revision) = match result {
Ok(ok) => ok,
@ -475,10 +483,9 @@ async fn handle(
return Err(error);
}
};
shared.runtime_events.record(
"api.user.delete.ok",
format!("username={}", deleted_user),
);
shared
.runtime_events
.record("api.user.delete.ok", format!("username={}", deleted_user));
return Ok(success_response(StatusCode::OK, deleted_user, revision));
}
if method == Method::POST

View File

@ -1,10 +1,12 @@
use std::net::IpAddr;
use std::sync::OnceLock;
use chrono::{DateTime, Utc};
use hyper::StatusCode;
use rand::Rng;
use serde::{Deserialize, Serialize};
use crate::crypto::SecureRandom;
const MAX_USERNAME_LEN: usize = 64;
#[derive(Debug)]
@ -196,8 +198,6 @@ pub(super) struct ZeroPoolData {
pub(super) pool_swap_total: u64,
pub(super) pool_drain_active: u64,
pub(super) pool_force_close_total: u64,
pub(super) pool_drain_soft_evict_total: u64,
pub(super) pool_drain_soft_evict_writer_total: u64,
pub(super) pool_stale_pick_total: u64,
pub(super) writer_removed_total: u64,
pub(super) writer_removed_unexpected_total: u64,
@ -206,16 +206,6 @@ pub(super) struct ZeroPoolData {
pub(super) refill_failed_total: u64,
pub(super) writer_restored_same_endpoint_total: u64,
pub(super) writer_restored_fallback_total: u64,
pub(super) teardown_attempt_total_normal: u64,
pub(super) teardown_attempt_total_hard_detach: u64,
pub(super) teardown_success_total_normal: u64,
pub(super) teardown_success_total_hard_detach: u64,
pub(super) teardown_timeout_total: u64,
pub(super) teardown_escalation_total: u64,
pub(super) teardown_noop_total: u64,
pub(super) teardown_cleanup_side_effect_failures_total: u64,
pub(super) teardown_duration_count_total: u64,
pub(super) teardown_duration_sum_seconds_total: f64,
}
#[derive(Serialize, Clone)]
@ -248,7 +238,6 @@ pub(super) struct MeWritersSummary {
pub(super) available_pct: f64,
pub(super) required_writers: usize,
pub(super) alive_writers: usize,
pub(super) coverage_ratio: f64,
pub(super) coverage_pct: f64,
pub(super) fresh_alive_writers: usize,
pub(super) fresh_coverage_pct: f64,
@ -297,7 +286,6 @@ pub(super) struct DcStatus {
pub(super) floor_max: usize,
pub(super) floor_capped: bool,
pub(super) alive_writers: usize,
pub(super) coverage_ratio: f64,
pub(super) coverage_pct: f64,
pub(super) fresh_alive_writers: usize,
pub(super) fresh_coverage_pct: f64,
@ -375,12 +363,6 @@ pub(super) struct MinimalMeRuntimeData {
pub(super) me_reconnect_backoff_cap_ms: u64,
pub(super) me_reconnect_fast_retry_count: u32,
pub(super) me_pool_drain_ttl_secs: u64,
pub(super) me_instadrain: bool,
pub(super) me_pool_drain_soft_evict_enabled: bool,
pub(super) me_pool_drain_soft_evict_grace_secs: u64,
pub(super) me_pool_drain_soft_evict_per_writer: u8,
pub(super) me_pool_drain_soft_evict_budget_per_core: u16,
pub(super) me_pool_drain_soft_evict_cooldown_ms: u64,
pub(super) me_pool_force_close_secs: u64,
pub(super) me_pool_min_fresh_ratio: f32,
pub(super) me_bind_stale_mode: &'static str,
@ -502,7 +484,9 @@ pub(super) fn is_valid_username(user: &str) -> bool {
}
pub(super) fn random_user_secret() -> String {
static API_SECRET_RNG: OnceLock<SecureRandom> = OnceLock::new();
let rng = API_SECRET_RNG.get_or_init(SecureRandom::new);
let mut bytes = [0u8; 16];
rand::rng().fill(&mut bytes);
rng.fill(&mut bytes);
hex::encode(bytes)
}

View File

@ -167,11 +167,7 @@ async fn current_me_pool_stage_progress(shared: &ApiShared) -> Option<f64> {
let pool = shared.me_pool.read().await.clone()?;
let status = pool.api_status_snapshot().await;
let configured_dc_groups = status.configured_dc_groups;
let covered_dc_groups = status
.dcs
.iter()
.filter(|dc| dc.alive_writers > 0)
.count();
let covered_dc_groups = status.dcs.iter().filter(|dc| dc.alive_writers > 0).count();
let dc_coverage = ratio_01(covered_dc_groups, configured_dc_groups);
let writer_coverage = ratio_01(status.alive_writers, status.required_writers);

View File

@ -4,9 +4,6 @@ use std::time::{SystemTime, UNIX_EPOCH};
use serde::Serialize;
use crate::config::ProxyConfig;
use crate::stats::{
MeWriterCleanupSideEffectStep, MeWriterTeardownMode, MeWriterTeardownReason, Stats,
};
use super::ApiShared;
@ -101,50 +98,6 @@ pub(super) struct RuntimeMeQualityCountersData {
pub(super) reconnect_success_total: u64,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityTeardownAttemptData {
pub(super) reason: &'static str,
pub(super) mode: &'static str,
pub(super) total: u64,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityTeardownSuccessData {
pub(super) mode: &'static str,
pub(super) total: u64,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityTeardownSideEffectData {
pub(super) step: &'static str,
pub(super) total: u64,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityTeardownDurationBucketData {
pub(super) le_seconds: &'static str,
pub(super) total: u64,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityTeardownDurationData {
pub(super) mode: &'static str,
pub(super) count: u64,
pub(super) sum_seconds: f64,
pub(super) buckets: Vec<RuntimeMeQualityTeardownDurationBucketData>,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityTeardownData {
pub(super) attempts: Vec<RuntimeMeQualityTeardownAttemptData>,
pub(super) success: Vec<RuntimeMeQualityTeardownSuccessData>,
pub(super) timeout_total: u64,
pub(super) escalation_total: u64,
pub(super) noop_total: u64,
pub(super) cleanup_side_effect_failures: Vec<RuntimeMeQualityTeardownSideEffectData>,
pub(super) duration: Vec<RuntimeMeQualityTeardownDurationData>,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityRouteDropData {
pub(super) no_conn_total: u64,
@ -179,14 +132,12 @@ pub(super) struct RuntimeMeQualityDcRttData {
pub(super) rtt_ema_ms: Option<f64>,
pub(super) alive_writers: usize,
pub(super) required_writers: usize,
pub(super) coverage_ratio: f64,
pub(super) coverage_pct: f64,
}
#[derive(Serialize)]
pub(super) struct RuntimeMeQualityPayload {
pub(super) counters: RuntimeMeQualityCountersData,
pub(super) teardown: RuntimeMeQualityTeardownData,
pub(super) route_drops: RuntimeMeQualityRouteDropData,
pub(super) family_states: Vec<RuntimeMeQualityFamilyStateData>,
pub(super) drain_gate: RuntimeMeQualityDrainGateData,
@ -457,7 +408,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime
reconnect_attempt_total: shared.stats.get_me_reconnect_attempts(),
reconnect_success_total: shared.stats.get_me_reconnect_success(),
},
teardown: build_runtime_me_teardown_data(shared),
route_drops: RuntimeMeQualityRouteDropData {
no_conn_total: shared.stats.get_me_route_drop_no_conn(),
channel_closed_total: shared.stats.get_me_route_drop_channel_closed(),
@ -480,7 +430,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime
rtt_ema_ms: dc.rtt_ms,
alive_writers: dc.alive_writers,
required_writers: dc.required_writers,
coverage_ratio: dc.coverage_ratio,
coverage_pct: dc.coverage_pct,
})
.collect(),
@ -488,81 +437,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime
}
}
fn build_runtime_me_teardown_data(shared: &ApiShared) -> RuntimeMeQualityTeardownData {
let attempts = MeWriterTeardownReason::ALL
.iter()
.copied()
.flat_map(|reason| {
MeWriterTeardownMode::ALL
.iter()
.copied()
.map(move |mode| RuntimeMeQualityTeardownAttemptData {
reason: reason.as_str(),
mode: mode.as_str(),
total: shared.stats.get_me_writer_teardown_attempt_total(reason, mode),
})
})
.collect();
let success = MeWriterTeardownMode::ALL
.iter()
.copied()
.map(|mode| RuntimeMeQualityTeardownSuccessData {
mode: mode.as_str(),
total: shared.stats.get_me_writer_teardown_success_total(mode),
})
.collect();
let cleanup_side_effect_failures = MeWriterCleanupSideEffectStep::ALL
.iter()
.copied()
.map(|step| RuntimeMeQualityTeardownSideEffectData {
step: step.as_str(),
total: shared
.stats
.get_me_writer_cleanup_side_effect_failures_total(step),
})
.collect();
let duration = MeWriterTeardownMode::ALL
.iter()
.copied()
.map(|mode| {
let count = shared.stats.get_me_writer_teardown_duration_count(mode);
let mut buckets: Vec<RuntimeMeQualityTeardownDurationBucketData> = Stats::me_writer_teardown_duration_bucket_labels()
.iter()
.enumerate()
.map(|(bucket_idx, label)| RuntimeMeQualityTeardownDurationBucketData {
le_seconds: label,
total: shared
.stats
.get_me_writer_teardown_duration_bucket_total(mode, bucket_idx),
})
.collect();
buckets.push(RuntimeMeQualityTeardownDurationBucketData {
le_seconds: "+Inf",
total: count,
});
RuntimeMeQualityTeardownDurationData {
mode: mode.as_str(),
count,
sum_seconds: shared.stats.get_me_writer_teardown_duration_sum_seconds(mode),
buckets,
}
})
.collect();
RuntimeMeQualityTeardownData {
attempts,
success,
timeout_total: shared.stats.get_me_writer_teardown_timeout_total(),
escalation_total: shared.stats.get_me_writer_teardown_escalation_total(),
noop_total: shared.stats.get_me_writer_teardown_noop_total(),
cleanup_side_effect_failures,
duration,
}
}
pub(super) async fn build_runtime_upstream_quality_data(
shared: &ApiShared,
) -> RuntimeUpstreamQualityData {

View File

@ -1,9 +1,9 @@
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use crate::config::ApiConfig;
use crate::stats::{MeWriterTeardownMode, Stats};
use crate::transport::upstream::IpPreference;
use crate::stats::Stats;
use crate::transport::UpstreamRouteKind;
use crate::transport::upstream::IpPreference;
use super::ApiShared;
use super::model::{
@ -96,8 +96,6 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer
pool_swap_total: stats.get_pool_swap_total(),
pool_drain_active: stats.get_pool_drain_active(),
pool_force_close_total: stats.get_pool_force_close_total(),
pool_drain_soft_evict_total: stats.get_pool_drain_soft_evict_total(),
pool_drain_soft_evict_writer_total: stats.get_pool_drain_soft_evict_writer_total(),
pool_stale_pick_total: stats.get_pool_stale_pick_total(),
writer_removed_total: stats.get_me_writer_removed_total(),
writer_removed_unexpected_total: stats.get_me_writer_removed_unexpected_total(),
@ -106,29 +104,6 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer
refill_failed_total: stats.get_me_refill_failed_total(),
writer_restored_same_endpoint_total: stats.get_me_writer_restored_same_endpoint_total(),
writer_restored_fallback_total: stats.get_me_writer_restored_fallback_total(),
teardown_attempt_total_normal: stats
.get_me_writer_teardown_attempt_total_by_mode(MeWriterTeardownMode::Normal),
teardown_attempt_total_hard_detach: stats
.get_me_writer_teardown_attempt_total_by_mode(MeWriterTeardownMode::HardDetach),
teardown_success_total_normal: stats
.get_me_writer_teardown_success_total(MeWriterTeardownMode::Normal),
teardown_success_total_hard_detach: stats
.get_me_writer_teardown_success_total(MeWriterTeardownMode::HardDetach),
teardown_timeout_total: stats.get_me_writer_teardown_timeout_total(),
teardown_escalation_total: stats.get_me_writer_teardown_escalation_total(),
teardown_noop_total: stats.get_me_writer_teardown_noop_total(),
teardown_cleanup_side_effect_failures_total: stats
.get_me_writer_cleanup_side_effect_failures_total_all(),
teardown_duration_count_total: stats
.get_me_writer_teardown_duration_count(MeWriterTeardownMode::Normal)
.saturating_add(
stats.get_me_writer_teardown_duration_count(MeWriterTeardownMode::HardDetach),
),
teardown_duration_sum_seconds_total: stats
.get_me_writer_teardown_duration_sum_seconds(MeWriterTeardownMode::Normal)
+ stats.get_me_writer_teardown_duration_sum_seconds(
MeWriterTeardownMode::HardDetach,
),
},
desync: ZeroDesyncData {
secure_padding_invalid_total: stats.get_secure_padding_invalid(),
@ -340,7 +315,6 @@ async fn get_minimal_payload_cached(
available_pct: status.available_pct,
required_writers: status.required_writers,
alive_writers: status.alive_writers,
coverage_ratio: status.coverage_ratio,
coverage_pct: status.coverage_pct,
fresh_alive_writers: status.fresh_alive_writers,
fresh_coverage_pct: status.fresh_coverage_pct,
@ -398,7 +372,6 @@ async fn get_minimal_payload_cached(
floor_max: entry.floor_max,
floor_capped: entry.floor_capped,
alive_writers: entry.alive_writers,
coverage_ratio: entry.coverage_ratio,
coverage_pct: entry.coverage_pct,
fresh_alive_writers: entry.fresh_alive_writers,
fresh_coverage_pct: entry.fresh_coverage_pct,
@ -452,12 +425,6 @@ async fn get_minimal_payload_cached(
me_reconnect_backoff_cap_ms: runtime.me_reconnect_backoff_cap_ms,
me_reconnect_fast_retry_count: runtime.me_reconnect_fast_retry_count,
me_pool_drain_ttl_secs: runtime.me_pool_drain_ttl_secs,
me_instadrain: runtime.me_instadrain,
me_pool_drain_soft_evict_enabled: runtime.me_pool_drain_soft_evict_enabled,
me_pool_drain_soft_evict_grace_secs: runtime.me_pool_drain_soft_evict_grace_secs,
me_pool_drain_soft_evict_per_writer: runtime.me_pool_drain_soft_evict_per_writer,
me_pool_drain_soft_evict_budget_per_core: runtime.me_pool_drain_soft_evict_budget_per_core,
me_pool_drain_soft_evict_cooldown_ms: runtime.me_pool_drain_soft_evict_cooldown_ms,
me_pool_force_close_secs: runtime.me_pool_force_close_secs,
me_pool_min_fresh_ratio: runtime.me_pool_min_fresh_ratio,
me_bind_stale_mode: runtime.me_bind_stale_mode,
@ -526,7 +493,6 @@ fn disabled_me_writers(now_epoch_secs: u64, reason: &'static str) -> MeWritersDa
available_pct: 0.0,
required_writers: 0,
alive_writers: 0,
coverage_ratio: 0.0,
coverage_pct: 0.0,
fresh_alive_writers: 0,
fresh_coverage_pct: 0.0,

View File

@ -128,7 +128,8 @@ pub(super) fn build_system_info_data(
.runtime_state
.last_config_reload_epoch_secs
.load(Ordering::Relaxed);
let last_config_reload_epoch_secs = (last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs);
let last_config_reload_epoch_secs =
(last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs);
let git_commit = option_env!("TELEMT_GIT_COMMIT")
.or(option_env!("VERGEN_GIT_SHA"))
@ -153,7 +154,10 @@ pub(super) fn build_system_info_data(
uptime_seconds: shared.stats.uptime_secs(),
config_path: shared.config_path.display().to_string(),
config_hash: revision.to_string(),
config_reload_count: shared.runtime_state.config_reload_count.load(Ordering::Relaxed),
config_reload_count: shared
.runtime_state
.config_reload_count
.load(Ordering::Relaxed),
last_config_reload_epoch_secs,
}
}
@ -233,9 +237,7 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD
adaptive_floor_writers_per_core_total: cfg
.general
.me_adaptive_floor_writers_per_core_total,
adaptive_floor_cpu_cores_override: cfg
.general
.me_adaptive_floor_cpu_cores_override,
adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override,
adaptive_floor_max_extra_writers_single_per_core: cfg
.general
.me_adaptive_floor_max_extra_writers_single_per_core,

View File

@ -46,7 +46,9 @@ pub(super) async fn create_user(
None => random_user_secret(),
};
if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) {
if let Some(ad_tag) = body.user_ad_tag.as_ref()
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters",
));
@ -65,12 +67,18 @@ pub(super) async fn create_user(
));
}
cfg.access.users.insert(body.username.clone(), secret.clone());
cfg.access
.users
.insert(body.username.clone(), secret.clone());
if let Some(ad_tag) = body.user_ad_tag {
cfg.access.user_ad_tags.insert(body.username.clone(), ad_tag);
cfg.access
.user_ad_tags
.insert(body.username.clone(), ad_tag);
}
if let Some(limit) = body.max_tcp_conns {
cfg.access.user_max_tcp_conns.insert(body.username.clone(), limit);
cfg.access
.user_max_tcp_conns
.insert(body.username.clone(), limit);
}
if let Some(expiration) = expiration {
cfg.access
@ -78,7 +86,9 @@ pub(super) async fn create_user(
.insert(body.username.clone(), expiration);
}
if let Some(quota) = body.data_quota_bytes {
cfg.access.user_data_quota.insert(body.username.clone(), quota);
cfg.access
.user_data_quota
.insert(body.username.clone(), quota);
}
let updated_limit = body.max_unique_ips;
@ -108,11 +118,15 @@ pub(super) async fn create_user(
touched_sections.push(AccessSection::UserMaxUniqueIps);
}
let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
let revision =
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard);
if let Some(limit) = updated_limit {
shared.ip_tracker.set_user_limit(&body.username, limit).await;
shared
.ip_tracker
.set_user_limit(&body.username, limit)
.await;
}
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
@ -140,12 +154,7 @@ pub(super) async fn create_user(
recent_unique_ips: 0,
recent_unique_ips_list: Vec::new(),
total_octets: 0,
links: build_user_links(
&cfg,
&secret,
detected_ip_v4,
detected_ip_v6,
),
links: build_user_links(&cfg, &secret, detected_ip_v4, detected_ip_v6),
});
Ok((CreateUserResponse { user, secret }, revision))
@ -157,12 +166,16 @@ pub(super) async fn patch_user(
expected_revision: Option<String>,
shared: &ApiShared,
) -> Result<(UserInfo, String), ApiFailure> {
if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) {
if let Some(secret) = body.secret.as_ref()
&& !is_valid_user_secret(secret)
{
return Err(ApiFailure::bad_request(
"secret must be exactly 32 hex characters",
));
}
if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) {
if let Some(ad_tag) = body.user_ad_tag.as_ref()
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters",
));
@ -187,10 +200,14 @@ pub(super) async fn patch_user(
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
}
if let Some(limit) = body.max_tcp_conns {
cfg.access.user_max_tcp_conns.insert(user.to_string(), limit);
cfg.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
}
if let Some(expiration) = expiration {
cfg.access.user_expirations.insert(user.to_string(), expiration);
cfg.access
.user_expirations
.insert(user.to_string(), expiration);
}
if let Some(quota) = body.data_quota_bytes {
cfg.access.user_data_quota.insert(user.to_string(), quota);
@ -198,7 +215,9 @@ pub(super) async fn patch_user(
let mut updated_limit = None;
if let Some(limit) = body.max_unique_ips {
cfg.access.user_max_unique_ips.insert(user.to_string(), limit);
cfg.access
.user_max_unique_ips
.insert(user.to_string(), limit);
updated_limit = Some(limit);
}
@ -263,7 +282,8 @@ pub(super) async fn rotate_secret(
AccessSection::UserDataQuota,
AccessSection::UserMaxUniqueIps,
];
let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
let revision =
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard);
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
@ -330,7 +350,8 @@ pub(super) async fn delete_user(
AccessSection::UserDataQuota,
AccessSection::UserMaxUniqueIps,
];
let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
let revision =
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard);
shared.ip_tracker.remove_user_limit(user).await;
shared.ip_tracker.clear_user_ips(user).await;
@ -365,12 +386,7 @@ pub(super) async fn users_from_config(
.users
.get(&username)
.map(|secret| {
build_user_links(
cfg,
secret,
startup_detected_ip_v4,
startup_detected_ip_v6,
)
build_user_links(cfg, secret, startup_detected_ip_v4, startup_detected_ip_v6)
})
.unwrap_or(UserLinks {
classic: Vec::new(),
@ -392,10 +408,8 @@ pub(super) async fn users_from_config(
.get(&username)
.copied()
.filter(|limit| *limit > 0)
.or(
(cfg.access.user_max_unique_ips_global_each > 0)
.then_some(cfg.access.user_max_unique_ips_global_each),
),
.or((cfg.access.user_max_unique_ips_global_each > 0)
.then_some(cfg.access.user_max_unique_ips_global_each)),
current_connections: stats.get_user_curr_connects(&username),
active_unique_ips: active_ip_list.len(),
active_unique_ips_list: active_ip_list,
@ -481,12 +495,12 @@ fn resolve_link_hosts(
push_unique_host(&mut hosts, host);
continue;
}
if let Some(ip) = listener.announce_ip {
if !ip.is_unspecified() {
if let Some(ip) = listener.announce_ip
&& !ip.is_unspecified()
{
push_unique_host(&mut hosts, &ip.to_string());
continue;
}
}
if listener.ip.is_unspecified() {
let detected_ip = if listener.ip.is_ipv4() {
startup_detected_ip_v4

View File

@ -1,11 +1,270 @@
//! CLI commands: --init (fire-and-forget setup)
//! CLI commands: --init (fire-and-forget setup), daemon options, subcommands
//!
//! Subcommands:
//! - `start [OPTIONS] [config.toml]` - Start the daemon
//! - `stop [--pid-file PATH]` - Stop a running daemon
//! - `reload [--pid-file PATH]` - Reload configuration (SIGHUP)
//! - `status [--pid-file PATH]` - Check daemon status
//! - `run [OPTIONS] [config.toml]` - Run in foreground (default behavior)
use rand::RngExt;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use rand::Rng;
#[cfg(unix)]
use crate::daemon::{self, DaemonOptions, DEFAULT_PID_FILE};
/// CLI subcommand to execute.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Subcommand {
/// Run the proxy (default, or explicit `run` subcommand).
Run,
/// Start as daemon (`start` subcommand).
Start,
/// Stop a running daemon (`stop` subcommand).
Stop,
/// Reload configuration (`reload` subcommand).
Reload,
/// Check daemon status (`status` subcommand).
Status,
/// Fire-and-forget setup (`--init`).
Init,
}
/// Parsed subcommand with its options.
#[derive(Debug)]
pub struct ParsedCommand {
pub subcommand: Subcommand,
pub pid_file: PathBuf,
pub config_path: String,
#[cfg(unix)]
pub daemon_opts: DaemonOptions,
pub init_opts: Option<InitOptions>,
}
impl Default for ParsedCommand {
fn default() -> Self {
Self {
subcommand: Subcommand::Run,
#[cfg(unix)]
pid_file: PathBuf::from(DEFAULT_PID_FILE),
#[cfg(not(unix))]
pid_file: PathBuf::from("/var/run/telemt.pid"),
config_path: "config.toml".to_string(),
#[cfg(unix)]
daemon_opts: DaemonOptions::default(),
init_opts: None,
}
}
}
/// Parse CLI arguments into a command structure.
pub fn parse_command(args: &[String]) -> ParsedCommand {
let mut cmd = ParsedCommand::default();
// Check for --init first (legacy form)
if args.iter().any(|a| a == "--init") {
cmd.subcommand = Subcommand::Init;
cmd.init_opts = parse_init_args(args);
return cmd;
}
// Check for subcommand as first argument
if let Some(first) = args.first() {
match first.as_str() {
"start" => {
cmd.subcommand = Subcommand::Start;
#[cfg(unix)]
{
cmd.daemon_opts = parse_daemon_args(args);
// Force daemonize for start command
cmd.daemon_opts.daemonize = true;
}
}
"stop" => {
cmd.subcommand = Subcommand::Stop;
}
"reload" => {
cmd.subcommand = Subcommand::Reload;
}
"status" => {
cmd.subcommand = Subcommand::Status;
}
"run" => {
cmd.subcommand = Subcommand::Run;
#[cfg(unix)]
{
cmd.daemon_opts = parse_daemon_args(args);
}
}
_ => {
// No subcommand, default to Run
#[cfg(unix)]
{
cmd.daemon_opts = parse_daemon_args(args);
}
}
}
}
// Parse remaining options
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
// Skip subcommand names
"start" | "stop" | "reload" | "status" | "run" => {}
// PID file option (for stop/reload/status)
"--pid-file" => {
i += 1;
if i < args.len() {
cmd.pid_file = PathBuf::from(&args[i]);
#[cfg(unix)]
{
cmd.daemon_opts.pid_file = Some(cmd.pid_file.clone());
}
}
}
s if s.starts_with("--pid-file=") => {
cmd.pid_file = PathBuf::from(s.trim_start_matches("--pid-file="));
#[cfg(unix)]
{
cmd.daemon_opts.pid_file = Some(cmd.pid_file.clone());
}
}
// Config path (positional, non-flag argument)
s if !s.starts_with('-') => {
cmd.config_path = s.to_string();
}
_ => {}
}
i += 1;
}
cmd
}
/// Execute a subcommand that doesn't require starting the server.
/// Returns `Some(exit_code)` if the command was handled, `None` if server should start.
#[cfg(unix)]
pub fn execute_subcommand(cmd: &ParsedCommand) -> Option<i32> {
match cmd.subcommand {
Subcommand::Stop => Some(cmd_stop(&cmd.pid_file)),
Subcommand::Reload => Some(cmd_reload(&cmd.pid_file)),
Subcommand::Status => Some(cmd_status(&cmd.pid_file)),
Subcommand::Init => {
if let Some(opts) = cmd.init_opts.clone() {
match run_init(opts) {
Ok(()) => Some(0),
Err(e) => {
eprintln!("[telemt] Init failed: {}", e);
Some(1)
}
}
} else {
Some(1)
}
}
// Run and Start need the server
Subcommand::Run | Subcommand::Start => None,
}
}
#[cfg(not(unix))]
pub fn execute_subcommand(cmd: &ParsedCommand) -> Option<i32> {
match cmd.subcommand {
Subcommand::Stop | Subcommand::Reload | Subcommand::Status => {
eprintln!("[telemt] Subcommand not supported on this platform");
Some(1)
}
Subcommand::Init => {
if let Some(opts) = cmd.init_opts.clone() {
match run_init(opts) {
Ok(()) => Some(0),
Err(e) => {
eprintln!("[telemt] Init failed: {}", e);
Some(1)
}
}
} else {
Some(1)
}
}
Subcommand::Run | Subcommand::Start => None,
}
}
/// Stop command: send SIGTERM to the running daemon.
#[cfg(unix)]
fn cmd_stop(pid_file: &Path) -> i32 {
use nix::sys::signal::Signal;
println!("Stopping telemt daemon...");
match daemon::signal_pid_file(pid_file, Signal::SIGTERM) {
Ok(()) => {
println!("Stop signal sent successfully");
// Wait for process to exit (up to 10 seconds)
for _ in 0..20 {
std::thread::sleep(std::time::Duration::from_millis(500));
if let daemon::DaemonStatus::NotRunning = daemon::check_status(pid_file) {
println!("Daemon stopped");
return 0;
}
}
println!("Daemon may still be shutting down");
0
}
Err(e) => {
eprintln!("Failed to stop daemon: {}", e);
1
}
}
}
/// Reload command: send SIGHUP to trigger config reload.
#[cfg(unix)]
fn cmd_reload(pid_file: &Path) -> i32 {
use nix::sys::signal::Signal;
println!("Reloading telemt configuration...");
match daemon::signal_pid_file(pid_file, Signal::SIGHUP) {
Ok(()) => {
println!("Reload signal sent successfully");
0
}
Err(e) => {
eprintln!("Failed to reload daemon: {}", e);
1
}
}
}
/// Status command: check if daemon is running.
#[cfg(unix)]
fn cmd_status(pid_file: &Path) -> i32 {
match daemon::check_status(pid_file) {
daemon::DaemonStatus::Running(pid) => {
println!("telemt is running (pid {})", pid);
0
}
daemon::DaemonStatus::Stale(pid) => {
println!("telemt is not running (stale pid file, was pid {})", pid);
// Clean up stale PID file
let _ = std::fs::remove_file(pid_file);
1
}
daemon::DaemonStatus::NotRunning => {
println!("telemt is not running");
1
}
}
}
/// Options for the init command
#[derive(Debug, Clone)]
pub struct InitOptions {
pub port: u16,
pub domain: String,
@ -15,6 +274,64 @@ pub struct InitOptions {
pub no_start: bool,
}
/// Parse daemon-related options from CLI args.
#[cfg(unix)]
pub fn parse_daemon_args(args: &[String]) -> DaemonOptions {
let mut opts = DaemonOptions::default();
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--daemon" | "-d" => {
opts.daemonize = true;
}
"--foreground" | "-f" => {
opts.foreground = true;
}
"--pid-file" => {
i += 1;
if i < args.len() {
opts.pid_file = Some(PathBuf::from(&args[i]));
}
}
s if s.starts_with("--pid-file=") => {
opts.pid_file = Some(PathBuf::from(s.trim_start_matches("--pid-file=")));
}
"--run-as-user" => {
i += 1;
if i < args.len() {
opts.user = Some(args[i].clone());
}
}
s if s.starts_with("--run-as-user=") => {
opts.user = Some(s.trim_start_matches("--run-as-user=").to_string());
}
"--run-as-group" => {
i += 1;
if i < args.len() {
opts.group = Some(args[i].clone());
}
}
s if s.starts_with("--run-as-group=") => {
opts.group = Some(s.trim_start_matches("--run-as-group=").to_string());
}
"--working-dir" => {
i += 1;
if i < args.len() {
opts.working_dir = Some(PathBuf::from(&args[i]));
}
}
s if s.starts_with("--working-dir=") => {
opts.working_dir = Some(PathBuf::from(s.trim_start_matches("--working-dir=")));
}
_ => {}
}
i += 1;
}
opts
}
impl Default for InitOptions {
fn default() -> Self {
Self {
@ -84,10 +401,16 @@ pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
/// Run the fire-and-forget setup.
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
use crate::service::{self, InitSystem, ServiceOptions};
eprintln!("[telemt] Fire-and-forget setup");
eprintln!();
// 1. Generate or validate secret
// 1. Detect init system
let init_system = service::detect_init_system();
eprintln!("[+] Detected init system: {}", init_system);
// 2. Generate or validate secret
let secret = match opts.secret {
Some(s) => {
if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) {
@ -104,50 +427,74 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[+] Port: {}", opts.port);
eprintln!("[+] Domain: {}", opts.domain);
// 2. Create config directory
// 3. Create config directory
fs::create_dir_all(&opts.config_dir)?;
let config_path = opts.config_dir.join("config.toml");
// 3. Write config
// 4. Write config
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
fs::write(&config_path, &config_content)?;
eprintln!("[+] Config written to {}", config_path.display());
// 4. Write systemd unit
// 5. Generate and write service file
let exe_path = std::env::current_exe()
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let unit_path = Path::new("/etc/systemd/system/telemt.service");
let unit_content = generate_systemd_unit(&exe_path, &config_path);
let service_opts = ServiceOptions {
exe_path: &exe_path,
config_path: &config_path,
user: None, // Let systemd/init handle user
group: None,
pid_file: "/var/run/telemt.pid",
working_dir: Some("/var/lib/telemt"),
description: "Telemt MTProxy - Telegram MTProto Proxy",
};
match fs::write(unit_path, &unit_content) {
let service_path = service::service_file_path(init_system);
let service_content = service::generate_service_file(init_system, &service_opts);
// Ensure parent directory exists
if let Some(parent) = Path::new(service_path).parent() {
let _ = fs::create_dir_all(parent);
}
match fs::write(service_path, &service_content) {
Ok(()) => {
eprintln!("[+] Systemd unit written to {}", unit_path.display());
eprintln!("[+] Service file written to {}", service_path);
// Make script executable for OpenRC/FreeBSD
#[cfg(unix)]
if init_system == InitSystem::OpenRC || init_system == InitSystem::FreeBSDRc {
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(service_path)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(service_path, perms)?;
}
}
Err(e) => {
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
eprintln!("[!] Manual unit file content:");
eprintln!("{}", unit_content);
eprintln!("[!] Cannot write service file (run as root?): {}", e);
eprintln!("[!] Manual service file content:");
eprintln!("{}", service_content);
// Still print links and config
// Still print links and installation instructions
eprintln!();
eprintln!("{}", service::installation_instructions(init_system));
print_links(&opts.username, &secret, opts.port, &opts.domain);
return Ok(());
}
}
// 5. Reload systemd
// 6. Install and enable service based on init system
match init_system {
InitSystem::Systemd => {
run_cmd("systemctl", &["daemon-reload"]);
// 6. Enable service
run_cmd("systemctl", &["enable", "telemt.service"]);
eprintln!("[+] Service enabled");
// 7. Start service (unless --no-start)
if !opts.no_start {
run_cmd("systemctl", &["start", "telemt.service"]);
eprintln!("[+] Service started");
// Brief delay then check status
std::thread::sleep(std::time::Duration::from_secs(1));
let status = Command::new("systemctl")
.args(["is-active", "telemt.service"])
@ -166,10 +513,40 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: systemctl start telemt.service");
}
}
InitSystem::OpenRC => {
run_cmd("rc-update", &["add", "telemt", "default"]);
eprintln!("[+] Service enabled");
if !opts.no_start {
run_cmd("rc-service", &["telemt", "start"]);
eprintln!("[+] Service started");
} else {
eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: rc-service telemt start");
}
}
InitSystem::FreeBSDRc => {
run_cmd("sysrc", &["telemt_enable=YES"]);
eprintln!("[+] Service enabled");
if !opts.no_start {
run_cmd("service", &["telemt", "start"]);
eprintln!("[+] Service started");
} else {
eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: service telemt start");
}
}
InitSystem::Unknown => {
eprintln!("[!] Unknown init system - service file written but not installed");
eprintln!("[!] You may need to install it manually");
}
}
eprintln!();
// 8. Print links
// 7. Print links
print_links(&opts.username, &secret, opts.port, &opts.domain);
Ok(())
@ -246,7 +623,7 @@ tls_full_cert_ttl_secs = 90
[access]
replay_check_len = 65536
replay_window_secs = 1800
replay_window_secs = 120
ignore_time_skew = false
[access.users]
@ -264,35 +641,6 @@ weight = 10
)
}
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
format!(
r#"[Unit]
Description=Telemt MTProxy
Documentation=https://github.com/telemt/telemt
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
ExecStart={exe} {config}
Restart=always
RestartSec=5
LimitNOFILE=65535
# Security hardening
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=/etc/telemt
PrivateTmp=true
[Install]
WantedBy=multi-user.target
"#,
exe = exe_path.display(),
config = config_path.display(),
)
}
fn run_cmd(cmd: &str, args: &[&str]) {
match Command::new(cmd).args(args).output() {
Ok(output) => {
@ -312,8 +660,10 @@ fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
println!("=== Proxy Links ===");
println!("[{}]", username);
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
port, secret, domain_hex);
println!(
" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
port, secret, domain_hex
);
println!();
println!("Replace YOUR_SERVER_IP with your server's public IP.");
println!("The proxy will auto-detect and display the correct link on startup.");

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use ipnetwork::IpNetwork;
use serde::Deserialize;
use std::collections::HashMap;
// Helper defaults kept private to the config module.
const DEFAULT_NETWORK_IPV6: Option<bool> = Some(false);
@ -86,13 +86,31 @@ pub(crate) fn default_replay_check_len() -> usize {
}
pub(crate) fn default_replay_window_secs() -> u64 {
1800
// Keep replay cache TTL tight by default to reduce replay surface.
// Deployments with higher RTT or longer reconnect jitter can override this in config.
120
}
pub(crate) fn default_handshake_timeout() -> u64 {
30
}
pub(crate) fn default_relay_idle_policy_v2_enabled() -> bool {
true
}
pub(crate) fn default_relay_client_idle_soft_secs() -> u64 {
120
}
pub(crate) fn default_relay_client_idle_hard_secs() -> u64 {
360
}
pub(crate) fn default_relay_idle_grace_after_downstream_activity_secs() -> u64 {
30
}
pub(crate) fn default_connect_timeout() -> u64 {
10
}
@ -125,10 +143,7 @@ pub(crate) fn default_weight() -> u16 {
}
pub(crate) fn default_metrics_whitelist() -> Vec<IpNetwork> {
vec![
"127.0.0.1/32".parse().unwrap(),
"::1/128".parse().unwrap(),
]
vec!["127.0.0.1/32".parse().unwrap(), "::1/128".parse().unwrap()]
}
pub(crate) fn default_api_listen() -> String {
@ -151,10 +166,18 @@ pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 {
1000
}
pub(crate) fn default_api_runtime_edge_enabled() -> bool { false }
pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { 1000 }
pub(crate) fn default_api_runtime_edge_top_n() -> usize { 10 }
pub(crate) fn default_api_runtime_edge_events_capacity() -> usize { 256 }
pub(crate) fn default_api_runtime_edge_enabled() -> bool {
false
}
pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 {
1000
}
pub(crate) fn default_api_runtime_edge_top_n() -> usize {
10
}
pub(crate) fn default_api_runtime_edge_events_capacity() -> usize {
256
}
pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 {
500
@ -485,17 +508,53 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 {
}
pub(crate) fn default_server_hello_delay_min_ms() -> u64 {
0
8
}
pub(crate) fn default_server_hello_delay_max_ms() -> u64 {
0
24
}
pub(crate) fn default_alpn_enforce() -> bool {
true
}
pub(crate) fn default_mask_shape_hardening() -> bool {
true
}
pub(crate) fn default_mask_shape_hardening_aggressive_mode() -> bool {
false
}
pub(crate) fn default_mask_shape_bucket_floor_bytes() -> usize {
512
}
pub(crate) fn default_mask_shape_bucket_cap_bytes() -> usize {
4096
}
pub(crate) fn default_mask_shape_above_cap_blur() -> bool {
false
}
pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize {
512
}
pub(crate) fn default_mask_timing_normalization_enabled() -> bool {
false
}
pub(crate) fn default_mask_timing_normalization_floor_ms() -> u64 {
0
}
pub(crate) fn default_mask_timing_normalization_ceiling_ms() -> u64 {
0
}
pub(crate) fn default_stun_servers() -> Vec<String> {
vec![
"stun.l.google.com:5349".to_string(),

View File

@ -31,15 +31,12 @@ use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
use tokio::sync::{mpsc, watch};
use tracing::{error, info, warn};
use crate::config::{
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
MeWriterPickMode,
};
use super::load::{LoadedConfig, ProxyConfig};
use crate::config::{
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode,
};
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
const HOT_RELOAD_STABLE_RECHECK: Duration = Duration::from_millis(75);
// ── Hot fields ────────────────────────────────────────────────────────────────
@ -58,11 +55,6 @@ pub struct HotFields {
pub me_pool_drain_ttl_secs: u64,
pub me_instadrain: bool,
pub me_pool_drain_threshold: u64,
pub me_pool_drain_soft_evict_enabled: bool,
pub me_pool_drain_soft_evict_grace_secs: u64,
pub me_pool_drain_soft_evict_per_writer: u8,
pub me_pool_drain_soft_evict_budget_per_core: u16,
pub me_pool_drain_soft_evict_cooldown_ms: u64,
pub me_pool_min_fresh_ratio: f32,
pub me_reinit_drain_timeout_secs: u64,
pub me_hardswap_warmup_delay_min_ms: u64,
@ -146,15 +138,6 @@ impl HotFields {
me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs,
me_instadrain: cfg.general.me_instadrain,
me_pool_drain_threshold: cfg.general.me_pool_drain_threshold,
me_pool_drain_soft_evict_enabled: cfg.general.me_pool_drain_soft_evict_enabled,
me_pool_drain_soft_evict_grace_secs: cfg.general.me_pool_drain_soft_evict_grace_secs,
me_pool_drain_soft_evict_per_writer: cfg.general.me_pool_drain_soft_evict_per_writer,
me_pool_drain_soft_evict_budget_per_core: cfg
.general
.me_pool_drain_soft_evict_budget_per_core,
me_pool_drain_soft_evict_cooldown_ms: cfg
.general
.me_pool_drain_soft_evict_cooldown_ms,
me_pool_min_fresh_ratio: cfg.general.me_pool_min_fresh_ratio,
me_reinit_drain_timeout_secs: cfg.general.me_reinit_drain_timeout_secs,
me_hardswap_warmup_delay_min_ms: cfg.general.me_hardswap_warmup_delay_min_ms,
@ -205,15 +188,11 @@ impl HotFields {
me_adaptive_floor_min_writers_multi_endpoint: cfg
.general
.me_adaptive_floor_min_writers_multi_endpoint,
me_adaptive_floor_recover_grace_secs: cfg
.general
.me_adaptive_floor_recover_grace_secs,
me_adaptive_floor_recover_grace_secs: cfg.general.me_adaptive_floor_recover_grace_secs,
me_adaptive_floor_writers_per_core_total: cfg
.general
.me_adaptive_floor_writers_per_core_total,
me_adaptive_floor_cpu_cores_override: cfg
.general
.me_adaptive_floor_cpu_cores_override,
me_adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override,
me_adaptive_floor_max_extra_writers_single_per_core: cfg
.general
.me_adaptive_floor_max_extra_writers_single_per_core,
@ -232,9 +211,15 @@ impl HotFields {
me_adaptive_floor_max_warm_writers_global: cfg
.general
.me_adaptive_floor_max_warm_writers_global,
me_route_backpressure_base_timeout_ms: cfg.general.me_route_backpressure_base_timeout_ms,
me_route_backpressure_high_timeout_ms: cfg.general.me_route_backpressure_high_timeout_ms,
me_route_backpressure_high_watermark_pct: cfg.general.me_route_backpressure_high_watermark_pct,
me_route_backpressure_base_timeout_ms: cfg
.general
.me_route_backpressure_base_timeout_ms,
me_route_backpressure_high_timeout_ms: cfg
.general
.me_route_backpressure_high_timeout_ms,
me_route_backpressure_high_watermark_pct: cfg
.general
.me_route_backpressure_high_watermark_pct,
me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms,
me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames,
me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes,
@ -346,16 +331,12 @@ impl WatchManifest {
#[derive(Debug, Default)]
struct ReloadState {
applied_snapshot_hash: Option<u64>,
candidate_snapshot_hash: Option<u64>,
candidate_hits: u8,
}
impl ReloadState {
fn new(applied_snapshot_hash: Option<u64>) -> Self {
Self {
applied_snapshot_hash,
candidate_snapshot_hash: None,
candidate_hits: 0,
}
}
@ -363,32 +344,8 @@ impl ReloadState {
self.applied_snapshot_hash == Some(hash)
}
fn observe_candidate(&mut self, hash: u64) -> u8 {
if self.candidate_snapshot_hash == Some(hash) {
self.candidate_hits = self.candidate_hits.saturating_add(1);
} else {
self.candidate_snapshot_hash = Some(hash);
self.candidate_hits = 1;
}
self.candidate_hits
}
fn reset_candidate(&mut self) {
self.candidate_snapshot_hash = None;
self.candidate_hits = 0;
}
fn mark_applied(&mut self, hash: u64) {
self.applied_snapshot_hash = Some(hash);
self.reset_candidate();
}
fn pending_candidate(&self) -> Option<(u64, u8)> {
let hash = self.candidate_snapshot_hash?;
if self.candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
return Some((hash, self.candidate_hits));
}
None
}
}
@ -481,15 +438,6 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
cfg.general.me_pool_drain_ttl_secs = new.general.me_pool_drain_ttl_secs;
cfg.general.me_instadrain = new.general.me_instadrain;
cfg.general.me_pool_drain_threshold = new.general.me_pool_drain_threshold;
cfg.general.me_pool_drain_soft_evict_enabled = new.general.me_pool_drain_soft_evict_enabled;
cfg.general.me_pool_drain_soft_evict_grace_secs =
new.general.me_pool_drain_soft_evict_grace_secs;
cfg.general.me_pool_drain_soft_evict_per_writer =
new.general.me_pool_drain_soft_evict_per_writer;
cfg.general.me_pool_drain_soft_evict_budget_per_core =
new.general.me_pool_drain_soft_evict_budget_per_core;
cfg.general.me_pool_drain_soft_evict_cooldown_ms =
new.general.me_pool_drain_soft_evict_cooldown_ms;
cfg.general.me_pool_min_fresh_ratio = new.general.me_pool_min_fresh_ratio;
cfg.general.me_reinit_drain_timeout_secs = new.general.me_reinit_drain_timeout_secs;
cfg.general.me_hardswap_warmup_delay_min_ms = new.general.me_hardswap_warmup_delay_min_ms;
@ -536,10 +484,14 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
new.general.me_adaptive_floor_writers_per_core_total;
cfg.general.me_adaptive_floor_cpu_cores_override =
new.general.me_adaptive_floor_cpu_cores_override;
cfg.general.me_adaptive_floor_max_extra_writers_single_per_core =
new.general.me_adaptive_floor_max_extra_writers_single_per_core;
cfg.general.me_adaptive_floor_max_extra_writers_multi_per_core =
new.general.me_adaptive_floor_max_extra_writers_multi_per_core;
cfg.general
.me_adaptive_floor_max_extra_writers_single_per_core = new
.general
.me_adaptive_floor_max_extra_writers_single_per_core;
cfg.general
.me_adaptive_floor_max_extra_writers_multi_per_core = new
.general
.me_adaptive_floor_max_extra_writers_multi_per_core;
cfg.general.me_adaptive_floor_max_active_writers_per_core =
new.general.me_adaptive_floor_max_active_writers_per_core;
cfg.general.me_adaptive_floor_max_warm_writers_per_core =
@ -598,8 +550,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.server.api.minimal_runtime_cache_ttl_ms
!= new.server.api.minimal_runtime_cache_ttl_ms
|| old.server.api.runtime_edge_enabled != new.server.api.runtime_edge_enabled
|| old.server.api.runtime_edge_cache_ttl_ms
!= new.server.api.runtime_edge_cache_ttl_ms
|| old.server.api.runtime_edge_cache_ttl_ms != new.server.api.runtime_edge_cache_ttl_ms
|| old.server.api.runtime_edge_top_n != new.server.api.runtime_edge_top_n
|| old.server.api.runtime_edge_events_capacity
!= new.server.api.runtime_edge_events_capacity
@ -615,8 +566,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.server.listen_tcp != new.server.listen_tcp
|| old.server.listen_unix_sock != new.server.listen_unix_sock
|| old.server.listen_unix_sock_perm != new.server.listen_unix_sock_perm
|| old.server.max_connections != new.server.max_connections
|| old.server.accept_permit_timeout_ms != new.server.accept_permit_timeout_ms
{
warned = true;
warn!("config reload: server listener settings changed; restart required");
@ -637,6 +586,19 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.tls_full_cert_ttl_secs != new.censorship.tls_full_cert_ttl_secs
|| old.censorship.alpn_enforce != new.censorship.alpn_enforce
|| old.censorship.mask_proxy_protocol != new.censorship.mask_proxy_protocol
|| old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening
|| old.censorship.mask_shape_bucket_floor_bytes
!= new.censorship.mask_shape_bucket_floor_bytes
|| old.censorship.mask_shape_bucket_cap_bytes != new.censorship.mask_shape_bucket_cap_bytes
|| old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur_max_bytes
!= new.censorship.mask_shape_above_cap_blur_max_bytes
|| old.censorship.mask_timing_normalization_enabled
!= new.censorship.mask_timing_normalization_enabled
|| old.censorship.mask_timing_normalization_floor_ms
!= new.censorship.mask_timing_normalization_floor_ms
|| old.censorship.mask_timing_normalization_ceiling_ms
!= new.censorship.mask_timing_normalization_ceiling_ms
{
warned = true;
warn!("config reload: censorship settings changed; restart required");
@ -677,9 +639,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
}
if old.general.me_route_no_writer_mode != new.general.me_route_no_writer_mode
|| old.general.me_route_no_writer_wait_ms != new.general.me_route_no_writer_wait_ms
|| old.general.me_route_hybrid_max_wait_ms != new.general.me_route_hybrid_max_wait_ms
|| old.general.me_route_blocking_send_timeout_ms
!= new.general.me_route_blocking_send_timeout_ms
|| old.general.me_route_inline_recovery_attempts
!= new.general.me_route_inline_recovery_attempts
|| old.general.me_route_inline_recovery_wait_ms
@ -688,10 +647,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
warned = true;
warn!("config reload: general.me_route_no_writer_* changed; restart required");
}
if old.general.me_c2me_send_timeout_ms != new.general.me_c2me_send_timeout_ms {
warned = true;
warn!("config reload: general.me_c2me_send_timeout_ms changed; restart required");
}
if old.general.unknown_dc_log_path != new.general.unknown_dc_log_path
|| old.general.unknown_dc_file_log_enabled != new.general.unknown_dc_file_log_enabled
{
@ -886,25 +841,6 @@ fn log_changes(
old_hot.me_pool_drain_threshold, new_hot.me_pool_drain_threshold,
);
}
if old_hot.me_pool_drain_soft_evict_enabled != new_hot.me_pool_drain_soft_evict_enabled
|| old_hot.me_pool_drain_soft_evict_grace_secs
!= new_hot.me_pool_drain_soft_evict_grace_secs
|| old_hot.me_pool_drain_soft_evict_per_writer
!= new_hot.me_pool_drain_soft_evict_per_writer
|| old_hot.me_pool_drain_soft_evict_budget_per_core
!= new_hot.me_pool_drain_soft_evict_budget_per_core
|| old_hot.me_pool_drain_soft_evict_cooldown_ms
!= new_hot.me_pool_drain_soft_evict_cooldown_ms
{
info!(
"config reload: me_pool_drain_soft_evict: enabled={} grace={}s per_writer={} budget_per_core={} cooldown={}ms",
new_hot.me_pool_drain_soft_evict_enabled,
new_hot.me_pool_drain_soft_evict_grace_secs,
new_hot.me_pool_drain_soft_evict_per_writer,
new_hot.me_pool_drain_soft_evict_budget_per_core,
new_hot.me_pool_drain_soft_evict_cooldown_ms
);
}
if (old_hot.me_pool_min_fresh_ratio - new_hot.me_pool_min_fresh_ratio).abs() > f32::EPSILON {
info!(
@ -938,8 +874,7 @@ fn log_changes(
{
info!(
"config reload: me_bind_stale: mode={:?} ttl={}s",
new_hot.me_bind_stale_mode,
new_hot.me_bind_stale_ttl_secs
new_hot.me_bind_stale_mode, new_hot.me_bind_stale_ttl_secs
);
}
if old_hot.me_secret_atomic_snapshot != new_hot.me_secret_atomic_snapshot
@ -1019,8 +954,7 @@ fn log_changes(
if old_hot.me_socks_kdf_policy != new_hot.me_socks_kdf_policy {
info!(
"config reload: me_socks_kdf_policy: {:?} → {:?}",
old_hot.me_socks_kdf_policy,
new_hot.me_socks_kdf_policy,
old_hot.me_socks_kdf_policy, new_hot.me_socks_kdf_policy,
);
}
@ -1074,8 +1008,7 @@ fn log_changes(
|| old_hot.me_route_backpressure_high_watermark_pct
!= new_hot.me_route_backpressure_high_watermark_pct
|| old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms
|| old_hot.me_health_interval_ms_unhealthy
!= new_hot.me_health_interval_ms_unhealthy
|| old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy
|| old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy
|| old_hot.me_admission_poll_ms != new_hot.me_admission_poll_ms
|| old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms
@ -1112,19 +1045,27 @@ fn log_changes(
}
if old_hot.users != new_hot.users {
let mut added: Vec<&String> = new_hot.users.keys()
let mut added: Vec<&String> = new_hot
.users
.keys()
.filter(|u| !old_hot.users.contains_key(*u))
.collect();
added.sort();
let mut removed: Vec<&String> = old_hot.users.keys()
let mut removed: Vec<&String> = old_hot
.users
.keys()
.filter(|u| !new_hot.users.contains_key(*u))
.collect();
removed.sort();
let mut changed: Vec<&String> = new_hot.users.keys()
let mut changed: Vec<&String> = new_hot
.users
.keys()
.filter(|u| {
old_hot.users.get(*u)
old_hot
.users
.get(*u)
.map(|s| s != &new_hot.users[*u])
.unwrap_or(false)
})
@ -1134,10 +1075,18 @@ fn log_changes(
if !added.is_empty() {
info!(
"config reload: users added: [{}]",
added.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")
added
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
);
let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6);
let port = new_cfg.general.links.public_port.unwrap_or(new_cfg.server.port);
let port = new_cfg
.general
.links
.public_port
.unwrap_or(new_cfg.server.port);
for user in &added {
if let Some(secret) = new_hot.users.get(*user) {
print_user_links(user, secret, &host, port, new_cfg);
@ -1147,13 +1096,21 @@ fn log_changes(
if !removed.is_empty() {
info!(
"config reload: users removed: [{}]",
removed.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")
removed
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
);
}
if !changed.is_empty() {
info!(
"config reload: users secret changed: [{}]",
changed.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")
changed
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
);
}
}
@ -1184,8 +1141,7 @@ fn log_changes(
}
if old_hot.user_max_unique_ips_global_each != new_hot.user_max_unique_ips_global_each
|| old_hot.user_max_unique_ips_mode != new_hot.user_max_unique_ips_mode
|| old_hot.user_max_unique_ips_window_secs
!= new_hot.user_max_unique_ips_window_secs
|| old_hot.user_max_unique_ips_window_secs != new_hot.user_max_unique_ips_window_secs
{
info!(
"config reload: user_max_unique_ips policy global_each={} mode={:?} window={}s",
@ -1208,7 +1164,6 @@ fn reload_config(
let loaded = match ProxyConfig::load_with_metadata(config_path) {
Ok(loaded) => loaded,
Err(e) => {
reload_state.reset_candidate();
error!("config reload: failed to parse {:?}: {}", config_path, e);
return None;
}
@ -1221,8 +1176,10 @@ fn reload_config(
let next_manifest = WatchManifest::from_source_files(&source_files);
if let Err(e) = new_cfg.validate() {
reload_state.reset_candidate();
error!("config reload: validation failed: {}; keeping old config", e);
error!(
"config reload: validation failed: {}; keeping old config",
e
);
return Some(next_manifest);
}
@ -1230,17 +1187,6 @@ fn reload_config(
return Some(next_manifest);
}
let candidate_hits = reload_state.observe_candidate(rendered_hash);
if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
info!(
snapshot_hash = rendered_hash,
candidate_hits,
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
"config reload: candidate snapshot observed but not stable yet"
);
return Some(next_manifest);
}
let old_cfg = config_tx.borrow().clone();
let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg);
let old_hot = HotFields::from_config(&old_cfg);
@ -1260,7 +1206,6 @@ fn reload_config(
if old_hot.dns_overrides != applied_hot.dns_overrides
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
{
reload_state.reset_candidate();
error!(
"config reload: invalid network.dns_overrides: {}; keeping old config",
e
@ -1281,73 +1226,6 @@ fn reload_config(
Some(next_manifest)
}
async fn reload_with_internal_stable_rechecks(
config_path: &PathBuf,
config_tx: &watch::Sender<Arc<ProxyConfig>>,
log_tx: &watch::Sender<LogLevel>,
detected_ip_v4: Option<IpAddr>,
detected_ip_v6: Option<IpAddr>,
reload_state: &mut ReloadState,
) -> Option<WatchManifest> {
let mut next_manifest = reload_config(
config_path,
config_tx,
log_tx,
detected_ip_v4,
detected_ip_v6,
reload_state,
);
let mut rechecks_left = HOT_RELOAD_STABLE_SNAPSHOTS.saturating_sub(1);
while rechecks_left > 0 {
let Some((snapshot_hash, candidate_hits)) = reload_state.pending_candidate() else {
break;
};
info!(
snapshot_hash,
candidate_hits,
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
rechecks_left,
recheck_delay_ms = HOT_RELOAD_STABLE_RECHECK.as_millis(),
"config reload: scheduling internal stable recheck"
);
tokio::time::sleep(HOT_RELOAD_STABLE_RECHECK).await;
let recheck_manifest = reload_config(
config_path,
config_tx,
log_tx,
detected_ip_v4,
detected_ip_v6,
reload_state,
);
if recheck_manifest.is_some() {
next_manifest = recheck_manifest;
}
if reload_state.is_applied(snapshot_hash) {
info!(
snapshot_hash,
"config reload: applied after internal stable recheck"
);
break;
}
if reload_state.pending_candidate().is_none() {
info!(
snapshot_hash,
"config reload: internal stable recheck aborted"
);
break;
}
rechecks_left = rechecks_left.saturating_sub(1);
}
next_manifest
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the hot-reload watcher task.
@ -1383,9 +1261,13 @@ pub fn spawn_config_watcher(
let tx_inotify = notify_tx.clone();
let manifest_for_inotify = manifest_state.clone();
let mut inotify_watcher = match recommended_watcher(move |res: notify::Result<notify::Event>| {
let mut inotify_watcher =
match recommended_watcher(move |res: notify::Result<notify::Event>| {
let Ok(event) = res else { return };
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
if !matches!(
event.kind,
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
) {
return;
}
let is_our_file = manifest_for_inotify
@ -1417,7 +1299,10 @@ pub fn spawn_config_watcher(
let mut poll_watcher = match notify::poll::PollWatcher::new(
move |res: notify::Result<notify::Event>| {
let Ok(event) = res else { return };
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
if !matches!(
event.kind,
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
) {
return;
}
let is_our_file = manifest_for_poll
@ -1465,22 +1350,36 @@ pub fn spawn_config_watcher(
}
}
#[cfg(not(unix))]
if notify_rx.recv().await.is_none() { break; }
if notify_rx.recv().await.is_none() {
break;
}
// Debounce: drain extra events that arrive within a short quiet window.
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
while notify_rx.try_recv().is_ok() {}
if let Some(next_manifest) = reload_with_internal_stable_rechecks(
let mut next_manifest = reload_config(
&config_path,
&config_tx,
&log_tx,
detected_ip_v4,
detected_ip_v6,
&mut reload_state,
)
.await
{
);
if next_manifest.is_none() {
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
while notify_rx.try_recv().is_ok() {}
next_manifest = reload_config(
&config_path,
&config_tx,
&log_tx,
detected_ip_v4,
detected_ip_v6,
&mut reload_state,
);
}
if let Some(next_manifest) = next_manifest {
apply_watch_manifest(
inotify_watcher.as_mut(),
poll_watcher.as_mut(),
@ -1555,7 +1454,10 @@ mod tests {
new.server.port = old.server.port.saturating_add(1);
let applied = overlay_hot_fields(&old, &new);
assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied));
assert_eq!(
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
assert_eq!(applied.server.port, old.server.port);
}
@ -1574,7 +1476,10 @@ mod tests {
applied.general.me_bind_stale_mode,
new.general.me_bind_stale_mode
);
assert_ne!(HotFields::from_config(&old), HotFields::from_config(&applied));
assert_ne!(
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
}
#[test]
@ -1588,7 +1493,10 @@ mod tests {
applied.general.me_keepalive_interval_secs,
old.general.me_keepalive_interval_secs
);
assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied));
assert_eq!(
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
}
#[test]
@ -1600,69 +1508,35 @@ mod tests {
let applied = overlay_hot_fields(&old, &new);
assert_eq!(applied.general.hardswap, new.general.hardswap);
assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy);
assert_eq!(
applied.general.use_middle_proxy,
old.general.use_middle_proxy
);
assert!(!config_equal(&applied, &new));
}
#[test]
fn reload_requires_stable_snapshot_before_hot_apply() {
fn reload_applies_hot_change_on_first_observed_snapshot() {
let initial_tag = "11111111111111111111111111111111";
let final_tag = "22222222222222222222222222222222";
let path = temp_config_path("telemt_hot_reload_stable");
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, None, None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
let _ = std::fs::remove_file(path);
}
#[tokio::test]
async fn reload_cycle_applies_after_single_external_event() {
let initial_tag = "10101010101010101010101010101010";
let final_tag = "20202020202020202020202020202020";
let path = temp_config_path("telemt_hot_reload_single_event");
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
let initial_hash = ProxyConfig::load_with_metadata(&path)
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, Some(final_tag), None);
reload_with_internal_stable_rechecks(
&path,
&config_tx,
&log_tx,
None,
None,
&mut reload_state,
)
.await
.unwrap();
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(final_tag)
);
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
let _ = std::fs::remove_file(path);
}
@ -1674,14 +1548,15 @@ mod tests {
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
let initial_hash = ProxyConfig::load_with_metadata(&path)
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
let applied = config_tx.borrow().clone();
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
@ -1689,4 +1564,36 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn reload_recovers_after_parse_error_on_next_attempt() {
let initial_tag = "cccccccccccccccccccccccccccccccc";
let final_tag = "dddddddddddddddddddddddddddddddd";
let path = temp_config_path("telemt_hot_reload_parse_recovery");
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path)
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap();
assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none());
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(initial_tag)
);
write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(final_tag)
);
let _ = std::fs::remove_file(path);
}
}

View File

@ -5,7 +5,7 @@ use std::hash::{DefaultHasher, Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf};
use rand::Rng;
use rand::RngExt;
use serde::{Deserialize, Serialize};
use shadowsocks::config::ServerConfig as ShadowsocksServerConfig;
use tracing::warn;
@ -360,6 +360,131 @@ impl ProxyConfig {
));
}
if config.timeouts.client_handshake == 0 {
return Err(ProxyError::Config(
"timeouts.client_handshake must be > 0".to_string(),
));
}
let handshake_timeout_ms = config
.timeouts
.client_handshake
.checked_mul(1000)
.ok_or_else(|| {
ProxyError::Config(
"timeouts.client_handshake is too large to validate milliseconds budget"
.to_string(),
)
})?;
if config.censorship.server_hello_delay_max_ms >= handshake_timeout_ms {
return Err(ProxyError::Config(
"censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"
.to_string(),
));
}
if config.censorship.mask_shape_bucket_floor_bytes == 0 {
return Err(ProxyError::Config(
"censorship.mask_shape_bucket_floor_bytes must be > 0".to_string(),
));
}
if config.censorship.mask_shape_bucket_cap_bytes
< config.censorship.mask_shape_bucket_floor_bytes
{
return Err(ProxyError::Config(
"censorship.mask_shape_bucket_cap_bytes must be >= censorship.mask_shape_bucket_floor_bytes"
.to_string(),
));
}
if config.censorship.mask_shape_above_cap_blur && !config.censorship.mask_shape_hardening {
return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"
.to_string(),
));
}
if config.censorship.mask_shape_hardening_aggressive_mode
&& !config.censorship.mask_shape_hardening
{
return Err(ProxyError::Config(
"censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true"
.to_string(),
));
}
if config.censorship.mask_shape_above_cap_blur
&& config.censorship.mask_shape_above_cap_blur_max_bytes == 0
{
return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"
.to_string(),
));
}
if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 {
return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576".to_string(),
));
}
if config.censorship.mask_timing_normalization_ceiling_ms
< config.censorship.mask_timing_normalization_floor_ms
{
return Err(ProxyError::Config(
"censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"
.to_string(),
));
}
if config.censorship.mask_timing_normalization_enabled
&& config.censorship.mask_timing_normalization_floor_ms == 0
{
return Err(ProxyError::Config(
"censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"
.to_string(),
));
}
if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 {
return Err(ProxyError::Config(
"censorship.mask_timing_normalization_ceiling_ms must be <= 60000".to_string(),
));
}
if config.timeouts.relay_client_idle_soft_secs == 0 {
return Err(ProxyError::Config(
"timeouts.relay_client_idle_soft_secs must be > 0".to_string(),
));
}
if config.timeouts.relay_client_idle_hard_secs == 0 {
return Err(ProxyError::Config(
"timeouts.relay_client_idle_hard_secs must be > 0".to_string(),
));
}
if config.timeouts.relay_client_idle_hard_secs < config.timeouts.relay_client_idle_soft_secs
{
return Err(ProxyError::Config(
"timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"
.to_string(),
));
}
if config
.timeouts
.relay_idle_grace_after_downstream_activity_secs
> config.timeouts.relay_client_idle_hard_secs
{
return Err(ProxyError::Config(
"timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs"
.to_string(),
));
}
if config.general.me_writer_cmd_channel_capacity == 0 {
return Err(ProxyError::Config(
"general.me_writer_cmd_channel_capacity must be > 0".to_string(),
@ -648,7 +773,8 @@ impl ProxyConfig {
}
if config.general.me_route_backpressure_base_timeout_ms > 5000 {
return Err(ProxyError::Config(
"general.me_route_backpressure_base_timeout_ms must be within [1, 5000]".to_string(),
"general.me_route_backpressure_base_timeout_ms must be within [1, 5000]"
.to_string(),
));
}
@ -661,7 +787,8 @@ impl ProxyConfig {
}
if config.general.me_route_backpressure_high_timeout_ms > 5000 {
return Err(ProxyError::Config(
"general.me_route_backpressure_high_timeout_ms must be within [1, 5000]".to_string(),
"general.me_route_backpressure_high_timeout_ms must be within [1, 5000]"
.to_string(),
));
}
@ -860,7 +987,7 @@ impl ProxyConfig {
if !config.censorship.tls_emulation
&& config.censorship.fake_cert_len == default_fake_cert_len()
{
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
config.censorship.fake_cert_len = rand::rng().random_range(1024..4096);
}
// Resolve listen_tcp: explicit value wins, otherwise auto-detect.
@ -982,6 +1109,18 @@ impl ProxyConfig {
}
}
#[cfg(test)]
#[path = "tests/load_idle_policy_tests.rs"]
mod load_idle_policy_tests;
#[cfg(test)]
#[path = "tests/load_security_tests.rs"]
mod load_security_tests;
#[cfg(test)]
#[path = "tests/load_mask_shape_security_tests.rs"]
mod load_mask_shape_security_tests;
#[cfg(test)]
mod tests {
use super::*;
@ -1697,7 +1836,9 @@ mod tests {
let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]"));
assert!(
err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]")
);
let _ = std::fs::remove_file(path);
}
@ -1718,7 +1859,9 @@ mod tests {
let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]"));
assert!(
err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]")
);
let _ = std::fs::remove_file(path);
}

View File

@ -1,9 +1,9 @@
//! Configuration.
pub(crate) mod defaults;
mod types;
mod load;
pub mod hot_reload;
mod load;
mod types;
pub use load::ProxyConfig;
pub use types::*;

View File

@ -0,0 +1,80 @@
use super::*;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
fn write_temp_config(contents: &str) -> PathBuf {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after unix epoch")
.as_nanos();
let path = std::env::temp_dir().join(format!("telemt-idle-policy-{nonce}.toml"));
fs::write(&path, contents).expect("temp config write must succeed");
path
}
fn remove_temp_config(path: &PathBuf) {
let _ = fs::remove_file(path);
}
#[test]
fn load_rejects_relay_hard_idle_smaller_than_soft_idle_with_clear_error() {
let path = write_temp_config(
r#"
[timeouts]
relay_client_idle_soft_secs = 120
relay_client_idle_hard_secs = 60
"#,
);
let err = ProxyConfig::load(&path).expect_err("config with hard<soft must fail");
let msg = err.to_string();
assert!(
msg.contains(
"timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"
),
"error must explain the violated hard>=soft invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_relay_grace_larger_than_hard_idle_with_clear_error() {
let path = write_temp_config(
r#"
[timeouts]
relay_client_idle_soft_secs = 60
relay_client_idle_hard_secs = 120
relay_idle_grace_after_downstream_activity_secs = 121
"#,
);
let err = ProxyConfig::load(&path).expect_err("config with grace>hard must fail");
let msg = err.to_string();
assert!(
msg.contains("timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs"),
"error must explain the violated grace<=hard invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_zero_handshake_timeout_with_clear_error() {
let path = write_temp_config(
r#"
[timeouts]
client_handshake = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("config with zero handshake timeout must fail");
let msg = err.to_string();
assert!(
msg.contains("timeouts.client_handshake must be > 0"),
"error must explain that handshake timeout must be positive, got: {msg}"
);
remove_temp_config(&path);
}

View File

@ -0,0 +1,238 @@
use super::*;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
fn write_temp_config(contents: &str) -> PathBuf {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after unix epoch")
.as_nanos();
let path = std::env::temp_dir().join(format!("telemt-load-mask-shape-security-{nonce}.toml"));
fs::write(&path, contents).expect("temp config write must succeed");
path
}
fn remove_temp_config(path: &PathBuf) {
let _ = fs::remove_file(path);
}
#[test]
fn load_rejects_zero_mask_shape_bucket_floor_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_bucket_floor_bytes = 0
mask_shape_bucket_cap_bytes = 4096
"#,
);
let err =
ProxyConfig::load(&path).expect_err("zero mask_shape_bucket_floor_bytes must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_shape_bucket_floor_bytes must be > 0"),
"error must explain floor>0 invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_shape_bucket_cap_less_than_floor() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_bucket_floor_bytes = 1024
mask_shape_bucket_cap_bytes = 512
"#,
);
let err =
ProxyConfig::load(&path).expect_err("mask_shape_bucket_cap_bytes < floor must be rejected");
let msg = err.to_string();
assert!(
msg.contains(
"censorship.mask_shape_bucket_cap_bytes must be >= censorship.mask_shape_bucket_floor_bytes"
),
"error must explain cap>=floor invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_mask_shape_bucket_cap_equal_to_floor() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = true
mask_shape_bucket_floor_bytes = 1024
mask_shape_bucket_cap_bytes = 1024
"#,
);
let cfg = ProxyConfig::load(&path).expect("equal cap and floor must be accepted");
assert!(cfg.censorship.mask_shape_hardening);
assert_eq!(cfg.censorship.mask_shape_bucket_floor_bytes, 1024);
assert_eq!(cfg.censorship.mask_shape_bucket_cap_bytes, 1024);
remove_temp_config(&path);
}
#[test]
fn load_rejects_above_cap_blur_when_shape_hardening_disabled() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = false
mask_shape_above_cap_blur = true
mask_shape_above_cap_blur_max_bytes = 64
"#,
);
let err =
ProxyConfig::load(&path).expect_err("above-cap blur must require shape hardening enabled");
let msg = err.to_string();
assert!(
msg.contains(
"censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"
),
"error must explain blur prerequisite, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_above_cap_blur_with_zero_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = true
mask_shape_above_cap_blur = true
mask_shape_above_cap_blur_max_bytes = 0
"#,
);
let err =
ProxyConfig::load(&path).expect_err("above-cap blur max bytes must be > 0 when enabled");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"),
"error must explain blur max bytes invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_timing_normalization_floor_zero_when_enabled() {
let path = write_temp_config(
r#"
[censorship]
mask_timing_normalization_enabled = true
mask_timing_normalization_floor_ms = 0
mask_timing_normalization_ceiling_ms = 200
"#,
);
let err =
ProxyConfig::load(&path).expect_err("timing normalization floor must be > 0 when enabled");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"),
"error must explain timing floor invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_timing_normalization_ceiling_below_floor() {
let path = write_temp_config(
r#"
[censorship]
mask_timing_normalization_enabled = true
mask_timing_normalization_floor_ms = 220
mask_timing_normalization_ceiling_ms = 200
"#,
);
let err = ProxyConfig::load(&path).expect_err("timing normalization ceiling must be >= floor");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"),
"error must explain timing ceiling/floor invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_valid_timing_normalization_and_above_cap_blur_config() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = true
mask_shape_above_cap_blur = true
mask_shape_above_cap_blur_max_bytes = 128
mask_timing_normalization_enabled = true
mask_timing_normalization_floor_ms = 150
mask_timing_normalization_ceiling_ms = 240
"#,
);
let cfg = ProxyConfig::load(&path)
.expect("valid blur and timing normalization settings must be accepted");
assert!(cfg.censorship.mask_shape_hardening);
assert!(cfg.censorship.mask_shape_above_cap_blur);
assert_eq!(cfg.censorship.mask_shape_above_cap_blur_max_bytes, 128);
assert!(cfg.censorship.mask_timing_normalization_enabled);
assert_eq!(cfg.censorship.mask_timing_normalization_floor_ms, 150);
assert_eq!(cfg.censorship.mask_timing_normalization_ceiling_ms, 240);
remove_temp_config(&path);
}
#[test]
fn load_rejects_aggressive_shape_mode_when_shape_hardening_disabled() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = false
mask_shape_hardening_aggressive_mode = true
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("aggressive shape hardening mode must require shape hardening enabled");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true"),
"error must explain aggressive-mode prerequisite, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_aggressive_shape_mode_when_shape_hardening_enabled() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = true
mask_shape_hardening_aggressive_mode = true
mask_shape_above_cap_blur = true
mask_shape_above_cap_blur_max_bytes = 8
"#,
);
let cfg = ProxyConfig::load(&path)
.expect("aggressive shape hardening mode should be accepted when prerequisites are met");
assert!(cfg.censorship.mask_shape_hardening);
assert!(cfg.censorship.mask_shape_hardening_aggressive_mode);
assert!(cfg.censorship.mask_shape_above_cap_blur);
remove_temp_config(&path);
}

View File

@ -0,0 +1,88 @@
use super::*;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
fn write_temp_config(contents: &str) -> PathBuf {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after unix epoch")
.as_nanos();
let path = std::env::temp_dir().join(format!("telemt-load-security-{nonce}.toml"));
fs::write(&path, contents).expect("temp config write must succeed");
path
}
fn remove_temp_config(path: &PathBuf) {
let _ = fs::remove_file(path);
}
#[test]
fn load_rejects_server_hello_delay_equal_to_handshake_timeout_budget() {
let path = write_temp_config(
r#"
[timeouts]
client_handshake = 1
[censorship]
server_hello_delay_max_ms = 1000
"#,
);
let err =
ProxyConfig::load(&path).expect_err("delay equal to handshake timeout must be rejected");
let msg = err.to_string();
assert!(
msg.contains(
"censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"
),
"error must explain delay<timeout invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_server_hello_delay_larger_than_handshake_timeout_budget() {
let path = write_temp_config(
r#"
[timeouts]
client_handshake = 1
[censorship]
server_hello_delay_max_ms = 1500
"#,
);
let err =
ProxyConfig::load(&path).expect_err("delay larger than handshake timeout must be rejected");
let msg = err.to_string();
assert!(
msg.contains(
"censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"
),
"error must explain delay<timeout invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_server_hello_delay_strictly_below_handshake_timeout_budget() {
let path = write_temp_config(
r#"
[timeouts]
client_handshake = 1
[censorship]
server_hello_delay_max_ms = 999
"#,
);
let cfg =
ProxyConfig::load(&path).expect("delay below handshake timeout budget must be accepted");
assert_eq!(cfg.timeouts.client_handshake, 1);
assert_eq!(cfg.censorship.server_hello_delay_max_ms, 999);
remove_temp_config(&path);
}

View File

@ -1047,8 +1047,7 @@ impl Default for GeneralConfig {
me_pool_drain_soft_evict_per_writer: default_me_pool_drain_soft_evict_per_writer(),
me_pool_drain_soft_evict_budget_per_core:
default_me_pool_drain_soft_evict_budget_per_core(),
me_pool_drain_soft_evict_cooldown_ms:
default_me_pool_drain_soft_evict_cooldown_ms(),
me_pool_drain_soft_evict_cooldown_ms: default_me_pool_drain_soft_evict_cooldown_ms(),
me_bind_stale_mode: MeBindStaleMode::default(),
me_bind_stale_ttl_secs: default_me_bind_stale_ttl_secs(),
me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(),
@ -1228,6 +1227,13 @@ pub struct ServerConfig {
#[serde(default = "default_proxy_protocol_header_timeout_ms")]
pub proxy_protocol_header_timeout_ms: u64,
/// Trusted source CIDRs allowed to send incoming PROXY protocol headers.
///
/// When non-empty, connections from addresses outside this allowlist are
/// rejected before `src_addr` is applied.
#[serde(default)]
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
/// Port for the Prometheus-compatible metrics endpoint.
/// Enables metrics when set; binds on all interfaces (dual-stack) by default.
#[serde(default)]
@ -1270,6 +1276,7 @@ impl Default for ServerConfig {
listen_tcp: None,
proxy_protocol: false,
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: Vec::new(),
metrics_port: None,
metrics_listen: None,
metrics_whitelist: default_metrics_whitelist(),
@ -1286,6 +1293,24 @@ pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")]
pub client_handshake: u64,
/// Enables soft/hard relay client idle policy for middle-relay sessions.
#[serde(default = "default_relay_idle_policy_v2_enabled")]
pub relay_idle_policy_v2_enabled: bool,
/// Soft idle threshold for middle-relay client uplink activity in seconds.
/// Hitting this threshold marks the session as idle-candidate, but does not close it.
#[serde(default = "default_relay_client_idle_soft_secs")]
pub relay_client_idle_soft_secs: u64,
/// Hard idle threshold for middle-relay client uplink activity in seconds.
/// Hitting this threshold closes the session.
#[serde(default = "default_relay_client_idle_hard_secs")]
pub relay_client_idle_hard_secs: u64,
/// Additional grace in seconds added to hard idle window after recent downstream activity.
#[serde(default = "default_relay_idle_grace_after_downstream_activity_secs")]
pub relay_idle_grace_after_downstream_activity_secs: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect: u64,
@ -1308,6 +1333,11 @@ impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
client_handshake: default_handshake_timeout(),
relay_idle_policy_v2_enabled: default_relay_idle_policy_v2_enabled(),
relay_client_idle_soft_secs: default_relay_client_idle_soft_secs(),
relay_client_idle_hard_secs: default_relay_client_idle_hard_secs(),
relay_idle_grace_after_downstream_activity_secs:
default_relay_idle_grace_after_downstream_activity_secs(),
tg_connect: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack: default_ack_timeout(),
@ -1381,6 +1411,46 @@ pub struct AntiCensorshipConfig {
/// Allows the backend to see the real client IP.
#[serde(default)]
pub mask_proxy_protocol: u8,
/// Enable shape-channel hardening on mask backend path by padding
/// client->mask stream tail to configured buckets on stream end.
#[serde(default = "default_mask_shape_hardening")]
pub mask_shape_hardening: bool,
/// Opt-in aggressive shape hardening mode.
/// When enabled, masking may shape some backend-silent timeout paths and
/// enforces strictly positive above-cap blur when blur is enabled.
#[serde(default = "default_mask_shape_hardening_aggressive_mode")]
pub mask_shape_hardening_aggressive_mode: bool,
/// Minimum bucket size for mask shape hardening padding.
#[serde(default = "default_mask_shape_bucket_floor_bytes")]
pub mask_shape_bucket_floor_bytes: usize,
/// Maximum bucket size for mask shape hardening padding.
#[serde(default = "default_mask_shape_bucket_cap_bytes")]
pub mask_shape_bucket_cap_bytes: usize,
/// Add bounded random tail bytes even when total bytes already exceed
/// mask_shape_bucket_cap_bytes.
#[serde(default = "default_mask_shape_above_cap_blur")]
pub mask_shape_above_cap_blur: bool,
/// Maximum random bytes appended above cap when above-cap blur is enabled.
#[serde(default = "default_mask_shape_above_cap_blur_max_bytes")]
pub mask_shape_above_cap_blur_max_bytes: usize,
/// Enable outcome-time normalization envelope for masking fallback.
#[serde(default = "default_mask_timing_normalization_enabled")]
pub mask_timing_normalization_enabled: bool,
/// Lower bound (ms) for masking outcome timing envelope.
#[serde(default = "default_mask_timing_normalization_floor_ms")]
pub mask_timing_normalization_floor_ms: u64,
/// Upper bound (ms) for masking outcome timing envelope.
#[serde(default = "default_mask_timing_normalization_ceiling_ms")]
pub mask_timing_normalization_ceiling_ms: u64,
}
impl Default for AntiCensorshipConfig {
@ -1402,6 +1472,15 @@ impl Default for AntiCensorshipConfig {
tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(),
alpn_enforce: default_alpn_enforce(),
mask_proxy_protocol: 0,
mask_shape_hardening: default_mask_shape_hardening(),
mask_shape_hardening_aggressive_mode: default_mask_shape_hardening_aggressive_mode(),
mask_shape_bucket_floor_bytes: default_mask_shape_bucket_floor_bytes(),
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),
mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(),
mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(),
mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(),
mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(),
}
}
}

View File

@ -13,10 +13,13 @@
#![allow(dead_code)]
use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use zeroize::Zeroize;
use crate::error::{ProxyError, Result};
use aes::Aes256;
use ctr::{
Ctr128BE,
cipher::{KeyIvInit, StreamCipher},
};
use zeroize::Zeroize;
type Aes256Ctr = Ctr128BE<Aes256>;
@ -46,10 +49,16 @@ impl AesCtr {
/// Create from key and IV slices
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
return Err(ProxyError::InvalidKeyLength {
expected: 32,
got: key.len(),
});
}
if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
return Err(ProxyError::InvalidKeyLength {
expected: 16,
got: iv.len(),
});
}
let key: [u8; 32] = key.try_into().unwrap();
@ -108,10 +117,16 @@ impl AesCbc {
/// Create from slices
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
return Err(ProxyError::InvalidKeyLength {
expected: 32,
got: key.len(),
});
}
if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
return Err(ProxyError::InvalidKeyLength {
expected: 16,
got: iv.len(),
});
}
Ok(Self {
@ -150,9 +165,10 @@ impl AesCbc {
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
return Err(ProxyError::Crypto(format!(
"CBC data must be aligned to 16 bytes, got {}",
data.len()
)));
}
if data.is_empty() {
@ -181,9 +197,10 @@ impl AesCbc {
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
return Err(ProxyError::Crypto(format!(
"CBC data must be aligned to 16 bytes, got {}",
data.len()
)));
}
if data.is_empty() {
@ -210,9 +227,10 @@ impl AesCbc {
/// Encrypt data in-place
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
return Err(ProxyError::Crypto(format!(
"CBC data must be aligned to 16 bytes, got {}",
data.len()
)));
}
if data.is_empty() {
@ -243,9 +261,10 @@ impl AesCbc {
/// Decrypt data in-place
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(
format!("CBC data must be aligned to 16 bytes, got {}", data.len())
));
return Err(ProxyError::Crypto(format!(
"CBC data must be aligned to 16 bytes, got {}",
data.len()
)));
}
if data.is_empty() {

View File

@ -12,10 +12,10 @@
//! usages are intentional and protocol-mandated.
use hmac::{Hmac, Mac};
use sha2::Sha256;
use md5::Md5;
use sha1::Sha1;
use sha2::Digest;
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
@ -28,8 +28,7 @@ pub fn sha256(data: &[u8]) -> [u8; 32] {
/// SHA-256 HMAC
pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(key)
.expect("HMAC accepts any key length");
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(data);
mac.finalize().into_bytes().into()
}
@ -124,17 +123,8 @@ pub fn derive_middleproxy_keys(
srv_ipv6: Option<&[u8; 16]>,
) -> ([u8; 32], [u8; 16]) {
let s = build_middleproxy_prekey(
nonce_srv,
nonce_clt,
clt_ts,
srv_ip,
clt_port,
purpose,
clt_ip,
srv_port,
secret,
clt_ipv6,
srv_ipv6,
nonce_srv, nonce_clt, clt_ts, srv_ip, clt_port, purpose, clt_ip, srv_port, secret,
clt_ipv6, srv_ipv6,
);
let md5_1 = md5(&s[1..]);
@ -164,17 +154,8 @@ mod tests {
let secret = vec![0x55u8; 128];
let prekey = build_middleproxy_prekey(
&nonce_srv,
&nonce_clt,
&clt_ts,
srv_ip,
&clt_port,
b"CLIENT",
clt_ip,
&srv_port,
&secret,
None,
None,
&nonce_srv, &nonce_clt, &clt_ts, srv_ip, &clt_port, b"CLIENT", clt_ip, &srv_port,
&secret, None, None,
);
let digest = sha256(&prekey);
assert_eq!(

View File

@ -4,7 +4,7 @@ pub mod aes;
pub mod hash;
pub mod random;
pub use aes::{AesCtr, AesCbc};
pub use aes::{AesCbc, AesCtr};
pub use hash::{
build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, sha256, sha256_hmac,
};

View File

@ -3,11 +3,11 @@
#![allow(deprecated)]
#![allow(dead_code)]
use rand::{Rng, RngCore, SeedableRng};
use rand::rngs::StdRng;
use parking_lot::Mutex;
use zeroize::Zeroize;
use crate::crypto::AesCtr;
use parking_lot::Mutex;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng};
use zeroize::Zeroize;
/// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom {
@ -101,7 +101,7 @@ impl SecureRandom {
return 0;
}
let mut inner = self.inner.lock();
inner.rng.gen_range(0..max)
inner.rng.random_range(0..max)
}
/// Generate random bits
@ -141,7 +141,7 @@ impl SecureRandom {
pub fn shuffle<T>(&self, slice: &mut [T]) {
let mut inner = self.inner.lock();
for i in (1..slice.len()).rev() {
let j = inner.rng.gen_range(0..=i);
let j = inner.rng.random_range(0..=i);
slice.swap(i, j);
}
}

533
src/daemon/mod.rs Normal file
View File

@ -0,0 +1,533 @@
//! Unix daemon support for telemt.
//!
//! Provides classic Unix daemonization (double-fork), PID file management,
//! and privilege dropping for running telemt as a background service.
use std::fs::{self, File, OpenOptions};
use std::io::{self, Read, Write};
use std::os::unix::fs::OpenOptionsExt;
use std::path::{Path, PathBuf};
use nix::fcntl::{Flock, FlockArg};
use nix::unistd::{self, ForkResult, Gid, Pid, Uid, chdir, close, fork, getpid, setsid};
use tracing::{debug, info, warn};
/// Default PID file location.
pub const DEFAULT_PID_FILE: &str = "/var/run/telemt.pid";
/// Daemon configuration options parsed from CLI.
#[derive(Debug, Clone, Default)]
pub struct DaemonOptions {
/// Run as daemon (fork to background).
pub daemonize: bool,
/// Path to PID file.
pub pid_file: Option<PathBuf>,
/// User to run as after binding sockets.
pub user: Option<String>,
/// Group to run as after binding sockets.
pub group: Option<String>,
/// Working directory for the daemon.
pub working_dir: Option<PathBuf>,
/// Explicit foreground mode (for systemd Type=simple).
pub foreground: bool,
}
impl DaemonOptions {
/// Returns the effective PID file path.
pub fn pid_file_path(&self) -> &Path {
self.pid_file
.as_deref()
.unwrap_or(Path::new(DEFAULT_PID_FILE))
}
/// Returns true if we should actually daemonize.
/// Foreground flag takes precedence.
pub fn should_daemonize(&self) -> bool {
self.daemonize && !self.foreground
}
}
/// Error types for daemon operations.
#[derive(Debug, thiserror::Error)]
pub enum DaemonError {
#[error("fork failed: {0}")]
ForkFailed(#[source] nix::Error),
#[error("setsid failed: {0}")]
SetsidFailed(#[source] nix::Error),
#[error("chdir failed: {0}")]
ChdirFailed(#[source] nix::Error),
#[error("failed to open /dev/null: {0}")]
DevNullFailed(#[source] io::Error),
#[error("failed to redirect stdio: {0}")]
RedirectFailed(#[source] nix::Error),
#[error("PID file error: {0}")]
PidFile(String),
#[error("another instance is already running (pid {0})")]
AlreadyRunning(i32),
#[error("user '{0}' not found")]
UserNotFound(String),
#[error("group '{0}' not found")]
GroupNotFound(String),
#[error("failed to set uid/gid: {0}")]
PrivilegeDrop(#[source] nix::Error),
#[error("io error: {0}")]
Io(#[from] io::Error),
}
/// Result of a successful daemonize() call.
#[derive(Debug)]
pub enum DaemonizeResult {
/// We are the parent process and should exit.
Parent,
/// We are the daemon child process and should continue.
Child,
}
/// Performs classic Unix double-fork daemonization.
///
/// This detaches the process from the controlling terminal:
/// 1. First fork - parent exits, child continues
/// 2. setsid() - become session leader
/// 3. Second fork - ensure we can never acquire a controlling terminal
/// 4. chdir("/") - don't hold any directory open
/// 5. Redirect stdin/stdout/stderr to /dev/null
///
/// Returns `DaemonizeResult::Parent` in the original parent (which should exit),
/// or `DaemonizeResult::Child` in the final daemon child.
pub fn daemonize(working_dir: Option<&Path>) -> Result<DaemonizeResult, DaemonError> {
// First fork
match unsafe { fork() } {
Ok(ForkResult::Parent { .. }) => {
// Parent exits
return Ok(DaemonizeResult::Parent);
}
Ok(ForkResult::Child) => {
// Child continues
}
Err(e) => return Err(DaemonError::ForkFailed(e)),
}
// Create new session, become session leader
setsid().map_err(DaemonError::SetsidFailed)?;
// Second fork to ensure we can never acquire a controlling terminal
match unsafe { fork() } {
Ok(ForkResult::Parent { .. }) => {
// Intermediate parent exits
std::process::exit(0);
}
Ok(ForkResult::Child) => {
// Final daemon child continues
}
Err(e) => return Err(DaemonError::ForkFailed(e)),
}
// Change working directory
let target_dir = working_dir.unwrap_or(Path::new("/"));
chdir(target_dir).map_err(DaemonError::ChdirFailed)?;
// Redirect stdin, stdout, stderr to /dev/null
redirect_stdio_to_devnull()?;
Ok(DaemonizeResult::Child)
}
/// Redirects stdin, stdout, and stderr to /dev/null.
fn redirect_stdio_to_devnull() -> Result<(), DaemonError> {
let devnull = File::options()
.read(true)
.write(true)
.open("/dev/null")
.map_err(DaemonError::DevNullFailed)?;
let devnull_fd = std::os::unix::io::AsRawFd::as_raw_fd(&devnull);
// Use libc::dup2 directly for redirecting standard file descriptors
// nix 0.31's dup2 requires OwnedFd which doesn't work well with stdio fds
unsafe {
// Redirect stdin (fd 0)
if libc::dup2(devnull_fd, 0) < 0 {
return Err(DaemonError::RedirectFailed(nix::errno::Errno::last()));
}
// Redirect stdout (fd 1)
if libc::dup2(devnull_fd, 1) < 0 {
return Err(DaemonError::RedirectFailed(nix::errno::Errno::last()));
}
// Redirect stderr (fd 2)
if libc::dup2(devnull_fd, 2) < 0 {
return Err(DaemonError::RedirectFailed(nix::errno::Errno::last()));
}
}
// Close original devnull fd if it's not one of the standard fds
if devnull_fd > 2 {
let _ = close(devnull_fd);
}
Ok(())
}
/// PID file manager with flock-based locking.
pub struct PidFile {
path: PathBuf,
file: Option<File>,
locked: bool,
}
impl PidFile {
/// Creates a new PID file manager for the given path.
pub fn new<P: AsRef<Path>>(path: P) -> Self {
Self {
path: path.as_ref().to_path_buf(),
file: None,
locked: false,
}
}
/// Checks if another instance is already running.
///
/// Returns the PID of the running instance if one exists.
pub fn check_running(&self) -> Result<Option<i32>, DaemonError> {
if !self.path.exists() {
return Ok(None);
}
// Try to read existing PID
let mut contents = String::new();
File::open(&self.path)
.and_then(|mut f| f.read_to_string(&mut contents))
.map_err(|e| DaemonError::PidFile(format!("cannot read {}: {}", self.path.display(), e)))?;
let pid: i32 = contents
.trim()
.parse()
.map_err(|_| DaemonError::PidFile(format!("invalid PID in {}", self.path.display())))?;
// Check if process is still running
if is_process_running(pid) {
Ok(Some(pid))
} else {
// Stale PID file
debug!(pid, path = %self.path.display(), "Removing stale PID file");
let _ = fs::remove_file(&self.path);
Ok(None)
}
}
/// Acquires the PID file lock and writes the current PID.
///
/// Fails if another instance is already running.
pub fn acquire(&mut self) -> Result<(), DaemonError> {
// Check for running instance first
if let Some(pid) = self.check_running()? {
return Err(DaemonError::AlreadyRunning(pid));
}
// Ensure parent directory exists
if let Some(parent) = self.path.parent() {
if !parent.exists() {
fs::create_dir_all(parent).map_err(|e| {
DaemonError::PidFile(format!(
"cannot create directory {}: {}",
parent.display(),
e
))
})?;
}
}
// Open/create PID file with exclusive lock
let file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o644)
.open(&self.path)
.map_err(|e| {
DaemonError::PidFile(format!("cannot open {}: {}", self.path.display(), e))
})?;
// Try to acquire exclusive lock (non-blocking)
let flock = Flock::lock(file, FlockArg::LockExclusiveNonblock).map_err(|(_, errno)| {
// Check if another instance grabbed the lock
if let Some(pid) = self.check_running().ok().flatten() {
DaemonError::AlreadyRunning(pid)
} else {
DaemonError::PidFile(format!("cannot lock {}: {}", self.path.display(), errno))
}
})?;
// Write our PID
let pid = getpid();
let mut file = flock.unlock().map_err(|(_, errno)| {
DaemonError::PidFile(format!("unlock failed: {}", errno))
})?;
writeln!(file, "{}", pid).map_err(|e| {
DaemonError::PidFile(format!("cannot write PID to {}: {}", self.path.display(), e))
})?;
// Re-acquire lock and keep it
let flock = Flock::lock(file, FlockArg::LockExclusiveNonblock).map_err(|(_, errno)| {
DaemonError::PidFile(format!("cannot re-lock {}: {}", self.path.display(), errno))
})?;
self.file = Some(flock.unlock().map_err(|(_, errno)| {
DaemonError::PidFile(format!("unlock for storage failed: {}", errno))
})?);
self.locked = true;
info!(pid = pid.as_raw(), path = %self.path.display(), "PID file created");
Ok(())
}
/// Releases the PID file lock and removes the file.
pub fn release(&mut self) -> Result<(), DaemonError> {
if let Some(file) = self.file.take() {
drop(file);
}
self.locked = false;
if self.path.exists() {
fs::remove_file(&self.path).map_err(|e| {
DaemonError::PidFile(format!("cannot remove {}: {}", self.path.display(), e))
})?;
debug!(path = %self.path.display(), "PID file removed");
}
Ok(())
}
/// Returns the path to this PID file.
#[allow(dead_code)]
pub fn path(&self) -> &Path {
&self.path
}
}
impl Drop for PidFile {
fn drop(&mut self) {
if self.locked {
if let Err(e) = self.release() {
warn!(error = %e, "Failed to clean up PID file on drop");
}
}
}
}
/// Checks if a process with the given PID is running.
fn is_process_running(pid: i32) -> bool {
// kill(pid, 0) checks if process exists without sending a signal
nix::sys::signal::kill(Pid::from_raw(pid), None).is_ok()
}
/// Drops privileges to the specified user and group.
///
/// This should be called after binding privileged ports but before
/// entering the main event loop.
pub fn drop_privileges(user: Option<&str>, group: Option<&str>) -> Result<(), DaemonError> {
// Look up group first (need to do this while still root)
let target_gid = if let Some(group_name) = group {
Some(lookup_group(group_name)?)
} else if let Some(user_name) = user {
// If no group specified but user is, use user's primary group
Some(lookup_user_primary_gid(user_name)?)
} else {
None
};
// Look up user
let target_uid = if let Some(user_name) = user {
Some(lookup_user(user_name)?)
} else {
None
};
// Drop privileges: set GID first, then UID
// (Setting UID first would prevent us from setting GID)
if let Some(gid) = target_gid {
unistd::setgid(gid).map_err(DaemonError::PrivilegeDrop)?;
// Also set supplementary groups to just this one
unistd::setgroups(&[gid]).map_err(DaemonError::PrivilegeDrop)?;
info!(gid = gid.as_raw(), "Dropped group privileges");
}
if let Some(uid) = target_uid {
unistd::setuid(uid).map_err(DaemonError::PrivilegeDrop)?;
info!(uid = uid.as_raw(), "Dropped user privileges");
}
Ok(())
}
/// Looks up a user by name and returns their UID.
fn lookup_user(name: &str) -> Result<Uid, DaemonError> {
// Use libc getpwnam
let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?;
unsafe {
let pwd = libc::getpwnam(c_name.as_ptr());
if pwd.is_null() {
Err(DaemonError::UserNotFound(name.to_string()))
} else {
Ok(Uid::from_raw((*pwd).pw_uid))
}
}
}
/// Looks up a user's primary GID by username.
fn lookup_user_primary_gid(name: &str) -> Result<Gid, DaemonError> {
let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?;
unsafe {
let pwd = libc::getpwnam(c_name.as_ptr());
if pwd.is_null() {
Err(DaemonError::UserNotFound(name.to_string()))
} else {
Ok(Gid::from_raw((*pwd).pw_gid))
}
}
}
/// Looks up a group by name and returns its GID.
fn lookup_group(name: &str) -> Result<Gid, DaemonError> {
let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::GroupNotFound(name.to_string()))?;
unsafe {
let grp = libc::getgrnam(c_name.as_ptr());
if grp.is_null() {
Err(DaemonError::GroupNotFound(name.to_string()))
} else {
Ok(Gid::from_raw((*grp).gr_gid))
}
}
}
/// Reads PID from a PID file.
#[allow(dead_code)]
pub fn read_pid_file<P: AsRef<Path>>(path: P) -> Result<i32, DaemonError> {
let path = path.as_ref();
let mut contents = String::new();
File::open(path)
.and_then(|mut f| f.read_to_string(&mut contents))
.map_err(|e| DaemonError::PidFile(format!("cannot read {}: {}", path.display(), e)))?;
contents
.trim()
.parse()
.map_err(|_| DaemonError::PidFile(format!("invalid PID in {}", path.display())))
}
/// Sends a signal to the process specified in a PID file.
#[allow(dead_code)]
pub fn signal_pid_file<P: AsRef<Path>>(
path: P,
signal: nix::sys::signal::Signal,
) -> Result<(), DaemonError> {
let pid = read_pid_file(&path)?;
if !is_process_running(pid) {
return Err(DaemonError::PidFile(format!(
"process {} from {} is not running",
pid,
path.as_ref().display()
)));
}
nix::sys::signal::kill(Pid::from_raw(pid), signal).map_err(|e| {
DaemonError::PidFile(format!("cannot signal process {}: {}", pid, e))
})?;
Ok(())
}
/// Returns the status of the daemon based on PID file.
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DaemonStatus {
/// Daemon is running with the given PID.
Running(i32),
/// PID file exists but process is not running.
Stale(i32),
/// No PID file exists.
NotRunning,
}
/// Checks the daemon status from a PID file.
#[allow(dead_code)]
pub fn check_status<P: AsRef<Path>>(path: P) -> DaemonStatus {
let path = path.as_ref();
if !path.exists() {
return DaemonStatus::NotRunning;
}
match read_pid_file(path) {
Ok(pid) => {
if is_process_running(pid) {
DaemonStatus::Running(pid)
} else {
DaemonStatus::Stale(pid)
}
}
Err(_) => DaemonStatus::NotRunning,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_daemon_options_default() {
let opts = DaemonOptions::default();
assert!(!opts.daemonize);
assert!(!opts.should_daemonize());
assert_eq!(opts.pid_file_path(), Path::new(DEFAULT_PID_FILE));
}
#[test]
fn test_daemon_options_foreground_overrides() {
let opts = DaemonOptions {
daemonize: true,
foreground: true,
..Default::default()
};
assert!(!opts.should_daemonize());
}
#[test]
fn test_check_status_not_running() {
let path = "/tmp/telemt_test_nonexistent.pid";
assert_eq!(check_status(path), DaemonStatus::NotRunning);
}
#[test]
fn test_pid_file_basic() {
let path = "/tmp/telemt_test_pidfile.pid";
let _ = fs::remove_file(path);
let mut pf = PidFile::new(path);
assert!(pf.check_running().unwrap().is_none());
pf.acquire().unwrap();
assert!(Path::new(path).exists());
// Read it back
let pid = read_pid_file(path).unwrap();
assert_eq!(pid, std::process::id() as i32);
pf.release().unwrap();
assert!(!Path::new(path).exists());
}
}

View File

@ -12,28 +12,15 @@ use thiserror::Error;
#[derive(Debug)]
pub enum StreamError {
/// Partial read: got fewer bytes than expected
PartialRead {
expected: usize,
got: usize,
},
PartialRead { expected: usize, got: usize },
/// Partial write: wrote fewer bytes than expected
PartialWrite {
expected: usize,
written: usize,
},
PartialWrite { expected: usize, written: usize },
/// Stream is in poisoned state and cannot be used
Poisoned {
reason: String,
},
Poisoned { reason: String },
/// Buffer overflow: attempted to buffer more than allowed
BufferOverflow {
limit: usize,
attempted: usize,
},
BufferOverflow { limit: usize, attempted: usize },
/// Invalid frame format
InvalidFrame {
details: String,
},
InvalidFrame { details: String },
/// Unexpected end of stream
UnexpectedEof,
/// Underlying I/O error
@ -47,13 +34,21 @@ impl fmt::Display for StreamError {
write!(f, "partial read: expected {} bytes, got {}", expected, got)
}
Self::PartialWrite { expected, written } => {
write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
write!(
f,
"partial write: expected {} bytes, wrote {}",
expected, written
)
}
Self::Poisoned { reason } => {
write!(f, "stream poisoned: {}", reason)
}
Self::BufferOverflow { limit, attempted } => {
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
write!(
f,
"buffer overflow: limit {}, attempted {}",
limit, attempted
)
}
Self::InvalidFrame { details } => {
write!(f, "invalid frame: {}", details)
@ -90,9 +85,7 @@ impl From<StreamError> for std::io::Error {
StreamError::UnexpectedEof => {
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
}
StreamError::Poisoned { .. } => {
std::io::Error::other(err)
}
StreamError::Poisoned { .. } => std::io::Error::other(err),
StreamError::BufferOverflow { .. } => {
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
}
@ -135,7 +128,10 @@ impl Recoverable for StreamError {
}
fn can_continue(&self) -> bool {
!matches!(self, Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. })
!matches!(
self,
Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. }
)
}
}
@ -165,7 +161,6 @@ impl Recoverable for std::io::Error {
#[derive(Error, Debug)]
pub enum ProxyError {
// ============= Crypto Errors =============
#[error("Crypto error: {0}")]
Crypto(String),
@ -173,12 +168,10 @@ pub enum ProxyError {
InvalidKeyLength { expected: usize, got: usize },
// ============= Stream Errors =============
#[error("Stream error: {0}")]
Stream(#[from] StreamError),
// ============= Protocol Errors =============
#[error("Invalid handshake: {0}")]
InvalidHandshake(String),
@ -210,7 +203,6 @@ pub enum ProxyError {
TgHandshakeTimeout,
// ============= Network Errors =============
#[error("Connection timeout to {addr}")]
ConnectionTimeout { addr: String },
@ -221,7 +213,6 @@ pub enum ProxyError {
Io(#[from] std::io::Error),
// ============= Proxy Protocol Errors =============
#[error("Invalid proxy protocol header")]
InvalidProxyProtocol,
@ -229,7 +220,6 @@ pub enum ProxyError {
Proxy(String),
// ============= Config Errors =============
#[error("Config error: {0}")]
Config(String),
@ -237,7 +227,6 @@ pub enum ProxyError {
InvalidSecret { user: String, reason: String },
// ============= User Errors =============
#[error("User {user} expired")]
UserExpired { user: String },
@ -254,7 +243,6 @@ pub enum ProxyError {
RateLimited,
// ============= General Errors =============
#[error("Internal error: {0}")]
Internal(String),
}
@ -311,7 +299,9 @@ impl<T, R, W> HandshakeResult<T, R, W> {
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
HandshakeResult::BadClient { reader, writer } => {
HandshakeResult::BadClient { reader, writer }
}
HandshakeResult::Error(e) => HandshakeResult::Error(e),
}
}
@ -341,18 +331,35 @@ mod tests {
#[test]
fn test_stream_error_display() {
let err = StreamError::PartialRead { expected: 100, got: 50 };
let err = StreamError::PartialRead {
expected: 100,
got: 50,
};
assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("50"));
let err = StreamError::Poisoned { reason: "test".into() };
let err = StreamError::Poisoned {
reason: "test".into(),
};
assert!(err.to_string().contains("test"));
}
#[test]
fn test_stream_error_recoverable() {
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
assert!(
StreamError::PartialRead {
expected: 10,
got: 5
}
.is_recoverable()
);
assert!(
StreamError::PartialWrite {
expected: 10,
written: 5
}
.is_recoverable()
);
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
assert!(!StreamError::UnexpectedEof.is_recoverable());
}
@ -361,7 +368,13 @@ mod tests {
fn test_stream_error_can_continue() {
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
assert!(!StreamError::UnexpectedEof.can_continue());
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
assert!(
StreamError::PartialRead {
expected: 10,
got: 5
}
.can_continue()
);
}
#[test]
@ -377,7 +390,10 @@ mod tests {
assert!(success.is_success());
assert!(!success.is_bad_client());
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient {
reader: (),
writer: (),
};
assert!(!bad.is_success());
assert!(bad.is_bad_client());
}
@ -404,7 +420,9 @@ mod tests {
#[test]
fn test_error_display() {
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
let err = ProxyError::ConnectionTimeout {
addr: "1.2.3.4:443".into(),
};
assert!(err.to_string().contains("1.2.3.4:443"));
let err = ProxyError::InvalidProxyProtocol;

View File

@ -5,10 +5,11 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::sync::{Mutex as AsyncMutex, RwLock};
use crate::config::UserMaxUniqueIpsMode;
@ -21,6 +22,8 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>,
cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>,
}
impl UserIpTracker {
@ -33,6 +36,79 @@ impl UserIpTracker {
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
}
}
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
match self.cleanup_queue.lock() {
Ok(mut queue) => queue.push((user, ip)),
Err(poisoned) => {
let mut queue = poisoned.into_inner();
queue.push((user.clone(), ip));
self.cleanup_queue.clear_poison();
tracing::warn!(
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
user,
ip
);
}
}
}
#[cfg(test)]
pub(crate) fn cleanup_queue_len_for_tests(&self) -> usize {
self.cleanup_queue
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.len()
}
#[cfg(test)]
pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc<Mutex<Vec<(String, IpAddr)>>> {
Arc::clone(&self.cleanup_queue)
}
pub(crate) async fn drain_cleanup_queue(&self) {
// Serialize queue draining and active-IP mutation so check-and-add cannot
// observe stale active entries that are already queued for removal.
let _drain_guard = self.cleanup_drain_lock.lock().await;
let to_remove = {
match self.cleanup_queue.lock() {
Ok(mut queue) => {
if queue.is_empty() {
return;
}
std::mem::take(&mut *queue)
}
Err(poisoned) => {
let mut queue = poisoned.into_inner();
if queue.is_empty() {
self.cleanup_queue.clear_poison();
return;
}
let drained = std::mem::take(&mut *queue);
self.cleanup_queue.clear_poison();
drained
}
}
};
let mut active_ips = self.active_ips.write().await;
for (user, ip) in to_remove {
if let Some(user_ips) = active_ips.get_mut(&user) {
if let Some(count) = user_ips.get_mut(&ip) {
if *count > 1 {
*count -= 1;
} else {
user_ips.remove(&ip);
}
}
if user_ips.is_empty() {
active_ips.remove(&user);
}
}
}
}
@ -65,7 +141,8 @@ impl UserIpTracker {
let mut active_ips = self.active_ips.write().await;
let mut recent_ips = self.recent_ips.write().await;
let mut users = Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len()));
let mut users =
Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len()));
users.extend(active_ips.keys().cloned());
for user in recent_ips.keys() {
if !active_ips.contains_key(user) {
@ -74,8 +151,14 @@ impl UserIpTracker {
}
for user in users {
let active_empty = active_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true);
let recent_empty = recent_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true);
let active_empty = active_ips
.get(&user)
.map(|ips| ips.is_empty())
.unwrap_or(true);
let recent_empty = recent_ips
.get(&user)
.map(|ips| ips.is_empty())
.unwrap_or(true);
if active_empty && recent_empty {
active_ips.remove(&user);
recent_ips.remove(&user);
@ -118,6 +201,7 @@ impl UserIpTracker {
}
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
self.drain_cleanup_queue().await;
self.maybe_compact_empty_users().await;
let default_max_ips = *self.default_max_ips.read().await;
let limit = {
@ -194,6 +278,7 @@ impl UserIpTracker {
}
pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap<String, usize> {
self.drain_cleanup_queue().await;
let window = *self.limit_window.read().await;
let now = Instant::now();
let recent_ips = self.recent_ips.read().await;
@ -214,6 +299,7 @@ impl UserIpTracker {
}
pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
let mut out = HashMap::with_capacity(users.len());
for user in users {
@ -228,6 +314,7 @@ impl UserIpTracker {
}
pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
self.drain_cleanup_queue().await;
let window = *self.limit_window.read().await;
let now = Instant::now();
let recent_ips = self.recent_ips.read().await;
@ -250,11 +337,13 @@ impl UserIpTracker {
}
pub async fn get_active_ip_count(&self, username: &str) -> usize {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips.get(username).map(|ips| ips.len()).unwrap_or(0)
}
pub async fn get_active_ips(&self, username: &str) -> Vec<IpAddr> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips
.get(username)
@ -263,6 +352,7 @@ impl UserIpTracker {
}
pub async fn get_stats(&self) -> Vec<(String, usize, usize)> {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
let max_ips = self.max_ips.read().await;
let default_max_ips = *self.default_max_ips.read().await;
@ -301,6 +391,7 @@ impl UserIpTracker {
}
pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool {
self.drain_cleanup_queue().await;
let active_ips = self.active_ips.read().await;
active_ips
.get(username)

291
src/logging.rs Normal file
View File

@ -0,0 +1,291 @@
//! Logging configuration for telemt.
//!
//! Supports multiple log destinations:
//! - stderr (default, works with systemd journald)
//! - syslog (Unix only, for traditional init systems)
//! - file (with optional rotation)
#![allow(dead_code)] // Infrastructure module - used via CLI flags
use std::path::Path;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, fmt, reload};
/// Log destination configuration.
#[derive(Debug, Clone, Default)]
pub enum LogDestination {
/// Log to stderr (default, captured by systemd journald).
#[default]
Stderr,
/// Log to syslog (Unix only).
#[cfg(unix)]
Syslog,
/// Log to a file with optional rotation.
File {
path: String,
/// Rotate daily if true.
rotate_daily: bool,
},
}
/// Logging options parsed from CLI/config.
#[derive(Debug, Clone, Default)]
pub struct LoggingOptions {
/// Where to send logs.
pub destination: LogDestination,
/// Disable ANSI colors.
pub disable_colors: bool,
}
/// Guard that must be held to keep file logging active.
/// When dropped, flushes and closes log files.
pub struct LoggingGuard {
_guard: Option<tracing_appender::non_blocking::WorkerGuard>,
}
impl LoggingGuard {
fn new(guard: Option<tracing_appender::non_blocking::WorkerGuard>) -> Self {
Self { _guard: guard }
}
/// Creates a no-op guard for stderr/syslog logging.
pub fn noop() -> Self {
Self { _guard: None }
}
}
/// Initialize the tracing subscriber with the specified options.
///
/// Returns a reload handle for dynamic log level changes and a guard
/// that must be kept alive for file logging.
pub fn init_logging(
opts: &LoggingOptions,
initial_filter: &str,
) -> (reload::Handle<EnvFilter, impl tracing::Subscriber + Send + Sync>, LoggingGuard) {
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new(initial_filter));
match &opts.destination {
LogDestination::Stderr => {
let fmt_layer = fmt::Layer::default()
.with_ansi(!opts.disable_colors)
.with_target(true);
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
(filter_handle, LoggingGuard::noop())
}
#[cfg(unix)]
LogDestination::Syslog => {
// Use a custom fmt layer that writes to syslog
let fmt_layer = fmt::Layer::default()
.with_ansi(false)
.with_target(true)
.with_writer(SyslogWriter::new);
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
(filter_handle, LoggingGuard::noop())
}
LogDestination::File { path, rotate_daily } => {
let (non_blocking, guard) = if *rotate_daily {
// Extract directory and filename prefix
let path = Path::new(path);
let dir = path.parent().unwrap_or(Path::new("/var/log"));
let prefix = path.file_name()
.and_then(|s| s.to_str())
.unwrap_or("telemt");
let file_appender = tracing_appender::rolling::daily(dir, prefix);
tracing_appender::non_blocking(file_appender)
} else {
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(path)
.expect("Failed to open log file");
tracing_appender::non_blocking(file)
};
let fmt_layer = fmt::Layer::default()
.with_ansi(false)
.with_target(true)
.with_writer(non_blocking);
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
(filter_handle, LoggingGuard::new(Some(guard)))
}
}
}
/// Syslog writer for tracing.
#[cfg(unix)]
struct SyslogWriter {
_private: (),
}
#[cfg(unix)]
impl SyslogWriter {
fn new() -> Self {
// Open syslog connection on first use
static INIT: std::sync::Once = std::sync::Once::new();
INIT.call_once(|| {
unsafe {
// Open syslog with ident "telemt", LOG_PID, LOG_DAEMON facility
let ident = b"telemt\0".as_ptr() as *const libc::c_char;
libc::openlog(ident, libc::LOG_PID | libc::LOG_NDELAY, libc::LOG_DAEMON);
}
});
Self { _private: () }
}
}
#[cfg(unix)]
impl std::io::Write for SyslogWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
// Convert to C string, stripping newlines
let msg = String::from_utf8_lossy(buf);
let msg = msg.trim_end();
if msg.is_empty() {
return Ok(buf.len());
}
// Determine priority based on log level in the message
let priority = if msg.contains(" ERROR ") || msg.contains(" error ") {
libc::LOG_ERR
} else if msg.contains(" WARN ") || msg.contains(" warn ") {
libc::LOG_WARNING
} else if msg.contains(" INFO ") || msg.contains(" info ") {
libc::LOG_INFO
} else if msg.contains(" DEBUG ") || msg.contains(" debug ") {
libc::LOG_DEBUG
} else {
libc::LOG_INFO
};
// Write to syslog
let c_msg = std::ffi::CString::new(msg.as_bytes())
.unwrap_or_else(|_| std::ffi::CString::new("(invalid utf8)").unwrap());
unsafe {
libc::syslog(priority, b"%s\0".as_ptr() as *const libc::c_char, c_msg.as_ptr());
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(unix)]
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for SyslogWriter {
type Writer = SyslogWriter;
fn make_writer(&'a self) -> Self::Writer {
SyslogWriter::new()
}
}
/// Parse log destination from CLI arguments.
pub fn parse_log_destination(args: &[String]) -> LogDestination {
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
#[cfg(unix)]
"--syslog" => {
return LogDestination::Syslog;
}
"--log-file" => {
i += 1;
if i < args.len() {
return LogDestination::File {
path: args[i].clone(),
rotate_daily: false,
};
}
}
s if s.starts_with("--log-file=") => {
return LogDestination::File {
path: s.trim_start_matches("--log-file=").to_string(),
rotate_daily: false,
};
}
"--log-file-daily" => {
i += 1;
if i < args.len() {
return LogDestination::File {
path: args[i].clone(),
rotate_daily: true,
};
}
}
s if s.starts_with("--log-file-daily=") => {
return LogDestination::File {
path: s.trim_start_matches("--log-file-daily=").to_string(),
rotate_daily: true,
};
}
_ => {}
}
i += 1;
}
LogDestination::Stderr
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_log_destination_default() {
let args: Vec<String> = vec![];
assert!(matches!(parse_log_destination(&args), LogDestination::Stderr));
}
#[test]
fn test_parse_log_destination_file() {
let args = vec!["--log-file".to_string(), "/var/log/telemt.log".to_string()];
match parse_log_destination(&args) {
LogDestination::File { path, rotate_daily } => {
assert_eq!(path, "/var/log/telemt.log");
assert!(!rotate_daily);
}
_ => panic!("Expected File destination"),
}
}
#[test]
fn test_parse_log_destination_file_daily() {
let args = vec!["--log-file-daily=/var/log/telemt".to_string()];
match parse_log_destination(&args) {
LogDestination::File { path, rotate_daily } => {
assert_eq!(path, "/var/log/telemt");
assert!(rotate_daily);
}
_ => panic!("Expected File destination"),
}
}
#[cfg(unix)]
#[test]
fn test_parse_log_destination_syslog() {
let args = vec!["--syslog".to_string()];
assert!(matches!(parse_log_destination(&args), LogDestination::Syslog));
}
}

View File

@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]
use std::sync::Arc;
use std::time::Instant;
@ -11,10 +13,10 @@ use crate::startup::{
COMPONENT_DC_CONNECTIVITY_PING, COMPONENT_ME_CONNECTIVITY_PING, COMPONENT_RUNTIME_READY,
StartupTracker,
};
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{
MePingFamily, MePingSample, MePool, format_me_route, format_sample_line, run_me_ping,
};
use crate::transport::UpstreamManager;
pub(crate) async fn run_startup_connectivity(
config: &Arc<ProxyConfig>,
@ -47,11 +49,15 @@ pub(crate) async fn run_startup_connectivity(
let v4_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V4)
&& r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some())
&& r.samples
.iter()
.any(|s| s.error.is_none() && s.handshake_ms.is_some())
});
let v6_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V6)
&& r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some())
&& r.samples
.iter()
.any(|s| s.error.is_none() && s.handshake_ms.is_some())
});
info!("================= Telegram ME Connectivity =================");
@ -131,8 +137,14 @@ pub(crate) async fn run_startup_connectivity(
.await;
for upstream_result in &ping_results {
let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some());
let v6_works = upstream_result
.v6_results
.iter()
.any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result
.v4_results
.iter()
.any(|r| r.rtt_ms.is_some());
if upstream_result.both_available {
if prefer_ipv6 {

View File

@ -1,16 +1,41 @@
use std::time::Duration;
#![allow(clippy::items_after_test_module)]
use std::path::PathBuf;
use std::time::Duration;
use tokio::sync::watch;
use tracing::{debug, error, info, warn};
use crate::cli;
use crate::config::ProxyConfig;
use crate::logging::LogDestination;
use crate::transport::middle_proxy::{
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
};
pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
pub(crate) fn resolve_runtime_config_path(
config_path_cli: &str,
startup_cwd: &std::path::Path,
) -> PathBuf {
let raw = PathBuf::from(config_path_cli);
let absolute = if raw.is_absolute() {
raw
} else {
startup_cwd.join(raw)
};
absolute.canonicalize().unwrap_or(absolute)
}
/// Parsed CLI arguments.
pub(crate) struct CliArgs {
pub config_path: String,
pub data_path: Option<PathBuf>,
pub silent: bool,
pub log_level: Option<String>,
pub log_destination: LogDestination,
}
pub(crate) fn parse_cli() -> CliArgs {
let mut config_path = "config.toml".to_string();
let mut data_path: Option<PathBuf> = None;
let mut silent = false;
@ -18,6 +43,9 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
let args: Vec<String> = std::env::args().skip(1).collect();
// Parse log destination
let log_destination = crate::logging::parse_log_destination(&args);
// Check for --init first (handled before tokio)
if let Some(init_opts) = cli::parse_init_args(&args) {
if let Err(e) = cli::run_init(init_opts) {
@ -40,7 +68,9 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
}
}
s if s.starts_with("--data-path=") => {
data_path = Some(PathBuf::from(s.trim_start_matches("--data-path=").to_string()));
data_path = Some(PathBuf::from(
s.trim_start_matches("--data-path=").to_string(),
));
}
"--silent" | "-s" => {
silent = true;
@ -55,14 +85,91 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
}
"--help" | "-h" => {
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
print_help();
std::process::exit(0);
}
"--version" | "-V" => {
println!("telemt {}", env!("CARGO_PKG_VERSION"));
std::process::exit(0);
}
// Skip daemon-related flags (already parsed)
"--daemon" | "-d" | "--foreground" | "-f" => {}
s if s.starts_with("--pid-file") => {
if !s.contains('=') {
i += 1; // skip value
}
}
s if s.starts_with("--run-as-user") => {
if !s.contains('=') {
i += 1;
}
}
s if s.starts_with("--run-as-group") => {
if !s.contains('=') {
i += 1;
}
}
s if s.starts_with("--working-dir") => {
if !s.contains('=') {
i += 1;
}
}
s if !s.starts_with('-') => {
config_path = s.to_string();
}
other => {
eprintln!("Unknown option: {}", other);
}
}
i += 1;
}
CliArgs {
config_path,
data_path,
silent,
log_level,
log_destination,
}
}
fn print_help() {
eprintln!("Usage: telemt [COMMAND] [OPTIONS] [config.toml]");
eprintln!();
eprintln!("Commands:");
eprintln!(" run Run in foreground (default if no command given)");
#[cfg(unix)]
{
eprintln!(" start Start as background daemon");
eprintln!(" stop Stop a running daemon");
eprintln!(" reload Reload configuration (send SIGHUP)");
eprintln!(" status Check if daemon is running");
}
eprintln!();
eprintln!("Options:");
eprintln!(" --data-path <DIR> Set data directory (absolute path; overrides config value)");
eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
eprintln!(" --version, -V Show version");
eprintln!();
eprintln!("Logging options:");
eprintln!(" --log-file <PATH> Log to file (default: stderr)");
eprintln!(" --log-file-daily <PATH> Log to file with daily rotation");
#[cfg(unix)]
eprintln!(" --syslog Log to syslog (Unix only)");
eprintln!();
#[cfg(unix)]
{
eprintln!("Daemon options (Unix only):");
eprintln!(" --daemon, -d Fork to background (daemonize)");
eprintln!(" --foreground, -f Explicit foreground mode (for systemd)");
eprintln!(" --pid-file <PATH> PID file path (default: /var/run/telemt.pid)");
eprintln!(" --run-as-user <USER> Drop privileges to this user after binding");
eprintln!(" --run-as-group <GROUP> Drop privileges to this group after binding");
eprintln!(" --working-dir <DIR> Working directory for daemon mode");
eprintln!();
}
eprintln!("Setup (fire-and-forget):");
eprintln!(
" --init Generate config, install systemd service, start"
@ -77,28 +184,65 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install");
std::process::exit(0);
#[cfg(unix)]
{
eprintln!();
eprintln!("Examples:");
eprintln!(" telemt config.toml Run in foreground");
eprintln!(" telemt start config.toml Start as daemon");
eprintln!(" telemt start --pid-file /tmp/t.pid Start with custom PID file");
eprintln!(" telemt stop Stop daemon");
eprintln!(" telemt reload Reload configuration");
eprintln!(" telemt status Check daemon status");
}
"--version" | "-V" => {
println!("telemt {}", env!("CARGO_PKG_VERSION"));
std::process::exit(0);
}
s if !s.starts_with('-') => {
config_path = s.to_string();
}
other => {
eprintln!("Unknown option: {}", other);
}
}
i += 1;
}
(config_path, data_path, silent, log_level)
#[cfg(test)]
mod tests {
use super::resolve_runtime_config_path;
#[test]
fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() {
let nonce = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}"));
std::fs::create_dir_all(&startup_cwd).unwrap();
let target = startup_cwd.join("config.toml");
std::fs::write(&target, " ").unwrap();
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd);
assert_eq!(resolved, target.canonicalize().unwrap());
let _ = std::fs::remove_file(&target);
let _ = std::fs::remove_dir(&startup_cwd);
}
#[test]
fn resolve_runtime_config_path_keeps_absolute_for_missing_file() {
let nonce = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}"));
std::fs::create_dir_all(&startup_cwd).unwrap();
let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd);
assert_eq!(resolved, startup_cwd.join("missing.toml"));
let _ = std::fs::remove_dir(&startup_cwd);
}
}
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
for user_name in config.general.links.show.resolve_users(&config.access.users) {
for user_name in config
.general
.links
.show
.resolve_users(&config.access.users)
{
if let Some(secret) = config.access.users.get(user_name) {
info!(target: "telemt::links", "User: {}", user_name);
if config.general.modes.classic {
@ -239,7 +383,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
return Some(cfg);
}
warn!(snapshot = label, url, "Startup proxy-config is empty; trying disk cache");
warn!(
snapshot = label,
url, "Startup proxy-config is empty; trying disk cache"
);
if let Some(path) = cache_path {
match load_proxy_config_cache(path).await {
Ok(cached) if !cached.map.is_empty() => {
@ -254,8 +401,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
Ok(_) => {
warn!(
snapshot = label,
path,
"Startup proxy-config cache is empty; ignoring cache file"
path, "Startup proxy-config cache is empty; ignoring cache file"
);
}
Err(cache_err) => {
@ -299,8 +445,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
Ok(_) => {
warn!(
snapshot = label,
path,
"Startup proxy-config cache is empty; ignoring cache file"
path, "Startup proxy-config cache is empty; ignoring cache file"
);
}
Err(cache_err) => {

View File

@ -12,17 +12,15 @@ use tracing::{debug, error, info, warn};
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker;
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController};
use crate::proxy::ClientHandler;
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController};
use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker};
use crate::stats::beobachten::BeobachtenStore;
use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool;
use crate::tls_front::TlsFrontCache;
use crate::transport::middle_proxy::MePool;
use crate::transport::{
ListenOptions, UpstreamManager, create_listener, find_listener_processes,
};
use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes};
use super::helpers::{is_expected_handshake_eof, print_proxy_links};
@ -81,8 +79,9 @@ pub(crate) async fn bind_listeners(
Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr);
let listener_proxy_protocol =
listener_conf.proxy_protocol.unwrap_or(config.server.proxy_protocol);
let listener_proxy_protocol = listener_conf
.proxy_protocol
.unwrap_or(config.server.proxy_protocol);
let public_host = if let Some(ref announce) = listener_conf.announce {
announce.clone()
@ -100,8 +99,14 @@ pub(crate) async fn bind_listeners(
listener_conf.ip.to_string()
};
if config.general.links.public_host.is_none() && !config.general.links.show.is_empty() {
let link_port = config.general.links.public_port.unwrap_or(config.server.port);
if config.general.links.public_host.is_none()
&& !config.general.links.show.is_empty()
{
let link_port = config
.general
.links
.public_port
.unwrap_or(config.server.port);
print_proxy_links(&public_host, link_port, config);
}
@ -145,12 +150,14 @@ pub(crate) async fn bind_listeners(
let (host, port) = if let Some(ref h) = config.general.links.public_host {
(
h.clone(),
config.general.links.public_port.unwrap_or(config.server.port),
config
.general
.links
.public_port
.unwrap_or(config.server.port),
)
} else {
let ip = detected_ip_v4
.or(detected_ip_v6)
.map(|ip| ip.to_string());
let ip = detected_ip_v4.or(detected_ip_v6).map(|ip| ip.to_string());
if ip.is_none() {
warn!(
"show_link is configured but public IP could not be detected. Set public_host in config."
@ -158,7 +165,11 @@ pub(crate) async fn bind_listeners(
}
(
ip.unwrap_or_else(|| "UNKNOWN".to_string()),
config.general.links.public_port.unwrap_or(config.server.port),
config
.general
.links
.public_port
.unwrap_or(config.server.port),
)
};
@ -178,13 +189,19 @@ pub(crate) async fn bind_listeners(
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(mode);
if let Err(e) = std::fs::set_permissions(unix_path, perms) {
error!("Failed to set unix socket permissions to {}: {}", perm_str, e);
error!(
"Failed to set unix socket permissions to {}: {}",
perm_str, e
);
} else {
info!("Listening on unix:{} (mode {})", unix_path, perm_str);
}
}
Err(e) => {
warn!("Invalid listen_unix_sock_perm '{}': {}. Ignoring.", perm_str, e);
warn!(
"Invalid listen_unix_sock_perm '{}': {}. Ignoring.",
perm_str, e
);
info!("Listening on unix:{}", unix_path);
}
}
@ -218,10 +235,8 @@ pub(crate) async fn bind_listeners(
drop(stream);
continue;
}
let accept_permit_timeout_ms = config_rx_unix
.borrow()
.server
.accept_permit_timeout_ms;
let accept_permit_timeout_ms =
config_rx_unix.borrow().server.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 {
match max_connections_unix.clone().acquire_owned().await {
Ok(permit) => permit,
@ -361,10 +376,8 @@ pub(crate) fn spawn_tcp_accept_loops(
drop(stream);
continue;
}
let accept_permit_timeout_ms = config_rx
.borrow()
.server
.accept_permit_timeout_ms;
let accept_permit_timeout_ms =
config_rx.borrow().server.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 {
match max_connections_tcp.clone().acquire_owned().await {
Ok(permit) => permit,

View File

@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]
use std::sync::Arc;
use std::time::Duration;
@ -12,8 +14,8 @@ use crate::startup::{
COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, StartupMeStatus, StartupTracker,
};
use crate::stats::Stats;
use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool;
use super::helpers::load_startup_proxy_config_snapshot;
@ -229,8 +231,12 @@ pub(crate) async fn initialize_me_pool(
config.general.me_adaptive_floor_recover_grace_secs,
config.general.me_adaptive_floor_writers_per_core_total,
config.general.me_adaptive_floor_cpu_cores_override,
config.general.me_adaptive_floor_max_extra_writers_single_per_core,
config.general.me_adaptive_floor_max_extra_writers_multi_per_core,
config
.general
.me_adaptive_floor_max_extra_writers_single_per_core,
config
.general
.me_adaptive_floor_max_extra_writers_multi_per_core,
config.general.me_adaptive_floor_max_active_writers_per_core,
config.general.me_adaptive_floor_max_warm_writers_per_core,
config.general.me_adaptive_floor_max_active_writers_global,
@ -268,8 +274,6 @@ pub(crate) async fn initialize_me_pool(
config.general.me_warn_rate_limit_ms,
config.general.me_route_no_writer_mode,
config.general.me_route_no_writer_wait_ms,
config.general.me_route_hybrid_max_wait_ms,
config.general.me_route_blocking_send_timeout_ms,
config.general.me_route_inline_recovery_attempts,
config.general.me_route_inline_recovery_wait_ms,
);
@ -475,7 +479,9 @@ pub(crate) async fn initialize_me_pool(
})
.await;
match res {
Ok(()) => warn!("me_health_monitor exited unexpectedly, restarting"),
Ok(()) => warn!(
"me_health_monitor exited unexpectedly, restarting"
),
Err(e) => {
error!(error = %e, "me_health_monitor panicked, restarting in 1s");
tokio::time::sleep(Duration::from_secs(1)).await;
@ -492,7 +498,9 @@ pub(crate) async fn initialize_me_pool(
})
.await;
match res {
Ok(()) => warn!("me_drain_timeout_enforcer exited unexpectedly, restarting"),
Ok(()) => warn!(
"me_drain_timeout_enforcer exited unexpectedly, restarting"
),
Err(e) => {
error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s");
tokio::time::sleep(Duration::from_secs(1)).await;
@ -509,7 +517,9 @@ pub(crate) async fn initialize_me_pool(
})
.await;
match res {
Ok(()) => warn!("me_zombie_writer_watchdog exited unexpectedly, restarting"),
Ok(()) => warn!(
"me_zombie_writer_watchdog exited unexpectedly, restarting"
),
Err(e) => {
error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s");
tokio::time::sleep(Duration::from_secs(1)).await;

View File

@ -11,9 +11,9 @@
// - admission: conditional-cast gate and route mode switching.
// - listeners: TCP/Unix listener bind and accept-loop orchestration.
// - shutdown: graceful shutdown sequence and uptime logging.
mod helpers;
mod admission;
mod connectivity;
mod helpers;
mod listeners;
mod me_startup;
mod runtime_tasks;
@ -33,22 +33,70 @@ use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker;
use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe};
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::startup::{
COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, COMPONENT_ME_POOL_CONSTRUCT,
COMPONENT_ME_POOL_INIT_STAGE1, COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6,
COMPONENT_ME_SECRET_FETCH, COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus,
StartupTracker,
};
use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats};
use crate::startup::{
COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD,
COMPONENT_ME_POOL_CONSTRUCT, COMPONENT_ME_POOL_INIT_STAGE1,
COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH,
COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus, StartupTracker,
};
use crate::stream::BufferPool;
use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager;
use helpers::parse_cli;
use crate::transport::middle_proxy::MePool;
use helpers::{parse_cli, resolve_runtime_config_path};
#[cfg(unix)]
use crate::daemon::{DaemonOptions, PidFile, drop_privileges};
/// Runs the full telemt runtime startup pipeline and blocks until shutdown.
///
/// On Unix, daemon options should be handled before calling this function
/// (daemonization must happen before tokio runtime starts).
#[cfg(unix)]
pub async fn run_with_daemon(
daemon_opts: DaemonOptions,
) -> std::result::Result<(), Box<dyn std::error::Error>> {
run_inner(daemon_opts).await
}
/// Runs the full telemt runtime startup pipeline and blocks until shutdown.
///
/// This is the main entry point for non-daemon mode or when called as a library.
#[allow(dead_code)]
pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
#[cfg(unix)]
{
// Parse CLI to get daemon options even in simple run() path
let args: Vec<String> = std::env::args().skip(1).collect();
let daemon_opts = crate::cli::parse_daemon_args(&args);
run_inner(daemon_opts).await
}
#[cfg(not(unix))]
{
run_inner().await
}
}
#[cfg(unix)]
async fn run_inner(
daemon_opts: DaemonOptions,
) -> std::result::Result<(), Box<dyn std::error::Error>> {
// Acquire PID file if daemonizing or if explicitly requested
// Keep it alive until shutdown (underscore prefix = intentionally kept for RAII cleanup)
let _pid_file = if daemon_opts.daemonize || daemon_opts.pid_file.is_some() {
let mut pf = PidFile::new(daemon_opts.pid_file_path());
if let Err(e) = pf.acquire() {
eprintln!("[telemt] {}", e);
std::process::exit(1);
}
Some(pf)
} else {
None
};
let process_started_at = Instant::now();
let process_started_at_epoch_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
@ -56,20 +104,39 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.as_secs();
let startup_tracker = Arc::new(StartupTracker::new(process_started_at_epoch_secs));
startup_tracker
.start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string()))
.start_component(
COMPONENT_CONFIG_LOAD,
Some("load and validate config".to_string()),
)
.await;
let (config_path, data_path, cli_silent, cli_log_level) = parse_cli();
let cli_args = parse_cli();
let config_path_cli = cli_args.config_path;
let data_path = cli_args.data_path;
let cli_silent = cli_args.silent;
let cli_log_level = cli_args.log_level;
let log_destination = cli_args.log_destination;
let startup_cwd = match std::env::current_dir() {
Ok(cwd) => cwd,
Err(e) => {
eprintln!("[telemt] Can't read current_dir: {}", e);
std::process::exit(1);
}
};
let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd);
let mut config = match ProxyConfig::load(&config_path) {
Ok(c) => c,
Err(e) => {
if std::path::Path::new(&config_path).exists() {
if config_path.exists() {
eprintln!("[telemt] Error: {}", e);
std::process::exit(1);
} else {
let default = ProxyConfig::default();
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!("[telemt] Created default config at {}", config_path);
eprintln!(
"[telemt] Created default config at {}",
config_path.display()
);
default
}
}
@ -86,24 +153,38 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
if let Some(ref data_path) = config.general.data_path {
if !data_path.is_absolute() {
eprintln!("[telemt] data_path must be absolute: {}", data_path.display());
eprintln!(
"[telemt] data_path must be absolute: {}",
data_path.display()
);
std::process::exit(1);
}
if data_path.exists() {
if !data_path.is_dir() {
eprintln!("[telemt] data_path exists but is not a directory: {}", data_path.display());
eprintln!(
"[telemt] data_path exists but is not a directory: {}",
data_path.display()
);
std::process::exit(1);
}
} else {
if let Err(e) = std::fs::create_dir_all(data_path) {
eprintln!("[telemt] Can't create data_path {}: {}", data_path.display(), e);
eprintln!(
"[telemt] Can't create data_path {}: {}",
data_path.display(),
e
);
std::process::exit(1);
}
}
if let Err(e) = std::env::set_current_dir(data_path) {
eprintln!("[telemt] Can't use data_path {}: {}", data_path.display(), e);
eprintln!(
"[telemt] Can't use data_path {}: {}",
data_path.display(),
e
);
std::process::exit(1);
}
}
@ -127,22 +208,54 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
startup_tracker
.start_component(COMPONENT_TRACING_INIT, Some("initialize tracing subscriber".to_string()))
.start_component(
COMPONENT_TRACING_INIT,
Some("initialize tracing subscriber".to_string()),
)
.await;
// Configure color output based on config
// Initialize logging based on destination
let _logging_guard: Option<crate::logging::LoggingGuard>;
match log_destination {
crate::logging::LogDestination::Stderr => {
// Default: log to stderr (works with systemd journald)
let fmt_layer = if config.general.disable_colors {
fmt::Layer::default().with_ansi(false)
} else {
fmt::Layer::default().with_ansi(true)
};
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
_logging_guard = None;
}
#[cfg(unix)]
crate::logging::LogDestination::Syslog => {
// Syslog: for OpenRC/FreeBSD
let logging_opts = crate::logging::LoggingOptions {
destination: log_destination,
disable_colors: true,
};
let (_, guard) = crate::logging::init_logging(&logging_opts, "info");
_logging_guard = Some(guard);
}
crate::logging::LogDestination::File { .. } => {
// File logging with optional rotation
let logging_opts = crate::logging::LoggingOptions {
destination: log_destination,
disable_colors: true,
};
let (_, guard) = crate::logging::init_logging(&logging_opts, "info");
_logging_guard = Some(guard);
}
}
startup_tracker
.complete_component(COMPONENT_TRACING_INIT, Some("tracing initialized".to_string()))
.complete_component(
COMPONENT_TRACING_INIT,
Some("tracing initialized".to_string()),
)
.await;
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
@ -208,7 +321,8 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
config.access.user_max_unique_ips_window_secs,
)
.await;
if config.access.user_max_unique_ips_global_each > 0 || !config.access.user_max_unique_ips.is_empty()
if config.access.user_max_unique_ips_global_each > 0
|| !config.access.user_max_unique_ips.is_empty()
{
info!(
global_each_limit = config.access.user_max_unique_ips_global_each,
@ -235,7 +349,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode));
let api_me_pool = Arc::new(RwLock::new(None::<Arc<MePool>>));
startup_tracker
.start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string()))
.start_component(
COMPONENT_API_BOOTSTRAP,
Some("spawn API listener task".to_string()),
)
.await;
if config.server.api.enabled {
@ -258,7 +375,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let route_runtime_api = route_runtime.clone();
let config_rx_api = api_config_rx.clone();
let admission_rx_api = admission_rx.clone();
let config_path_api = std::path::PathBuf::from(&config_path);
let config_path_api = config_path.clone();
let startup_tracker_api = startup_tracker.clone();
let detected_ips_rx_api = detected_ips_rx.clone();
tokio::spawn(async move {
@ -318,7 +435,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.await;
startup_tracker
.start_component(COMPONENT_NETWORK_PROBE, Some("probe network capabilities".to_string()))
.start_component(
COMPONENT_NETWORK_PROBE,
Some("probe network capabilities".to_string()),
)
.await;
let probe = run_probe(
&config.network,
@ -331,11 +451,8 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
probe.detected_ipv4.map(IpAddr::V4),
probe.detected_ipv6.map(IpAddr::V6),
));
let decision = decide_network_capabilities(
&config.network,
&probe,
config.general.middle_proxy_nat_ip,
);
let decision =
decide_network_capabilities(&config.network, &probe, config.general.middle_proxy_nat_ip);
log_probe_result(&probe, &decision);
startup_tracker
.complete_component(
@ -438,24 +555,16 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
// If ME failed to initialize, force direct-only mode.
if me_pool.is_some() {
startup_tracker
.set_transport_mode("middle_proxy")
.await;
startup_tracker
.set_degraded(false)
.await;
startup_tracker.set_transport_mode("middle_proxy").await;
startup_tracker.set_degraded(false).await;
info!("Transport: Middle-End Proxy - all DC-over-RPC");
} else {
let _ = use_middle_proxy;
use_middle_proxy = false;
// Make runtime config reflect direct-only mode for handlers.
config.general.use_middle_proxy = false;
startup_tracker
.set_transport_mode("direct")
.await;
startup_tracker
.set_degraded(true)
.await;
startup_tracker.set_transport_mode("direct").await;
startup_tracker.set_degraded(true).await;
if me2dc_fallback {
startup_tracker
.set_me_status(StartupMeStatus::Failed, "fallback_to_direct")
@ -555,6 +664,17 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
std::process::exit(1);
}
// Drop privileges after binding sockets (which may require root for port < 1024)
if daemon_opts.user.is_some() || daemon_opts.group.is_some() {
if let Err(e) = drop_privileges(
daemon_opts.user.as_deref(),
daemon_opts.group.as_deref(),
) {
error!(error = %e, "Failed to drop privileges");
std::process::exit(1);
}
}
runtime_tasks::apply_runtime_log_filter(
has_rust_log,
&effective_log_level,
@ -575,6 +695,9 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
runtime_tasks::mark_runtime_ready(&startup_tracker).await;
// Spawn signal handlers for SIGUSR1/SIGUSR2 (non-shutdown signals)
shutdown::spawn_signal_handlers(stats.clone(), process_started_at);
listeners::spawn_tcp_accept_loops(
listeners,
config_rx.clone(),
@ -592,7 +715,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
max_connections.clone(),
);
shutdown::wait_for_shutdown(process_started_at, me_pool).await;
shutdown::wait_for_shutdown(process_started_at, me_pool, stats).await;
Ok(())
}

View File

@ -1,24 +1,27 @@
use std::net::IpAddr;
use std::path::PathBuf;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::{mpsc, watch};
use tracing::{debug, warn};
use tracing_subscriber::reload;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::reload;
use crate::config::{LogLevel, ProxyConfig};
use crate::config::hot_reload::spawn_config_watcher;
use crate::config::{LogLevel, ProxyConfig};
use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker;
use crate::metrics;
use crate::network::probe::NetworkProbe;
use crate::startup::{COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, StartupTracker};
use crate::startup::{
COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY,
StartupTracker,
};
use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats};
use crate::transport::middle_proxy::{MePool, MeReinitTrigger};
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{MePool, MeReinitTrigger};
use super::helpers::write_beobachten_snapshot;
@ -32,7 +35,7 @@ pub(crate) struct RuntimeWatches {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn spawn_runtime_tasks(
config: &Arc<ProxyConfig>,
config_path: &str,
config_path: &Path,
probe: &NetworkProbe,
prefer_ipv6: bool,
decision_ipv4_dc: bool,
@ -79,11 +82,9 @@ pub(crate) async fn spawn_runtime_tasks(
Some("spawn config hot-reload watcher".to_string()),
)
.await;
let (config_rx, log_level_rx): (
watch::Receiver<Arc<ProxyConfig>>,
watch::Receiver<LogLevel>,
) = spawn_config_watcher(
PathBuf::from(config_path),
let (config_rx, log_level_rx): (watch::Receiver<Arc<ProxyConfig>>, watch::Receiver<LogLevel>) =
spawn_config_watcher(
config_path.to_path_buf(),
config.clone(),
detected_ip_v4,
detected_ip_v6,
@ -114,7 +115,8 @@ pub(crate) async fn spawn_runtime_tasks(
break;
}
let cfg = config_rx_policy.borrow_and_update().clone();
stats_policy.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry));
stats_policy
.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry));
if let Some(pool) = &me_pool_for_policy {
pool.update_runtime_transport_policy(
cfg.general.me_socks_kdf_policy,
@ -130,7 +132,11 @@ pub(crate) async fn spawn_runtime_tasks(
let ip_tracker_policy = ip_tracker.clone();
let mut config_rx_ip_limits = config_rx.clone();
tokio::spawn(async move {
let mut prev_limits = config_rx_ip_limits.borrow().access.user_max_unique_ips.clone();
let mut prev_limits = config_rx_ip_limits
.borrow()
.access
.user_max_unique_ips
.clone();
let mut prev_global_each = config_rx_ip_limits
.borrow()
.access
@ -183,7 +189,9 @@ pub(crate) async fn spawn_runtime_tasks(
let sleep_secs = cfg.general.beobachten_flush_secs.max(1);
if cfg.general.beobachten {
let ttl = std::time::Duration::from_secs(cfg.general.beobachten_minutes.saturating_mul(60));
let ttl = std::time::Duration::from_secs(
cfg.general.beobachten_minutes.saturating_mul(60),
);
let path = cfg.general.beobachten_file.clone();
let snapshot = beobachten_writer.snapshot_text(ttl);
if let Err(e) = write_beobachten_snapshot(&path, &snapshot).await {
@ -227,7 +235,10 @@ pub(crate) async fn spawn_runtime_tasks(
let config_rx_clone_rot = config_rx.clone();
let reinit_tx_rotation = reinit_tx.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_rotation_task(config_rx_clone_rot, reinit_tx_rotation)
crate::transport::middle_proxy::me_rotation_task(
config_rx_clone_rot,
reinit_tx_rotation,
)
.await;
});
}

View File

@ -1,23 +1,100 @@
//! Shutdown and signal handling for telemt.
//!
//! Handles graceful shutdown on various signals:
//! - SIGINT (Ctrl+C) / SIGTERM: Graceful shutdown
//! - SIGQUIT: Graceful shutdown with stats dump
//! - SIGUSR1: Reserved for log rotation (logs acknowledgment)
//! - SIGUSR2: Dump runtime status to log
//!
//! SIGHUP is handled separately in config/hot_reload.rs for config reload.
use std::sync::Arc;
use std::time::{Duration, Instant};
#[cfg(unix)]
use tokio::signal::unix::{SignalKind, signal};
#[cfg(not(unix))]
use tokio::signal;
use tracing::{error, info, warn};
use tracing::{info, warn};
use crate::stats::Stats;
use crate::transport::middle_proxy::MePool;
use super::helpers::{format_uptime, unit_label};
pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Option<Arc<MePool>>) {
match signal::ctrl_c().await {
Ok(()) => {
/// Signal that triggered shutdown.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownSignal {
/// SIGINT (Ctrl+C)
Interrupt,
/// SIGTERM
Terminate,
/// SIGQUIT (with stats dump)
Quit,
}
impl std::fmt::Display for ShutdownSignal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShutdownSignal::Interrupt => write!(f, "SIGINT"),
ShutdownSignal::Terminate => write!(f, "SIGTERM"),
ShutdownSignal::Quit => write!(f, "SIGQUIT"),
}
}
}
/// Waits for a shutdown signal and performs graceful shutdown.
pub(crate) async fn wait_for_shutdown(
process_started_at: Instant,
me_pool: Option<Arc<MePool>>,
stats: Arc<Stats>,
) {
let signal = wait_for_shutdown_signal().await;
perform_shutdown(signal, process_started_at, me_pool, &stats).await;
}
/// Waits for any shutdown signal (SIGINT, SIGTERM, SIGQUIT).
#[cfg(unix)]
async fn wait_for_shutdown_signal() -> ShutdownSignal {
let mut sigint = signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler");
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler");
let mut sigquit = signal(SignalKind::quit()).expect("Failed to register SIGQUIT handler");
tokio::select! {
_ = sigint.recv() => ShutdownSignal::Interrupt,
_ = sigterm.recv() => ShutdownSignal::Terminate,
_ = sigquit.recv() => ShutdownSignal::Quit,
}
}
#[cfg(not(unix))]
async fn wait_for_shutdown_signal() -> ShutdownSignal {
signal::ctrl_c().await.expect("Failed to listen for Ctrl+C");
ShutdownSignal::Interrupt
}
/// Performs graceful shutdown sequence.
async fn perform_shutdown(
signal: ShutdownSignal,
process_started_at: Instant,
me_pool: Option<Arc<MePool>>,
stats: &Stats,
) {
let shutdown_started_at = Instant::now();
info!(signal = %signal, "Received shutdown signal");
// Dump stats if SIGQUIT
if signal == ShutdownSignal::Quit {
dump_stats(stats, process_started_at);
}
info!("Shutting down...");
let uptime_secs = process_started_at.elapsed().as_secs();
info!("Uptime: {}", format_uptime(uptime_secs));
// Graceful ME pool shutdown
if let Some(pool) = &me_pool {
match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all())
.await
match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()).await
{
Ok(total) => {
info!(
@ -30,6 +107,7 @@ pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Opti
}
}
}
let shutdown_secs = shutdown_started_at.elapsed().as_secs();
info!(
"Shutdown completed successfully in {} {}.",
@ -37,6 +115,97 @@ pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Opti
unit_label(shutdown_secs, "second", "seconds")
);
}
Err(e) => error!("Signal error: {}", e),
/// Dumps runtime statistics to the log.
fn dump_stats(stats: &Stats, process_started_at: Instant) {
let uptime_secs = process_started_at.elapsed().as_secs();
info!("=== Runtime Statistics Dump ===");
info!("Uptime: {}", format_uptime(uptime_secs));
// Connection stats
info!(
"Connections: total={}, current={} (direct={}, me={}), bad={}",
stats.get_connects_all(),
stats.get_current_connections_total(),
stats.get_current_connections_direct(),
stats.get_current_connections_me(),
stats.get_connects_bad(),
);
// ME pool stats
info!(
"ME keepalive: sent={}, pong={}, failed={}, timeout={}",
stats.get_me_keepalive_sent(),
stats.get_me_keepalive_pong(),
stats.get_me_keepalive_failed(),
stats.get_me_keepalive_timeout(),
);
// Relay stats
info!(
"Relay idle: soft_mark={}, hard_close={}, pressure_evict={}",
stats.get_relay_idle_soft_mark_total(),
stats.get_relay_idle_hard_close_total(),
stats.get_relay_pressure_evict_total(),
);
info!("=== End Statistics Dump ===");
}
/// Spawns a background task to handle operational signals (SIGUSR1, SIGUSR2).
///
/// These signals don't trigger shutdown but perform specific actions:
/// - SIGUSR1: Log rotation acknowledgment (for external log rotation tools)
/// - SIGUSR2: Dump runtime status to log
#[cfg(unix)]
pub(crate) fn spawn_signal_handlers(
stats: Arc<Stats>,
process_started_at: Instant,
) {
tokio::spawn(async move {
let mut sigusr1 = signal(SignalKind::user_defined1())
.expect("Failed to register SIGUSR1 handler");
let mut sigusr2 = signal(SignalKind::user_defined2())
.expect("Failed to register SIGUSR2 handler");
loop {
tokio::select! {
_ = sigusr1.recv() => {
handle_sigusr1();
}
_ = sigusr2.recv() => {
handle_sigusr2(&stats, process_started_at);
}
}
}
});
}
/// No-op on non-Unix platforms.
#[cfg(not(unix))]
pub(crate) fn spawn_signal_handlers(
_stats: Arc<Stats>,
_process_started_at: Instant,
) {
// No SIGUSR1/SIGUSR2 on non-Unix
}
/// Handles SIGUSR1 - log rotation signal.
///
/// This signal is typically sent by logrotate or similar tools after
/// rotating log files. Since tracing-subscriber doesn't natively support
/// reopening files, we just acknowledge the signal. If file logging is
/// added in the future, this would reopen log file handles.
#[cfg(unix)]
fn handle_sigusr1() {
info!("SIGUSR1 received - log rotation acknowledged");
// Future: If using file-based logging, reopen file handles here
}
/// Handles SIGUSR2 - dump runtime status.
#[cfg(unix)]
fn handle_sigusr2(stats: &Stats, process_started_at: Instant) {
info!("SIGUSR2 received - dumping runtime status");
dump_stats(stats, process_started_at);
}

View File

@ -1,7 +1,7 @@
use std::sync::Arc;
use std::time::Duration;
use rand::Rng;
use rand::RngExt;
use tracing::warn;
use crate::config::ProxyConfig;

View File

@ -4,9 +4,20 @@ mod api;
mod cli;
mod config;
mod crypto;
#[cfg(unix)]
mod daemon;
mod error;
mod ip_tracker;
mod logging;
mod service;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"]
mod ip_tracker_encapsulation_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests;
mod maestro;
mod metrics;
@ -20,7 +31,49 @@ mod tls_front;
mod transport;
mod util;
#[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
maestro::run().await
fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let args: Vec<String> = std::env::args().skip(1).collect();
let cmd = cli::parse_command(&args);
// Handle subcommands that don't need the server (stop, reload, status, init)
if let Some(exit_code) = cli::execute_subcommand(&cmd) {
std::process::exit(exit_code);
}
// On Unix, handle daemonization before starting tokio runtime
#[cfg(unix)]
{
let daemon_opts = cmd.daemon_opts;
// Daemonize if requested (must happen before tokio runtime starts)
if daemon_opts.should_daemonize() {
match daemon::daemonize(daemon_opts.working_dir.as_deref()) {
Ok(daemon::DaemonizeResult::Parent) => {
// Parent process exits successfully
std::process::exit(0);
}
Ok(daemon::DaemonizeResult::Child) => {
// Continue as daemon child
}
Err(e) => {
eprintln!("[telemt] Daemonization failed: {}", e);
std::process::exit(1);
}
}
}
// Now start tokio runtime and run the server
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(maestro::run_with_daemon(daemon_opts))
}
#[cfg(not(unix))]
{
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(maestro::run())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -26,9 +26,7 @@ fn parse_ip_spec(ip_spec: &str) -> Result<IpAddr> {
}
let ip = ip_spec.parse::<IpAddr>().map_err(|_| {
ProxyError::Config(format!(
"network.dns_overrides IP is invalid: '{ip_spec}'"
))
ProxyError::Config(format!("network.dns_overrides IP is invalid: '{ip_spec}'"))
})?;
if matches!(ip, IpAddr::V6(_)) {
return Err(ProxyError::Config(format!(
@ -103,9 +101,9 @@ pub fn validate_entries(entries: &[String]) -> Result<()> {
/// Replace runtime DNS overrides with a new validated snapshot.
pub fn install_entries(entries: &[String]) -> Result<()> {
let parsed = parse_entries(entries)?;
let mut guard = overrides_store()
.write()
.map_err(|_| ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()))?;
let mut guard = overrides_store().write().map_err(|_| {
ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string())
})?;
*guard = parsed;
Ok(())
}

View File

@ -1,4 +1,5 @@
#![allow(dead_code)]
#![allow(clippy::items_after_test_module)]
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
@ -10,7 +11,9 @@ use tracing::{debug, info, warn};
use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType};
use crate::error::Result;
use crate::network::stun::{stun_probe_family_with_bind, DualStunResult, IpFamily, StunProbeResult};
use crate::network::stun::{
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind,
};
use crate::transport::UpstreamManager;
#[derive(Debug, Clone, Default)]
@ -78,12 +81,7 @@ pub async fn run_probe(
warn!("STUN probe is enabled but network.stun_servers is empty");
DualStunResult::default()
} else {
probe_stun_servers_parallel(
&servers,
stun_nat_probe_concurrency.max(1),
None,
None,
)
probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None)
.await
}
} else if nat_probe {
@ -99,7 +97,8 @@ pub async fn run_probe(
let UpstreamType::Direct {
interface,
bind_addresses,
} = &upstream.upstream_type else {
} = &upstream.upstream_type
else {
continue;
};
if let Some(addrs) = bind_addresses.as_ref().filter(|v| !v.is_empty()) {
@ -199,12 +198,11 @@ pub async fn run_probe(
if nat_probe
&& probe.reflected_ipv4.is_none()
&& probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false)
&& let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await
{
if let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await {
probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0));
info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback");
}
}
probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) {
(Some(det), Some(reflected)) => det != reflected.ip(),
@ -217,12 +215,20 @@ pub async fn run_probe(
probe.ipv4_usable = config.ipv4
&& probe.detected_ipv4.is_some()
&& (!probe.ipv4_is_bogon || probe.reflected_ipv4.map(|r| !is_bogon(r.ip())).unwrap_or(false));
&& (!probe.ipv4_is_bogon
|| probe
.reflected_ipv4
.map(|r| !is_bogon(r.ip()))
.unwrap_or(false));
let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some());
probe.ipv6_usable = ipv6_enabled
&& probe.detected_ipv6.is_some()
&& (!probe.ipv6_is_bogon || probe.reflected_ipv6.map(|r| !is_bogon(r.ip())).unwrap_or(false));
&& (!probe.ipv6_is_bogon
|| probe
.reflected_ipv6
.map(|r| !is_bogon(r.ip()))
.unwrap_or(false));
Ok(probe)
}
@ -280,8 +286,6 @@ async fn probe_stun_servers_parallel(
while next_idx < servers.len() && join_set.len() < concurrency {
let stun_addr = servers[next_idx].clone();
next_idx += 1;
let bind_v4 = bind_v4;
let bind_v6 = bind_v6;
join_set.spawn(async move {
let res = timeout(STUN_BATCH_TIMEOUT, async {
let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?;
@ -300,11 +304,15 @@ async fn probe_stun_servers_parallel(
match task {
Ok((stun_addr, Ok(Ok(result)))) => {
if let Some(v4) = result.v4 {
let entry = best_v4_by_ip.entry(v4.reflected_addr.ip()).or_insert((0, v4));
let entry = best_v4_by_ip
.entry(v4.reflected_addr.ip())
.or_insert((0, v4));
entry.0 += 1;
}
if let Some(v6) = result.v6 {
let entry = best_v6_by_ip.entry(v6.reflected_addr.ip()).or_insert((0, v6));
let entry = best_v6_by_ip
.entry(v6.reflected_addr.ip())
.or_insert((0, v6));
entry.0 += 1;
}
if result.v4.is_some() || result.v6.is_some() {
@ -324,17 +332,11 @@ async fn probe_stun_servers_parallel(
}
let mut out = DualStunResult::default();
if let Some((_, best)) = best_v4_by_ip
.into_values()
.max_by_key(|(count, _)| *count)
{
if let Some((_, best)) = best_v4_by_ip.into_values().max_by_key(|(count, _)| *count) {
info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip());
out.v4 = Some(best);
}
if let Some((_, best)) = best_v6_by_ip
.into_values()
.max_by_key(|(count, _)| *count)
{
if let Some((_, best)) = best_v6_by_ip.into_values().max_by_key(|(count, _)| *count) {
info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip());
out.v6 = Some(best);
}
@ -347,7 +349,8 @@ pub fn decide_network_capabilities(
middle_proxy_nat_ip: Option<IpAddr>,
) -> NetworkDecision {
let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some();
let ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some();
let ipv6_dc =
config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some();
let nat_ip_v4 = matches!(middle_proxy_nat_ip, Some(IpAddr::V4(_)));
let nat_ip_v6 = matches!(middle_proxy_nat_ip, Some(IpAddr::V6(_)));
@ -534,10 +537,26 @@ pub fn is_bogon_v6(ip: Ipv6Addr) -> bool {
pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) {
info!(
ipv4 = probe.detected_ipv4.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
ipv6 = probe.detected_ipv6.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
reflected_v4 = probe.reflected_ipv4.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
reflected_v6 = probe.reflected_ipv6.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
ipv4 = probe
.detected_ipv4
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| "-".into()),
ipv6 = probe
.detected_ipv6
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| "-".into()),
reflected_v4 = probe
.reflected_ipv4
.as_ref()
.map(|v| v.ip().to_string())
.unwrap_or_else(|| "-".into()),
reflected_v6 = probe
.reflected_ipv6
.as_ref()
.map(|v| v.ip().to_string())
.unwrap_or_else(|| "-".into()),
ipv4_bogon = probe.ipv4_is_bogon,
ipv6_bogon = probe.ipv6_is_bogon,
ipv4_me = decision.ipv4_me,

View File

@ -2,13 +2,20 @@
#![allow(dead_code)]
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::OnceLock;
use tokio::net::{lookup_host, UdpSocket};
use tokio::time::{timeout, Duration, sleep};
use tokio::net::{UdpSocket, lookup_host};
use tokio::time::{Duration, sleep, timeout};
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result};
use crate::network::dns_overrides::{resolve, split_host_port};
fn stun_rng() -> &'static SecureRandom {
static STUN_RNG: OnceLock<SecureRandom> = OnceLock::new();
STUN_RNG.get_or_init(SecureRandom::new)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IpFamily {
V4,
@ -34,13 +41,13 @@ pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
stun_probe_family(stun_addr, IpFamily::V6),
);
Ok(DualStunResult {
v4: v4?,
v6: v6?,
})
Ok(DualStunResult { v4: v4?, v6: v6? })
}
pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result<Option<StunProbeResult>> {
pub async fn stun_probe_family(
stun_addr: &str,
family: IpFamily,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind(stun_addr, family, None).await
}
@ -49,8 +56,6 @@ pub async fn stun_probe_family_with_bind(
family: IpFamily,
bind_ip: Option<IpAddr>,
) -> Result<Option<StunProbeResult>> {
use rand::RngCore;
let bind_addr = match (family, bind_ip) {
(IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0),
(IpFamily::V6, Some(IpAddr::V6(ip))) => SocketAddr::new(IpAddr::V6(ip), 0),
@ -71,13 +76,18 @@ pub async fn stun_probe_family_with_bind(
if let Some(addr) = target_addr {
match socket.connect(addr).await {
Ok(()) => {}
Err(e) if family == IpFamily::V6 && matches!(
Err(e)
if family == IpFamily::V6
&& matches!(
e.kind(),
std::io::ErrorKind::NetworkUnreachable
| std::io::ErrorKind::HostUnreachable
| std::io::ErrorKind::Unsupported
| std::io::ErrorKind::NetworkDown
) => return Ok(None),
) =>
{
return Ok(None);
}
Err(e) => return Err(ProxyError::Proxy(format!("STUN connect failed: {e}"))),
}
} else {
@ -88,7 +98,7 @@ pub async fn stun_probe_family_with_bind(
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
rand::rng().fill_bytes(&mut req[8..20]); // transaction ID
stun_rng().fill(&mut req[8..20]); // transaction ID
let mut buf = [0u8; 256];
let mut attempt = 0;
@ -200,7 +210,6 @@ pub async fn stun_probe_family_with_bind(
idx += (alen + 3) & !3;
}
}
Ok(None)
@ -228,7 +237,11 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<S
.await
.map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?;
let target = addrs
.find(|a| matches!((a.is_ipv4(), family), (true, IpFamily::V4) | (false, IpFamily::V6)));
let target = addrs.find(|a| {
matches!(
(a.is_ipv4(), family),
(true, IpFamily::V4) | (false, IpFamily::V6)
)
});
Ok(target)
}

View File

@ -36,32 +36,86 @@ pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| {
let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
m.insert(
1,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)],
);
m.insert(
-1,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)],
);
m.insert(
2,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)],
);
m.insert(
-2,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)],
);
m.insert(
3,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)],
);
m.insert(
-3,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)],
);
m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]);
m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]);
m.insert(
-4,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)],
);
m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
m.insert(
-5,
vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)],
);
m
});
pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| {
let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
m.insert(
1,
vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)],
);
m.insert(
-1,
vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)],
);
m.insert(
2,
vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)],
);
m.insert(
-2,
vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)],
);
m.insert(
3,
vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)],
);
m.insert(
-3,
vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)],
);
m.insert(
4,
vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)],
);
m.insert(
-4,
vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)],
);
m.insert(
5,
vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)],
);
m.insert(
-5,
vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)],
);
m
});
@ -152,11 +206,29 @@ pub const TLS_RECORD_CHANGE_CIPHER: u8 = 0x14;
pub const TLS_RECORD_APPLICATION: u8 = 0x17;
/// TLS record type: Alert
pub const TLS_RECORD_ALERT: u8 = 0x15;
/// Maximum TLS record size
pub const MAX_TLS_RECORD_SIZE: usize = 16384;
/// Maximum TLS chunk size (with overhead)
/// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext
pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256;
/// Maximum TLS plaintext record payload size.
/// RFC 8446 §5.1: "The length MUST NOT exceed 2^14 bytes."
/// Use this for validating incoming unencrypted records
/// (ClientHello, ChangeCipherSpec, unprotected Handshake messages).
pub const MAX_TLS_PLAINTEXT_SIZE: usize = 16_384;
/// Structural minimum for a valid TLS 1.3 ClientHello with SNI.
/// Derived from RFC 8446 §4.1.2 field layout + Appendix D.4 compat mode.
/// Deliberately conservative (below any real client) to avoid false
/// positives on legitimate connections with compact extension sets.
pub const MIN_TLS_CLIENT_HELLO_SIZE: usize = 100;
/// Maximum TLS ciphertext record payload size.
/// RFC 8446 §5.2: "The length MUST NOT exceed 2^14 + 256 bytes."
/// The +256 accounts for maximum AEAD expansion overhead.
/// Use this for validating or sizing buffers for encrypted records.
pub const MAX_TLS_CIPHERTEXT_SIZE: usize = 16_384 + 256;
#[deprecated(note = "use MAX_TLS_PLAINTEXT_SIZE")]
pub const MAX_TLS_RECORD_SIZE: usize = MAX_TLS_PLAINTEXT_SIZE;
#[deprecated(note = "use MAX_TLS_CIPHERTEXT_SIZE")]
pub const MAX_TLS_CHUNK_SIZE: usize = MAX_TLS_CIPHERTEXT_SIZE;
/// Secure Intermediate payload is expected to be 4-byte aligned.
pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
@ -204,9 +276,7 @@ pub const SMALL_BUFFER_SIZE: usize = 8192;
// ============= Statistics =============
/// Duration buckets for histogram metrics
pub static DURATION_BUCKETS: &[f64] = &[
0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0,
];
pub static DURATION_BUCKETS: &[f64] = &[0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0];
// ============= Reserved Nonce Patterns =============
@ -224,9 +294,7 @@ pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[
];
/// Reserved continuation bytes (bytes 4-7)
pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[
[0x00, 0x00, 0x00, 0x00],
];
pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[[0x00, 0x00, 0x00, 0x00]];
// ============= RPC Constants (for Middle Proxy) =============
@ -267,7 +335,6 @@ pub mod rpc_flags {
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
// ============= Middle-End Proxy Servers =============
pub const ME_PROXY_PORT: u16 = 8888;
@ -320,6 +387,10 @@ pub mod rpc_flags {
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
#[cfg(test)]
#[path = "tests/tls_size_constants_security_tests.rs"]
mod tls_size_constants_security_tests;
#[cfg(test)]
mod tests {
use super::*;

View File

@ -83,7 +83,7 @@ impl FrameMode {
/// Validate message length for MTProto
pub fn validate_message_length(len: usize) -> bool {
use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER};
use super::constants::{MAX_MSG_LEN, MIN_MSG_LEN, PADDING_FILLER};
(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) && len.is_multiple_of(PADDING_FILLER.len())
}

View File

@ -2,9 +2,9 @@
#![allow(dead_code)]
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr};
use super::constants::*;
use crate::crypto::{AesCtr, sha256};
use zeroize::Zeroize;
/// Obfuscation parameters from handshake
///
@ -69,9 +69,8 @@ impl ObfuscationParams {
None => continue,
};
let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
let dc_idx =
i16::from_le_bytes(decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap());
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey);

View File

@ -0,0 +1,358 @@
use super::*;
use crate::crypto::sha256_hmac;
use std::time::Instant;
/// Helper to create a byte vector of specific length.
fn make_garbage(len: usize) -> Vec<u8> {
vec![0x42u8; len]
}
/// Helper to create a valid-looking HMAC digest for test.
fn make_digest(secret: &[u8], msg: &[u8], ts: u32) -> [u8; 32] {
let mut hmac = sha256_hmac(secret, msg);
let ts_bytes = ts.to_le_bytes();
for i in 0..4 {
hmac[28 + i] ^= ts_bytes[i];
}
hmac
}
fn make_valid_tls_handshake_with_session_id(
secret: &[u8],
timestamp: u32,
session_id: &[u8],
) -> Vec<u8> {
let session_id_len = session_id.len();
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let digest = make_digest(secret, &handshake, timestamp);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
handshake
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32])
}
// ------------------------------------------------------------------
// Truncated Packet Tests (OWASP ASVS 5.1.4, 5.1.5)
// ------------------------------------------------------------------
#[test]
fn validate_tls_handshake_truncated_10_bytes_rejected() {
let secrets = vec![("user".to_string(), b"secret".to_vec())];
let truncated = make_garbage(10);
assert!(validate_tls_handshake(&truncated, &secrets, true).is_none());
}
#[test]
fn validate_tls_handshake_truncated_at_digest_start_rejected() {
let secrets = vec![("user".to_string(), b"secret".to_vec())];
// TLS_DIGEST_POS = 11. 11 bytes should be rejected.
let truncated = make_garbage(TLS_DIGEST_POS);
assert!(validate_tls_handshake(&truncated, &secrets, true).is_none());
}
#[test]
fn validate_tls_handshake_truncated_inside_digest_rejected() {
let secrets = vec![("user".to_string(), b"secret".to_vec())];
// TLS_DIGEST_POS + 16 (half digest)
let truncated = make_garbage(TLS_DIGEST_POS + 16);
assert!(validate_tls_handshake(&truncated, &secrets, true).is_none());
}
#[test]
fn extract_sni_truncated_at_record_header_rejected() {
let truncated = make_garbage(3);
assert!(extract_sni_from_client_hello(&truncated).is_none());
}
#[test]
fn extract_sni_truncated_at_handshake_header_rejected() {
let mut truncated = vec![TLS_RECORD_HANDSHAKE, 0x03, 0x03, 0x00, 0x05];
truncated.extend_from_slice(&[0x01, 0x00]); // ClientHello type but truncated length
assert!(extract_sni_from_client_hello(&truncated).is_none());
}
// ------------------------------------------------------------------
// Malformed Extension Parsing Tests
// ------------------------------------------------------------------
#[test]
fn extract_sni_with_overlapping_extension_lengths_rejected() {
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header
h.push(0x01); // Handshake type: ClientHello
h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92
h.extend_from_slice(&[0x03, 0x03]); // Version
h.extend_from_slice(&[0u8; 32]); // Random
h.push(0); // Session ID length: 0
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites
h.extend_from_slice(&[0x01, 0x00]); // Compression
// Extensions start
h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32
// Extension 1: SNI (type 0)
h.extend_from_slice(&[0x00, 0x00]);
h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32)
h.extend_from_slice(&[0u8; 64]);
assert!(extract_sni_from_client_hello(&h).is_none());
}
#[test]
fn extract_sni_with_infinite_loop_potential_extension_rejected() {
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header
h.push(0x01); // Handshake type: ClientHello
h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92
h.extend_from_slice(&[0x03, 0x03]); // Version
h.extend_from_slice(&[0u8; 32]); // Random
h.push(0); // Session ID length: 0
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites
h.extend_from_slice(&[0x01, 0x00]); // Compression
// Extensions start
h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16
// Extension: zero length but claims more?
// If our parser didn't advance, it might loop.
// Telemt uses `pos += 4 + elen;` so it always advances.
h.extend_from_slice(&[0x12, 0x34]); // Unknown type
h.extend_from_slice(&[0x00, 0x00]); // Length 0
// Fill the rest with garbage
h.extend_from_slice(&[0x42; 12]);
// We expect it to finish without SNI found
assert!(extract_sni_from_client_hello(&h).is_none());
}
#[test]
fn extract_sni_with_invalid_hostname_rejected() {
let host = b"invalid_host!%^";
let mut sni = Vec::new();
sni.extend_from_slice(&((host.len() + 3) as u16).to_be_bytes());
sni.push(0);
sni.extend_from_slice(&(host.len() as u16).to_be_bytes());
sni.extend_from_slice(host);
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header
h.push(0x01); // ClientHello
h.extend_from_slice(&[0x00, 0x00, 0x5C]);
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&[0u8; 32]);
h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
h.extend_from_slice(&[0x01, 0x00]);
let mut ext = Vec::new();
ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext);
assert!(
extract_sni_from_client_hello(&h).is_none(),
"Invalid SNI hostname must be rejected"
);
}
// ------------------------------------------------------------------
// Timing Neutrality Tests (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[test]
fn validate_tls_handshake_timing_neutrality() {
let secret = b"timing_test_secret_32_bytes_long_";
let secrets = vec![("u".to_string(), secret.to_vec())];
let mut base = vec![0x42u8; 100];
base[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
const ITER: usize = 600;
const ROUNDS: usize = 7;
let mut per_round_avg_diff_ns = Vec::with_capacity(ROUNDS);
for round in 0..ROUNDS {
let mut success_h = base.clone();
let mut fail_h = base.clone();
let start_success = Instant::now();
for _ in 0..ITER {
let digest = make_digest(secret, &success_h, 0);
success_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
let _ = validate_tls_handshake_at_time(&success_h, &secrets, true, 0);
}
let success_elapsed = start_success.elapsed();
let start_fail = Instant::now();
for i in 0..ITER {
let mut digest = make_digest(secret, &fail_h, 0);
let flip_idx = (i + round) % (TLS_DIGEST_LEN - 4);
digest[flip_idx] ^= 0xFF;
fail_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
let _ = validate_tls_handshake_at_time(&fail_h, &secrets, true, 0);
}
let fail_elapsed = start_fail.elapsed();
let diff = if success_elapsed > fail_elapsed {
success_elapsed - fail_elapsed
} else {
fail_elapsed - success_elapsed
};
per_round_avg_diff_ns.push(diff.as_nanos() as f64 / ITER as f64);
}
per_round_avg_diff_ns.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median_avg_diff_ns = per_round_avg_diff_ns[ROUNDS / 2];
// Keep this as a coarse side-channel guard only; noisy shared CI hosts can
// introduce microsecond-level jitter that should not fail deterministic suites.
assert!(
median_avg_diff_ns < 50_000.0,
"Median timing delta too large: {} ns/iter",
median_avg_diff_ns
);
}
// ------------------------------------------------------------------
// Adversarial Fingerprinting / Active Probing Tests
// ------------------------------------------------------------------
#[test]
fn is_tls_handshake_robustness_against_probing() {
// Valid TLS 1.0 ClientHello
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
// Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer)
assert!(is_tls_handshake(&[0x16, 0x03, 0x03]));
// Invalid record type but matching version
assert!(!is_tls_handshake(&[0x17, 0x03, 0x03]));
// Plaintext HTTP request
assert!(!is_tls_handshake(b"GET / HTTP/1.1"));
// Short garbage
assert!(!is_tls_handshake(&[0x16, 0x03]));
}
#[test]
fn validate_tls_handshake_at_time_strict_boundary() {
let secret = b"strict_boundary_secret_32_bytes_";
let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_000_000_000;
// Boundary: exactly TIME_SKEW_MAX (120s past)
let ts_past = (now - TIME_SKEW_MAX) as u32;
let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]);
assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some());
// Boundary + 1s: should be rejected
let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32;
let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]);
assert!(validate_tls_handshake_at_time(&h2, &secrets, false, now).is_none());
}
#[test]
fn extract_sni_with_duplicate_extensions_rejected() {
// Construct a ClientHello with TWO SNI extensions
let host1 = b"first.com";
let mut sni1 = Vec::new();
sni1.extend_from_slice(&((host1.len() + 3) as u16).to_be_bytes());
sni1.push(0);
sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes());
sni1.extend_from_slice(host1);
let host2 = b"second.com";
let mut sni2 = Vec::new();
sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes());
sni2.push(0);
sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes());
sni2.extend_from_slice(host2);
let mut ext = Vec::new();
// Ext 1: SNI
ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni1.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni1);
// Ext 2: SNI again
ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni2);
let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
body.extend_from_slice(&[0x01, 0x00]);
body.extend_from_slice(&(ext.len() as u16).to_be_bytes());
body.extend_from_slice(&ext);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut h = Vec::new();
h.push(0x16);
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
h.extend_from_slice(&handshake);
// Duplicate SNI extensions are ambiguous and must fail closed.
assert!(extract_sni_from_client_hello(&h).is_none());
}
#[test]
fn extract_alpn_with_malformed_list_rejected() {
let mut alpn_payload = Vec::new();
alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5
alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5)
alpn_payload.extend_from_slice(b"h2");
let mut ext = Vec::new();
ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16)
ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes());
ext.extend_from_slice(&alpn_payload);
let mut h = vec![
0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03,
];
h.extend_from_slice(&[0u8; 32]);
h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext);
let res = extract_alpn_from_client_hello(&h);
assert!(
res.is_empty(),
"Malformed ALPN list must return empty or fail"
);
}
#[test]
fn extract_sni_with_huge_extension_header_rejected() {
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x00]; // Record header
h.push(0x01); // ClientHello
h.extend_from_slice(&[0x00, 0xFF, 0xFF]); // Huge length (65535) - overflows record
h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&[0u8; 32]);
h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]);
// Extensions start
h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything)
assert!(extract_sni_from_client_hello(&h).is_none());
}

View File

@ -0,0 +1,210 @@
use super::*;
use crate::crypto::sha256_hmac;
use std::panic::catch_unwind;
fn make_valid_tls_handshake_with_session_id(
secret: &[u8],
timestamp: u32,
session_id: &[u8],
) -> Vec<u8> {
let session_id_len = session_id.len();
assert!(session_id_len <= u8::MAX as usize);
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let mut digest = sha256_hmac(secret, &handshake);
let ts = timestamp.to_le_bytes();
for idx in 0..4 {
digest[28 + idx] ^= ts[idx];
}
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
handshake
}
fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[test]
fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]);
assert_eq!(
extract_sni_from_client_hello(&valid).as_deref(),
Some("example.com")
);
assert_eq!(
extract_alpn_from_client_hello(&valid),
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
);
assert!(
extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(),
"literal IP hostnames must be rejected"
);
let mut corpus = vec![
Vec::new(),
vec![0x16, 0x03, 0x03],
valid[..9].to_vec(),
valid[..valid.len() - 1].to_vec(),
];
let mut wrong_type = valid.clone();
wrong_type[0] = 0x15;
corpus.push(wrong_type);
let mut wrong_handshake = valid.clone();
wrong_handshake[5] = 0x02;
corpus.push(wrong_handshake);
let mut wrong_length = valid.clone();
wrong_length[3] ^= 0x7f;
corpus.push(wrong_length);
for (idx, input) in corpus.iter().enumerate() {
assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok());
assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok());
if idx == 0 {
continue;
}
assert!(
extract_sni_from_client_hello(input).is_none(),
"corpus item {idx} must fail closed for SNI"
);
assert!(
extract_alpn_from_client_hello(input).is_empty(),
"corpus item {idx} must fail closed for ALPN"
);
}
}
#[test]
fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
let secret = b"tls_fuzz_security_secret";
let now: i64 = 1_700_000_000;
let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]);
let secrets = vec![("fuzz-user".to_string(), secret.to_vec())];
assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some());
let mut corpus = Vec::new();
let mut truncated = base.clone();
truncated.truncate(TLS_DIGEST_POS + 16);
corpus.push(truncated);
let mut digest_flip = base.clone();
digest_flip[TLS_DIGEST_POS + 7] ^= 0x80;
corpus.push(digest_flip);
let mut session_id_len_overflow = base.clone();
session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33;
corpus.push(session_id_len_overflow);
let mut timestamp_far_past = base.clone();
timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes());
corpus.push(timestamp_far_past);
let mut timestamp_far_future = base.clone();
timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32]
.copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes());
corpus.push(timestamp_far_future);
let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64;
for _ in 0..32 {
let mut mutated = base.clone();
for _ in 0..2 {
seed = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN);
mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, handshake) in corpus.iter().enumerate() {
let result =
catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now));
assert!(result.is_ok(), "corpus item {idx} must not panic");
assert!(
result.unwrap().is_none(),
"corpus item {idx} must fail closed"
);
}
}
#[test]
fn tls_boot_time_acceptance_is_capped_by_replay_window() {
let secret = b"tls_boot_time_cap_secret";
let secrets = vec![("boot-user".to_string(), secret.to_vec())];
let boot_ts = 1u32;
let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(),
"boot-time timestamp should be accepted while replay window permits it"
);
assert!(
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(),
"boot-time timestamp must be rejected when replay window disables the bypass"
);
}

View File

@ -0,0 +1,37 @@
use super::*;
#[test]
fn extension_builder_fails_closed_on_u16_length_overflow() {
let builder = TlsExtensionBuilder {
extensions: vec![0u8; (u16::MAX as usize) + 1],
};
let built = builder.build();
assert!(
built.is_empty(),
"oversized extension blob must fail closed instead of truncating length field"
);
}
#[test]
fn server_hello_builder_fails_closed_on_session_id_len_overflow() {
let builder = ServerHelloBuilder {
random: [0u8; 32],
session_id: vec![0xAB; (u8::MAX as usize) + 1],
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
compression: 0,
extensions: TlsExtensionBuilder::new(),
};
let message = builder.build_message();
let record = builder.build_record();
assert!(
message.is_empty(),
"session_id length overflow must fail closed in message builder"
);
assert!(
record.is_empty(),
"session_id length overflow must fail closed in record builder"
);
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,11 @@
use super::{MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
#[test]
fn tls_size_constants_match_rfc_8446() {
assert_eq!(MAX_TLS_PLAINTEXT_SIZE, 16_384);
assert_eq!(MAX_TLS_CIPHERTEXT_SIZE, 16_640);
assert!(MIN_TLS_CLIENT_HELLO_SIZE < 512);
assert!(MIN_TLS_CLIENT_HELLO_SIZE > 64);
assert!(MAX_TLS_CIPHERTEXT_SIZE > MAX_TLS_PLAINTEXT_SIZE);
}

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,8 @@
#![allow(dead_code)]
// Adaptive buffer policy is staged and retained for deterministic rollout.
// Keep definitions compiled for compatibility and security test scaffolding.
use dashmap::DashMap;
use std::cmp::max;
use std::sync::OnceLock;
@ -170,7 +175,8 @@ impl SessionAdaptiveController {
return self.promote(TierTransitionReason::SoftConfirmed, 0);
}
let demote_candidate = self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now;
let demote_candidate =
self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now;
if demote_candidate {
self.quiet_ticks = self.quiet_ticks.saturating_add(1);
if self.quiet_ticks >= QUIET_DEMOTE_TICKS {
@ -253,10 +259,7 @@ pub fn record_user_tier(user: &str, tier: AdaptiveTier) {
};
return;
}
profiles().insert(
user.to_string(),
UserAdaptiveProfile { tier, seen_at: now },
);
profiles().insert(user.to_string(), UserAdaptiveProfile { tier, seen_at: now });
}
pub fn direct_copy_buffers_for_tier(
@ -339,10 +342,7 @@ mod tests {
sample(
300_000, // ~9.6 Mbps
320_000, // incoming > outgoing to confirm tier2
250_000,
10,
0,
0,
250_000, 10, 0, 0,
),
tick_secs,
);
@ -358,10 +358,7 @@ mod tests {
fn test_hard_promotion_on_pending_pressure() {
let mut ctrl = SessionAdaptiveController::new(AdaptiveTier::Base);
let transition = ctrl
.observe(
sample(10_000, 20_000, 10_000, 4, 1, 3),
0.25,
)
.observe(sample(10_000, 20_000, 10_000, 4, 1, 3), 0.25)
.expect("expected hard promotion");
assert_eq!(transition.reason, TierTransitionReason::HardPressure);
assert_eq!(transition.to, AdaptiveTier::Tier1);

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,11 @@
use std::collections::HashSet;
use std::ffi::OsString;
use std::fs::OpenOptions;
use std::io::Write;
use std::net::SocketAddr;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split};
use tokio::sync::watch;
@ -17,11 +21,209 @@ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
cutover_stagger_delay,
};
use crate::proxy::adaptive_buffers;
use crate::proxy::session_eviction::SessionLease;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager;
#[cfg(unix)]
use nix::fcntl::{Flock, FlockArg, OFlag, openat};
#[cfg(unix)]
use nix::sys::stat::Mode;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
const MAX_SCOPE_HINT_LEN: usize = 64;
fn validated_scope_hint(user: &str) -> Option<&str> {
let scope = user.strip_prefix("scope_")?;
if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN {
return None;
}
if scope
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
{
Some(scope)
} else {
None
}
}
#[derive(Clone)]
struct SanitizedUnknownDcLogPath {
resolved_path: PathBuf,
allowed_parent: PathBuf,
file_name: OsString,
}
// In tests, this function shares global mutable state. Callers that also use
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
// deterministic under parallel execution.
fn should_log_unknown_dc(dc_idx: i16) -> bool {
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
should_log_unknown_dc_with_set(set, dc_idx)
}
fn should_log_unknown_dc_with_set(set: &Mutex<HashSet<i16>>, dc_idx: i16) -> bool {
match set.lock() {
Ok(mut guard) => {
if guard.contains(&dc_idx) {
return false;
}
if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT {
return false;
}
guard.insert(dc_idx)
}
// Fail closed on poisoned state to avoid unbounded blocking log writes.
Err(_) => false,
}
}
fn sanitize_unknown_dc_log_path(path: &str) -> Option<SanitizedUnknownDcLogPath> {
let candidate = Path::new(path);
if candidate.as_os_str().is_empty() {
return None;
}
if candidate
.components()
.any(|component| matches!(component, Component::ParentDir))
{
return None;
}
let cwd = std::env::current_dir().ok()?;
let file_name = candidate.file_name()?;
let parent = candidate.parent().unwrap_or_else(|| Path::new("."));
let parent_path = if parent.is_absolute() {
parent.to_path_buf()
} else {
cwd.join(parent)
};
let canonical_parent = parent_path.canonicalize().ok()?;
if !canonical_parent.is_dir() {
return None;
}
Some(SanitizedUnknownDcLogPath {
resolved_path: canonical_parent.join(file_name),
allowed_parent: canonical_parent,
file_name: file_name.to_os_string(),
})
}
fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool {
let Some(parent) = path.resolved_path.parent() else {
return false;
};
let Ok(current_parent) = parent.canonicalize() else {
return false;
};
if current_parent != path.allowed_parent {
return false;
}
if let Ok(canonical_target) = path.resolved_path.canonicalize() {
let Some(target_parent) = canonical_target.parent() else {
return false;
};
let Some(target_name) = canonical_target.file_name() else {
return false;
};
if target_parent != path.allowed_parent || target_name != path.file_name {
return false;
}
}
true
}
#[cfg(test)]
fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
#[cfg(unix)]
{
OpenOptions::new()
.create(true)
.append(true)
.custom_flags(libc::O_NOFOLLOW)
.open(path)
}
#[cfg(not(unix))]
{
let _ = path;
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
))
}
}
fn open_unknown_dc_log_append_anchored(
path: &SanitizedUnknownDcLogPath,
) -> std::io::Result<std::fs::File> {
#[cfg(unix)]
{
let parent = OpenOptions::new()
.read(true)
.custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC)
.open(&path.allowed_parent)?;
let oflags = OFlag::O_CREAT
| OFlag::O_APPEND
| OFlag::O_WRONLY
| OFlag::O_NOFOLLOW
| OFlag::O_CLOEXEC;
let mode = Mode::from_bits_truncate(0o600);
let path_component = Path::new(path.file_name.as_os_str());
let fd = openat(&parent, path_component, oflags, mode)
.map_err(|err| std::io::Error::from_raw_os_error(err as i32))?;
let file = std::fs::File::from(fd);
Ok(file)
}
#[cfg(not(unix))]
{
let _ = path;
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
))
}
}
fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> {
#[cfg(unix)]
{
let cloned = file.try_clone()?;
let mut locked = Flock::lock(cloned, FlockArg::LockExclusive)
.map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?;
let write_result = writeln!(&mut *locked, "dc_idx={dc_idx}");
let _ = locked
.unlock()
.map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?;
write_result
}
#[cfg(not(unix))]
{
writeln!(file, "dc_idx={dc_idx}")
}
}
#[cfg(test)]
fn clear_unknown_dc_log_cache_for_testing() {
if let Some(set) = LOGGED_UNKNOWN_DCS.get()
&& let Ok(mut guard) = set.lock()
{
guard.clear();
}
}
#[cfg(test)]
fn unknown_dc_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
pub(crate) async fn handle_via_direct<R, W>(
client_reader: CryptoReader<R>,
@ -35,7 +237,6 @@ pub(crate) async fn handle_via_direct<R, W>(
mut route_rx: watch::Receiver<RouteCutoverState>,
route_snapshot: RouteCutoverState,
session_id: u64,
session_lease: SessionLease,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
@ -54,12 +255,15 @@ where
"Connecting to Telegram DC"
);
let scope_hint = validated_scope_hint(user);
if user.starts_with("scope_") && scope_hint.is_none() {
warn!(
user = %user,
"Ignoring invalid scope hint and falling back to default upstream selection"
);
}
let tg_stream = upstream_manager
.connect(
dc_addr,
Some(success.dc_idx),
user.strip_prefix("scope_").filter(|s| !s.is_empty()),
)
.connect(dc_addr, Some(success.dc_idx), scope_hint)
.await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
@ -70,29 +274,19 @@ where
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
stats.increment_current_connections_direct();
let seed_tier = adaptive_buffers::seed_tier_for_user(user);
let (c2s_copy_buf, s2c_copy_buf) = adaptive_buffers::direct_copy_buffers_for_tier(
seed_tier,
config.general.direct_relay_copy_buf_c2s_bytes,
config.general.direct_relay_copy_buf_s2c_bytes,
);
let _direct_connection_lease = stats.acquire_direct_connection_lease();
let relay_result = relay_bidirectional(
client_reader,
client_writer,
tg_reader,
tg_writer,
c2s_copy_buf,
s2c_copy_buf,
config.general.direct_relay_copy_buf_c2s_bytes,
config.general.direct_relay_copy_buf_s2c_bytes,
user,
success.dc_idx,
Arc::clone(&stats),
config.access.user_data_quota.get(user).copied(),
buffer_pool,
session_lease,
seed_tier,
);
tokio::pin!(relay_result);
let relay_result = loop {
@ -122,9 +316,6 @@ where
}
};
stats.decrement_current_connections_direct();
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
@ -181,20 +372,27 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
&& let Some(path) = &config.general.unknown_dc_log_path
&& let Ok(handle) = tokio::runtime::Handle::try_current()
{
let path = path.clone();
if let Some(path) = sanitize_unknown_dc_log_path(path) {
if should_log_unknown_dc(dc_idx) {
handle.spawn_blocking(move || {
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
let _ = writeln!(file, "dc_idx={dc_idx}");
if unknown_dc_log_path_is_still_safe(&path)
&& let Ok(mut file) = open_unknown_dc_log_append_anchored(&path)
{
let _ = append_unknown_dc_line(&mut file, dc_idx);
}
});
}
} else {
warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path");
}
}
}
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1
0
};
info!(
@ -222,8 +420,6 @@ where
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
success.dc_idx,
&success.dec_key,
success.dec_iv,
&success.enc_key,
success.enc_iv,
rng,
@ -249,3 +445,19 @@ where
CryptoWriter::new(write_half, tg_encryptor, max_pending),
))
}
#[cfg(test)]
#[path = "tests/direct_relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/direct_relay_business_logic_tests.rs"]
mod business_logic_tests;
#[cfg(test)]
#[path = "tests/direct_relay_common_mistakes_tests.rs"]
mod common_mistakes_tests;
#[cfg(test)]
#[path = "tests/direct_relay_subtle_adversarial_tests.rs"]
mod subtle_adversarial_tests;

View File

@ -2,22 +2,479 @@
#![allow(dead_code)]
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use std::collections::HashSet;
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc;
use std::time::Duration;
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace};
use zeroize::Zeroize;
use tracing::{debug, trace, warn};
use zeroize::{Zeroize, Zeroizing};
use crate::crypto::{sha256, AesCtr, SecureRandom};
use rand::Rng;
use crate::config::ProxyConfig;
use crate::crypto::{AesCtr, SecureRandom, sha256};
use crate::error::{HandshakeResult, ProxyError};
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
use crate::error::{ProxyError, HandshakeResult};
use crate::stats::ReplayChecker;
use crate::config::ProxyConfig;
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::tls_front::{TlsFrontCache, emulator};
use rand::RngExt;
const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
#[cfg(test)]
const WARNED_SECRET_MAX_ENTRIES: usize = 64;
#[cfg(not(test))]
const WARNED_SECRET_MAX_ENTRIES: usize = 1_024;
const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60;
#[cfg(test)]
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256;
#[cfg(not(test))]
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024;
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
#[cfg(not(test))]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16;
#[cfg(not(test))]
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000;
#[derive(Clone, Copy)]
struct AuthProbeState {
fail_streak: u32,
blocked_until: Instant,
last_seen: Instant,
}
#[derive(Clone, Copy)]
struct AuthProbeSaturationState {
fail_streak: u32,
blocked_until: Instant,
last_seen: Instant,
}
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> =
OnceLock::new();
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
AUTH_PROBE_STATE.get_or_init(DashMap::new)
}
fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationState>> {
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
}
fn auth_probe_saturation_state_lock()
-> std::sync::MutexGuard<'static, Option<AuthProbeSaturationState>> {
auth_probe_saturation_state()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
match peer_ip {
IpAddr::V4(ip) => IpAddr::V4(ip),
IpAddr::V6(ip) => {
let [a, b, c, d, _, _, _, _] = ip.segments();
IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0))
}
}
}
fn auth_probe_backoff(fail_streak: u32) -> Duration {
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
return Duration::ZERO;
}
let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10);
let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
let ms = AUTH_PROBE_BACKOFF_BASE_MS
.saturating_mul(multiplier)
.min(AUTH_PROBE_BACKOFF_MAX_MS);
Duration::from_millis(ms)
}
fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS);
now.duration_since(state.last_seen) > retention
}
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new);
let mut hasher = hasher_state.build_hasher();
peer_ip.hash(&mut hasher);
now.hash(&mut hasher);
hasher.finish() as usize
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let Some(entry) = state.get(&peer_ip) else {
return false;
};
if auth_probe_state_expired(&entry, now) {
drop(entry);
state.remove(&peer_ip);
return false;
}
now < entry.blocked_until
}
fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let Some(entry) = state.get(&peer_ip) else {
return false;
};
if auth_probe_state_expired(&entry, now) {
drop(entry);
state.remove(&peer_ip);
return false;
}
entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
}
fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool {
if !auth_probe_is_throttled(peer_ip, now) {
return false;
}
if !auth_probe_saturation_is_throttled(now) {
return true;
}
auth_probe_saturation_grace_exhausted(peer_ip, now)
}
fn auth_probe_saturation_is_throttled(now: Instant) -> bool {
let mut guard = auth_probe_saturation_state_lock();
let Some(state) = guard.as_mut() else {
return false;
};
if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) {
*guard = None;
return false;
}
if now < state.blocked_until {
return true;
}
false
}
fn auth_probe_note_saturation(now: Instant) {
let mut guard = auth_probe_saturation_state_lock();
match guard.as_mut() {
Some(state)
if now.duration_since(state.last_seen)
<= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) =>
{
state.fail_streak = state.fail_streak.saturating_add(1);
state.last_seen = now;
state.blocked_until = now + auth_probe_backoff(state.fail_streak);
}
_ => {
let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS;
*guard = Some(AuthProbeSaturationState {
fail_streak,
blocked_until: now + auth_probe_backoff(fail_streak),
last_seen: now,
});
}
}
}
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
auth_probe_record_failure_with_state(state, peer_ip, now);
}
fn auth_probe_record_failure_with_state(
state: &DashMap<IpAddr, AuthProbeState>,
peer_ip: IpAddr,
now: Instant,
) {
let make_new_state = || AuthProbeState {
fail_streak: 1,
blocked_until: now + auth_probe_backoff(1),
last_seen: now,
};
let update_existing = |entry: &mut AuthProbeState| {
if auth_probe_state_expired(entry, now) {
*entry = make_new_state();
} else {
entry.fail_streak = entry.fail_streak.saturating_add(1);
entry.last_seen = now;
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
}
};
match state.entry(peer_ip) {
Entry::Occupied(mut entry) => {
update_existing(entry.get_mut());
return;
}
Entry::Vacant(_) => {}
}
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
let mut rounds = 0usize;
while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
rounds += 1;
if rounds > 8 {
auth_probe_note_saturation(now);
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
}
let Some((evict_key, _, _)) = eviction_candidate else {
return;
};
state.remove(&evict_key);
break;
}
let mut stale_keys = Vec::new();
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = if state_len == 0 {
0
} else {
auth_probe_eviction_offset(peer_ip, now) % state_len
};
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
scanned += 1;
if scanned >= scan_limit {
break;
}
}
if scanned < scan_limit {
for entry in state.iter().take(scan_limit - scanned) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
}
}
for stale_key in stale_keys {
state.remove(&stale_key);
}
if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES {
break;
}
let Some((evict_key, _, _)) = eviction_candidate else {
auth_probe_note_saturation(now);
return;
};
state.remove(&evict_key);
auth_probe_note_saturation(now);
}
}
match state.entry(peer_ip) {
Entry::Occupied(mut entry) => {
update_existing(entry.get_mut());
}
Entry::Vacant(entry) => {
entry.insert(make_new_state());
}
}
}
fn auth_probe_record_success(peer_ip: IpAddr) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
state.remove(&peer_ip);
}
#[cfg(test)]
fn clear_auth_probe_state_for_testing() {
if let Some(state) = AUTH_PROBE_STATE.get() {
state.clear();
}
if AUTH_PROBE_SATURATION_STATE.get().is_some() {
let mut guard = auth_probe_saturation_state_lock();
*guard = None;
}
}
#[cfg(test)]
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = AUTH_PROBE_STATE.get()?;
state.get(&peer_ip).map(|entry| entry.fail_streak)
}
#[cfg(test)]
fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
auth_probe_is_throttled(peer_ip, Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_for_testing() -> bool {
auth_probe_saturation_is_throttled(Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool {
auth_probe_saturation_is_throttled(now)
}
#[cfg(test)]
fn auth_probe_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn clear_warned_secrets_for_testing() {
if let Some(warned) = INVALID_SECRET_WARNED.get()
&& let Ok(mut guard) = warned.lock()
{
guard.clear();
}
}
#[cfg(test)]
fn warned_secrets_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
let key = (name.to_string(), reason.to_string());
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let should_warn = match warned.lock() {
Ok(mut guard) => {
if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES {
false
} else {
guard.insert(key)
}
}
Err(_) => true,
};
if !should_warn {
return;
}
match got {
Some(actual) => {
warn!(
user = %name,
expected = expected,
got = actual,
"Skipping user: access secret has unexpected length"
);
}
None => {
warn!(
user = %name,
"Skipping user: access secret is not valid hex"
);
}
}
}
fn decode_user_secret(name: &str, secret_hex: &str) -> Option<Vec<u8>> {
match hex::decode(secret_hex) {
Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes),
Ok(bytes) => {
warn_invalid_secret_once(
name,
"invalid_length",
ACCESS_SECRET_BYTES,
Some(bytes.len()),
);
None
}
Err(_) => {
warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None);
None
}
}
}
// Decide whether a client-supplied proto tag is allowed given the configured
// proxy modes and the transport that carried the handshake.
//
// A common mistake is to treat `modes.tls` and `modes.secure` as interchangeable
// even though they correspond to different transport profiles: `modes.tls` is
// for the TLS-fronted (EE-TLS) path, while `modes.secure` is for direct MTProto
// over TCP (DD). Enforcing this separation prevents an attacker from using a
// TLS-capable client to bypass the operator intent for the direct MTProto mode,
// and vice versa.
fn mode_enabled_for_proto(config: &ProxyConfig, proto_tag: ProtoTag, is_tls: bool) -> bool {
match proto_tag {
ProtoTag::Secure => {
if is_tls {
config.general.modes.tls
} else {
config.general.modes.secure
}
}
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}
}
fn decode_user_secrets(
config: &ProxyConfig,
@ -27,7 +484,7 @@ fn decode_user_secrets(
if let Some(preferred) = preferred_user
&& let Some(secret_hex) = config.access.users.get(preferred)
&& let Ok(bytes) = hex::decode(secret_hex)
&& let Some(bytes) = decode_user_secret(preferred, secret_hex)
{
secrets.push((preferred.to_string(), bytes));
}
@ -36,7 +493,7 @@ fn decode_user_secrets(
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
continue;
}
if let Ok(bytes) = hex::decode(secret_hex) {
if let Some(bytes) = decode_user_secret(name, secret_hex) {
secrets.push((name.clone(), bytes));
}
}
@ -44,11 +501,29 @@ fn decode_user_secrets(
secrets
}
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
if config.censorship.server_hello_delay_max_ms == 0 {
return;
}
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
/// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
/// zeroized on drop.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct HandshakeSuccess {
/// Authenticated user name
pub user: String,
@ -94,28 +569,33 @@ where
{
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer };
}
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
let client_sni = tls::extract_sni_from_client_hello(handshake);
let secrets = decode_user_secrets(config, client_sni.as_deref());
if replay_checker.check_and_add_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secrets = decode_user_secrets(config, None);
let validation = match tls::validate_tls_handshake(
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
&secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
) {
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
@ -125,16 +605,29 @@ where
}
};
// Replay tracking is applied only after successful authentication to avoid
// letting unauthenticated probes evict valid entries from the replay cache.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer },
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) {
if cache.contains_domain(&sni).await {
sni
let selected_domain = if let Some(sni) = client_sni.as_ref() {
if cache.contains_domain(sni).await {
sni.clone()
} else {
config.censorship.tls_domain.clone()
}
@ -166,6 +659,10 @@ where
Some(b"h2".to_vec())
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
None
}
@ -196,19 +693,9 @@ where
)
};
// Optional anti-fingerprint delay before sending ServerHello.
if config.censorship.server_hello_delay_max_ms > 0 {
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms.max(min);
let delay_ms = if max == min {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
}
// Apply the same optional delay budget used by reject paths to reduce
// distinguishability between success and fail-closed handshakes.
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@ -228,6 +715,8 @@ where
"TLS handshake successful"
);
auth_probe_record_success(peer.ip());
HandshakeResult::Success((
FakeTlsReader::new(reader),
FakeTlsWriter::new(writer),
@ -250,75 +739,93 @@ where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
let handshake_fingerprint = {
let digest = sha256(&handshake[..8]);
hex::encode(&digest[..4])
};
trace!(
peer = %peer,
handshake_fingerprint = %handshake_fingerprint,
"MTProto handshake prefix"
);
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected");
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let decoded_users = decode_user_secrets(config, preferred_user);
for (user, secret) in decoded_users {
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake);
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into()
.unwrap();
let tag_bytes: [u8; 4] = [
decrypted[PROTO_TAG_POS],
decrypted[PROTO_TAG_POS + 1],
decrypted[PROTO_TAG_POS + 2],
decrypted[PROTO_TAG_POS + 3],
];
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag,
None => continue,
};
let mode_ok = match proto_tag {
ProtoTag::Secure => {
if is_tls {
config.general.modes.tls || config.general.modes.secure
} else {
config.general.modes.secure || config.general.modes.tls
}
}
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
};
let mode_ok = mode_enabled_for_proto(config, proto_tag, is_tls);
if !mode_ok {
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled");
continue;
}
let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]);
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(enc_iv_bytes);
let enc_iv = u128::from_be_bytes(enc_iv_arr);
let encryptor = AesCtr::new(&enc_key, enc_iv);
// Apply replay tracking only after successful authentication.
//
// This ordering prevents an attacker from producing invalid handshakes that
// still collide with a valid handshake's replay slot and thus evict a valid
// entry from the cache. We accept the cost of performing the full
// authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
}
let success = HandshakeSuccess {
user: user.clone(),
dc_idx,
@ -340,6 +847,8 @@ where
"MTProto handshake successful"
);
auth_probe_record_success(peer.ip());
let max_pending = config.general.crypto_pending_buffer;
return HandshakeResult::Success((
CryptoReader::new(reader, decryptor),
@ -348,6 +857,8 @@ where
));
}
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer }
}
@ -356,8 +867,6 @@ where
pub fn generate_tg_nonce(
proto_tag: ProtoTag,
dc_idx: i16,
_client_dec_key: &[u8; 32],
_client_dec_iv: u128,
client_enc_key: &[u8; 32],
client_enc_iv: u128,
rng: &SecureRandom,
@ -365,22 +874,30 @@ pub fn generate_tg_nonce(
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop {
let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
let Ok(mut nonce): Result<[u8; HANDSHAKE_LEN], _> = bytes.try_into() else {
continue;
};
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
continue;
}
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
continue;
}
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]];
if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
continue;
}
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// CRITICAL: write dc_idx so upstream DC knows where to route
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
if fast_mode {
let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN);
let mut key_iv = Zeroizing::new(Vec::with_capacity(KEY_LEN + IV_LEN));
key_iv.extend_from_slice(client_enc_key);
key_iv.extend_from_slice(&client_enc_iv.to_be_bytes());
key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce
@ -388,13 +905,19 @@ pub fn generate_tg_nonce(
}
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut tg_enc_key = [0u8; 32];
tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
let mut tg_enc_iv_arr = [0u8; IV_LEN];
tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
let tg_enc_iv = u128::from_be_bytes(tg_enc_iv_arr);
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
let mut tg_dec_key = [0u8; 32];
tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]);
let mut tg_dec_iv_arr = [0u8; IV_LEN];
tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]);
let tg_dec_iv = u128::from_be_bytes(tg_dec_iv_arr);
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
}
@ -403,13 +926,19 @@ pub fn generate_tg_nonce(
/// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::<Vec<u8>>());
let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut enc_key = [0u8; 32];
enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
let enc_iv = u128::from_be_bytes(enc_iv_arr);
let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
let mut dec_key = [0u8; 32];
dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
@ -418,6 +947,8 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
let decryptor = AesCtr::new(&dec_key, dec_iv);
enc_key.zeroize();
dec_key.zeroize();
(result, encryptor, decryptor)
}
@ -429,80 +960,31 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
}
#[cfg(test)]
mod tests {
use super::*;
#[path = "tests/handshake_security_tests.rs"]
mod security_tests;
#[test]
fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
#[cfg(test)]
#[path = "tests/handshake_adversarial_tests.rs"]
mod adversarial_tests;
let rng = SecureRandom::new();
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
#[cfg(test)]
#[path = "tests/handshake_fuzz_security_tests.rs"]
mod fuzz_security_tests;
assert_eq!(nonce.len(), HANDSHAKE_LEN);
#[cfg(test)]
#[path = "tests/handshake_saturation_poison_security_tests.rs"]
mod saturation_poison_security_tests;
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
}
#[test]
fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) =
generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
}
#[test]
fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess {
user: "test".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Secure,
dec_key: [0xAA; 32],
dec_iv: 0xBBBBBBBB,
enc_key: [0xCC; 32],
enc_iv: 0xDDDDDDDD,
peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true,
};
assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]);
drop(success);
// Drop impl zeroizes key material without panic
}
#[cfg(test)]
#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"]
mod auth_probe_hardening_adversarial_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.
mod compile_time_security_checks {
use super::HandshakeSuccess;
use static_assertions::assert_not_impl_all;
assert_not_impl_all!(HandshakeSuccess: Copy, Clone);
}

View File

@ -1,32 +1,231 @@
//! Masking - forward unrecognized traffic to mask host
use std::str;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tracing::debug;
use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use rand::{Rng, RngExt};
use std::net::SocketAddr;
use std::str;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::time::{Instant, timeout};
use tracing::debug;
#[cfg(not(test))]
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_TIMEOUT: Duration = Duration::from_millis(50);
/// Maximum duration for the entire masking relay.
/// Limits resource consumption from slow-loris attacks and port scanners.
#[cfg(not(test))]
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
#[cfg(test)]
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
#[cfg(not(test))]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
struct CopyOutcome {
total: usize,
ended_by_eof: bool,
}
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W) -> CopyOutcome
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut total = 0usize;
let mut ended_by_eof = false;
loop {
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
let n = match read_res {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
ended_by_eof = true;
break;
}
total = total.saturating_add(n);
let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await;
match write_res {
Ok(Ok(())) => {}
Ok(Err(_)) | Err(_) => break,
}
}
CopyOutcome {
total,
ended_by_eof,
}
}
fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize {
if total == 0 || floor == 0 || cap < floor {
return total;
}
if total >= cap {
return total;
}
let mut bucket = floor;
while bucket < total {
match bucket.checked_mul(2) {
Some(next) => bucket = next,
None => return total,
}
if bucket > cap {
return cap;
}
}
bucket
}
async fn maybe_write_shape_padding<W>(
mask_write: &mut W,
total_sent: usize,
enabled: bool,
floor: usize,
cap: usize,
above_cap_blur: bool,
above_cap_blur_max_bytes: usize,
aggressive_mode: bool,
) where
W: AsyncWrite + Unpin,
{
if !enabled {
return;
}
let target_total = if total_sent >= cap && above_cap_blur && above_cap_blur_max_bytes > 0 {
let mut rng = rand::rng();
let extra = if aggressive_mode {
rng.random_range(1..=above_cap_blur_max_bytes)
} else {
rng.random_range(0..=above_cap_blur_max_bytes)
};
total_sent.saturating_add(extra)
} else {
next_mask_shape_bucket(total_sent, floor, cap)
};
if target_total <= total_sent {
return;
}
let mut remaining = target_total - total_sent;
let mut pad_chunk = [0u8; 1024];
let deadline = Instant::now() + MASK_TIMEOUT;
while remaining > 0 {
let now = Instant::now();
if now >= deadline {
return;
}
let write_len = remaining.min(pad_chunk.len());
{
let mut rng = rand::rng();
rng.fill_bytes(&mut pad_chunk[..write_len]);
}
let write_budget = deadline.saturating_duration_since(now);
match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await {
Ok(Ok(())) => {}
Ok(Err(_)) | Err(_) => return,
}
remaining -= write_len;
}
let now = Instant::now();
if now >= deadline {
return;
}
let flush_budget = deadline.saturating_duration_since(now);
let _ = timeout(flush_budget, mask_write.flush()).await;
}
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
where
W: AsyncWrite + Unpin,
{
match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await {
Ok(Ok(())) => true,
Ok(Err(_)) => false,
Err(_) => {
debug!("Timeout writing proxy protocol header to mask backend");
false
}
}
}
async fn consume_client_data_with_timeout<R>(reader: R)
where
R: AsyncRead + Unpin,
{
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader))
.await
.is_err()
{
debug!("Timed out while consuming client data on masking fallback path");
}
}
async fn wait_mask_connect_budget(started: Instant) {
let elapsed = started.elapsed();
if elapsed < MASK_TIMEOUT {
tokio::time::sleep(MASK_TIMEOUT - elapsed).await;
}
}
fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
if config.censorship.mask_timing_normalization_enabled {
let floor = config.censorship.mask_timing_normalization_floor_ms;
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
if ceiling > floor {
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(floor..=ceiling));
}
return Duration::from_millis(floor);
}
MASK_TIMEOUT
}
async fn wait_mask_connect_budget_if_needed(started: Instant, config: &ProxyConfig) {
if config.censorship.mask_timing_normalization_enabled {
return;
}
wait_mask_connect_budget(started).await;
}
async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
let target = mask_outcome_target_budget(config);
let elapsed = started.elapsed();
if elapsed < target {
tokio::time::sleep(target - elapsed).await;
}
}
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4
&& (data.starts_with(b"GET ") || data.starts_with(b"POST") ||
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS"))
&& (data.starts_with(b"GET ")
|| data.starts_with(b"POST")
|| data.starts_with(b"HEAD")
|| data.starts_with(b"PUT ")
|| data.starts_with(b"DELETE")
|| data.starts_with(b"OPTIONS"))
{
return "HTTP";
}
@ -49,6 +248,33 @@ fn detect_client_type(data: &[u8]) -> &'static str {
"unknown"
}
fn build_mask_proxy_header(
version: u8,
peer: SocketAddr,
local_addr: SocketAddr,
) -> Option<Vec<u8>> {
match version {
0 => None,
2 => Some(
ProxyProtocolV2Builder::new()
.with_addrs(peer, local_addr)
.build(),
),
_ => {
let header = match (peer, local_addr) {
(SocketAddr::V4(src), SocketAddr::V4(dst)) => ProxyProtocolV1Builder::new()
.tcp4(src.into(), dst.into())
.build(),
(SocketAddr::V6(src), SocketAddr::V6(dst)) => ProxyProtocolV1Builder::new()
.tcp6(src.into(), dst.into())
.build(),
_ => ProxyProtocolV1Builder::new().build(),
};
Some(header)
}
}
}
/// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client<R, W>(
reader: R,
@ -58,8 +284,7 @@ pub async fn handle_bad_client<R, W>(
local_addr: SocketAddr,
config: &ProxyConfig,
beobachten: &BeobachtenStore,
)
where
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
@ -71,13 +296,15 @@ where
if !config.censorship.mask {
// Masking disabled, just consume data
consume_client_data(reader).await;
consume_client_data_with_timeout(reader).await;
return;
}
// Connect via Unix socket or TCP
#[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
let outcome_started = Instant::now();
let connect_started = Instant::now();
debug!(
client_type = client_type,
sock = %sock_path,
@ -89,45 +316,59 @@ where
match connect_result {
Ok(Ok(stream)) => {
let (mask_read, mut mask_write) = stream.into_split();
let proxy_header: Option<Vec<u8>> = match config.censorship.mask_proxy_protocol {
0 => None,
version => {
let header = match version {
2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(),
_ => match (peer, local_addr) {
(SocketAddr::V4(src), SocketAddr::V4(dst)) =>
ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(),
(SocketAddr::V6(src), SocketAddr::V6(dst)) =>
ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(),
_ =>
ProxyProtocolV1Builder::new().build(),
},
};
Some(header)
}
};
if let Some(header) = proxy_header {
if mask_write.write_all(&header).await.is_err() {
let proxy_header = build_mask_proxy_header(
config.censorship.mask_proxy_protocol,
peer,
local_addr,
);
if let Some(header) = proxy_header
&& !write_proxy_header_with_timeout(&mut mask_write, &header).await
{
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
}
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
if timeout(
MASK_RELAY_TIMEOUT,
relay_to_mask(
reader,
writer,
mask_read,
mask_write,
initial_data,
config.censorship.mask_shape_hardening,
config.censorship.mask_shape_bucket_floor_bytes,
config.censorship.mask_shape_bucket_cap_bytes,
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
),
)
.await
.is_err()
{
debug!("Mask relay timed out (unix socket)");
}
wait_mask_outcome_budget(outcome_started, config).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data(reader).await;
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data(reader).await;
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
return;
}
let mask_host = config.censorship.mask_host.as_deref()
let mask_host = config
.censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
@ -143,44 +384,54 @@ where
let mask_addr = resolve_socket_addr(mask_host, mask_port)
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
Ok(Ok(stream)) => {
let proxy_header: Option<Vec<u8>> = match config.censorship.mask_proxy_protocol {
0 => None,
version => {
let header = match version {
2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(),
_ => match (peer, local_addr) {
(SocketAddr::V4(src), SocketAddr::V4(dst)) =>
ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(),
(SocketAddr::V6(src), SocketAddr::V6(dst)) =>
ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(),
_ =>
ProxyProtocolV1Builder::new().build(),
},
};
Some(header)
}
};
let proxy_header =
build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr);
let (mask_read, mut mask_write) = stream.into_split();
if let Some(header) = proxy_header {
if mask_write.write_all(&header).await.is_err() {
if let Some(header) = proxy_header
&& !write_proxy_header_with_timeout(&mut mask_write, &header).await
{
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
}
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
if timeout(
MASK_RELAY_TIMEOUT,
relay_to_mask(
reader,
writer,
mask_read,
mask_write,
initial_data,
config.censorship.mask_shape_hardening,
config.censorship.mask_shape_bucket_floor_bytes,
config.censorship.mask_shape_bucket_cap_bytes,
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
),
)
.await
.is_err()
{
debug!("Mask relay timed out");
}
wait_mask_outcome_budget(outcome_started, config).await;
}
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data(reader).await;
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data(reader).await;
consume_client_data_with_timeout(reader).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
}
@ -192,8 +443,13 @@ async fn relay_to_mask<R, W, MR, MW>(
mut mask_read: MR,
mut mask_write: MW,
initial_data: &[u8],
)
where
shape_hardening_enabled: bool,
shape_bucket_floor_bytes: usize,
shape_bucket_cap_bytes: usize,
shape_above_cap_blur: bool,
shape_above_cap_blur_max_bytes: usize,
shape_hardening_aggressive_mode: bool,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
MR: AsyncRead + Unpin + Send + 'static,
@ -203,47 +459,36 @@ where
if mask_write.write_all(initial_data).await.is_err() {
return;
}
if mask_write.flush().await.is_err() {
return;
}
let (upstream_copy, downstream_copy) = tokio::join!(
async { copy_with_idle_timeout(&mut reader, &mut mask_write).await },
async { copy_with_idle_timeout(&mut mask_read, &mut writer).await }
);
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
let should_shape = shape_hardening_enabled
&& !initial_data.is_empty()
&& (upstream_copy.ended_by_eof
|| (shape_hardening_aggressive_mode && downstream_copy.total == 0));
maybe_write_shape_padding(
&mut mask_write,
total_sent,
should_shape,
shape_bucket_floor_bytes,
shape_bucket_cap_bytes,
shape_above_cap_blur,
shape_above_cap_blur_max_bytes,
shape_hardening_aggressive_mode,
)
.await;
// Relay traffic
let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
match reader.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await;
break;
}
Ok(n) => {
if mask_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
}
}
});
let m2c = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = writer.shutdown().await;
break;
}
Ok(n) => {
if writer.write_all(&buf[..n]).await.is_err() {
break;
}
}
}
}
});
// Wait for either to complete
tokio::select! {
_ = c2m => {}
_ = m2c => {}
}
}
/// Just consume all data from client without responding
@ -255,3 +500,51 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
}
}
}
#[cfg(test)]
#[path = "tests/masking_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/masking_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_shape_hardening_adversarial_tests.rs"]
mod masking_shape_hardening_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_shape_above_cap_blur_security_tests.rs"]
mod masking_shape_above_cap_blur_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_normalization_security_tests.rs"]
mod masking_timing_normalization_security_tests;
#[cfg(test)]
#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"]
mod masking_ab_envelope_blur_integration_security_tests;
#[cfg(test)]
#[path = "tests/masking_shape_guard_security_tests.rs"]
mod masking_shape_guard_security_tests;
#[cfg(test)]
#[path = "tests/masking_shape_guard_adversarial_tests.rs"]
mod masking_shape_guard_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_shape_classifier_resistance_adversarial_tests.rs"]
mod masking_shape_classifier_resistance_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_shape_bypass_blackhat_tests.rs"]
mod masking_shape_bypass_blackhat_tests;
#[cfg(test)]
#[path = "tests/masking_aggressive_mode_security_tests.rs"]
mod masking_aggressive_mode_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
mod masking_timing_sidechannel_redteam_expected_fail_tests;

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,71 @@
//! Proxy Defs
// Apply strict linting to proxy production code while keeping test builds noise-tolerant.
#![cfg_attr(test, allow(warnings))]
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
#![cfg_attr(
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
)]
#![cfg_attr(
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
)]
pub mod adaptive_buffers;
pub mod client;
pub mod direct_relay;
pub mod handshake;
pub mod masking;
pub mod middle_relay;
pub mod route_mode;
pub mod relay;
pub mod route_mode;
pub mod session_eviction;
pub use client::ClientHandler;

View File

@ -51,24 +51,19 @@
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
use crate::error::Result;
use crate::proxy::adaptive_buffers::{
self, AdaptiveTier, RelaySignalSample, SessionAdaptiveController, TierTransitionReason,
};
use crate::proxy::session_eviction::SessionLease;
use crate::error::{ProxyError, Result};
use crate::stats::Stats;
use crate::stream::BufferPool;
use dashmap::DashMap;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
// ============= Constants =============
@ -83,7 +78,11 @@ const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
/// 10 seconds gives responsive timeout detection (±10s accuracy)
/// without measurable overhead from atomic reads.
const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10);
const ADAPTIVE_TICK: Duration = Duration::from_millis(250);
#[inline]
fn watchdog_delta(current: u64, previous: u64) -> u64 {
current.saturating_sub(previous)
}
// ============= CombinedStream =============
@ -160,16 +159,6 @@ struct SharedCounters {
s2c_ops: AtomicU64,
/// Milliseconds since relay epoch of last I/O activity
last_activity_ms: AtomicU64,
/// Bytes requested to write to client (S→C direction).
s2c_requested_bytes: AtomicU64,
/// Total write operations for S→C direction.
s2c_write_ops: AtomicU64,
/// Number of partial writes to client.
s2c_partial_writes: AtomicU64,
/// Number of times S→C poll_write returned Pending.
s2c_pending_writes: AtomicU64,
/// Consecutive pending writes in S→C direction.
s2c_consecutive_pending_writes: AtomicU64,
}
impl SharedCounters {
@ -180,11 +169,6 @@ impl SharedCounters {
c2s_ops: AtomicU64::new(0),
s2c_ops: AtomicU64::new(0),
last_activity_ms: AtomicU64::new(0),
s2c_requested_bytes: AtomicU64::new(0),
s2c_write_ops: AtomicU64::new(0),
s2c_partial_writes: AtomicU64::new(0),
s2c_pending_writes: AtomicU64::new(0),
s2c_consecutive_pending_writes: AtomicU64::new(0),
}
}
@ -225,6 +209,12 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
quota_read_retry_active: Arc<AtomicBool>,
quota_write_retry_active: Arc<AtomicBool>,
epoch: Instant,
}
@ -234,11 +224,136 @@ impl<S> StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
Self { inner, counters, stats, user, epoch }
Self {
inner,
counters,
stats,
user,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
quota_read_retry_active: Arc::new(AtomicBool::new(false)),
quota_write_retry_active: Arc::new(AtomicBool::new(false)),
epoch,
}
}
}
impl<S> Drop for StatsIo<S> {
fn drop(&mut self) {
self.quota_read_retry_active.store(false, Ordering::Relaxed);
self.quota_write_retry_active
.store(false, Ordering::Relaxed);
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
impl std::fmt::Display for QuotaIoSentinel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("user data quota exceeded")
}
}
impl std::error::Error for QuotaIoSentinel {}
fn quota_io_error() -> io::Error {
io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel)
}
fn is_quota_io_error(err: &io::Error) -> bool {
err.kind() == io::ErrorKind::PermissionDenied
&& err
.get_ref()
.and_then(|source| source.downcast_ref::<QuotaIoSentinel>())
.is_some()
}
#[cfg(test)]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
fn spawn_quota_retry_waker(retry_active: Arc<AtomicBool>, waker: std::task::Waker) {
tokio::task::spawn(async move {
loop {
if !retry_active.load(Ordering::Relaxed) {
break;
}
tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await;
if !retry_active.load(Ordering::Relaxed) {
break;
}
waker.wake_by_ref();
}
});
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
@ -249,20 +364,82 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_read_wake_scheduled = false;
this.quota_read_retry_active.store(false, Ordering::Relaxed);
Some(guard)
}
Err(_) => {
if !this.quota_read_wake_scheduled {
this.quota_read_wake_scheduled = true;
this.quota_read_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_read_retry_active),
cx.waker().clone(),
);
}
return Poll::Pending;
}
}
} else {
None
};
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
let mut reached_quota_boundary = false;
if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = limit - used;
if (n as u64) > remaining {
// Fail closed: when a single read chunk would cross quota,
// stop relay immediately without accounting beyond the cap.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
reached_quota_boundary = (n as u64) == remaining;
}
// C→S: client sent data
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters
.c2s_bytes
.fetch_add(n as u64, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
if reached_quota_boundary {
this.quota_exceeded.store(true, Ordering::Relaxed);
}
trace!(user = %this.user, bytes = n, "C->S");
}
Poll::Ready(Ok(()))
@ -279,43 +456,81 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.counters
.s2c_requested_bytes
.fetch_add(buf.len() as u64, Ordering::Relaxed);
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
this.counters.s2c_write_ops.fetch_add(1, Ordering::Relaxed);
this.counters
.s2c_consecutive_pending_writes
.store(0, Ordering::Relaxed);
if n < buf.len() {
this.counters
.s2c_partial_writes
.fetch_add(1, Ordering::Relaxed);
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_write_wake_scheduled = false;
this.quota_write_retry_active
.store(false, Ordering::Relaxed);
Some(guard)
}
Err(_) => {
if !this.quota_write_wake_scheduled {
this.quota_write_wake_scheduled = true;
this.quota_write_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_write_retry_active),
cx.waker().clone(),
);
}
return Poll::Pending;
}
}
} else {
None
};
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = (limit - used) as usize;
if buf.len() > remaining {
// Fail closed: do not emit partial S->C payload when remaining
// quota cannot accommodate the pending write request.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
buf
} else {
buf
};
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
// S→C: data written to client
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters
.s2c_bytes
.fetch_add(n as u64, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
trace!(user = %this.user, bytes = n, "S->C");
}
Poll::Ready(Ok(n))
}
Poll::Pending => {
this.counters
.s2c_pending_writes
.fetch_add(1, Ordering::Relaxed);
this.counters
.s2c_consecutive_pending_writes
.fetch_add(1, Ordering::Relaxed);
Poll::Pending
}
other => other,
}
}
@ -348,7 +563,8 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
/// - Per-user stats: bytes and ops counted per direction
/// - Periodic rate logging: every 10 seconds when active
/// - Clean shutdown: both write sides are shut down on exit
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`,
/// other I/O failures are returned as `ProxyError::Io`
pub async fn relay_bidirectional<CR, CW, SR, SW>(
client_reader: CR,
client_writer: CW,
@ -357,11 +573,9 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
c2s_buf_size: usize,
s2c_buf_size: usize,
user: &str,
dc_idx: i16,
stats: Arc<Stats>,
quota_limit: Option<u64>,
_buffer_pool: Arc<BufferPool>,
session_lease: SessionLease,
seed_tier: AdaptiveTier,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
@ -371,6 +585,7 @@ where
{
let epoch = Instant::now();
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let user_owned = user.to_string();
// ── Combine split halves into bidirectional streams ──────────────
@ -383,43 +598,31 @@ where
Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
quota_limit,
Arc::clone(&quota_exceeded),
epoch,
);
// ── Watchdog: activity timeout + periodic rate logging ──────────
let wd_counters = Arc::clone(&counters);
let wd_user = user_owned.clone();
let wd_dc = dc_idx;
let wd_stats = Arc::clone(&stats);
let wd_session = session_lease.clone();
let wd_quota_exceeded = Arc::clone(&quota_exceeded);
let watchdog = async {
let mut prev_c2s_log: u64 = 0;
let mut prev_s2c_log: u64 = 0;
let mut prev_c2s_sample: u64 = 0;
let mut prev_s2c_requested_sample: u64 = 0;
let mut prev_s2c_written_sample: u64 = 0;
let mut prev_s2c_write_ops_sample: u64 = 0;
let mut prev_s2c_partial_sample: u64 = 0;
let mut accumulated_log = Duration::ZERO;
let mut adaptive = SessionAdaptiveController::new(seed_tier);
let mut prev_c2s: u64 = 0;
let mut prev_s2c: u64 = 0;
loop {
tokio::time::sleep(ADAPTIVE_TICK).await;
if wd_session.is_stale() {
wd_stats.increment_reconnect_stale_close_total();
warn!(
user = %wd_user,
dc = wd_dc,
"Session evicted by reconnect"
);
return;
}
tokio::time::sleep(WATCHDOG_INTERVAL).await;
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
if wd_quota_exceeded.load(Ordering::Relaxed) {
warn!(user = %wd_user, "User data quota reached, closing relay");
return;
}
// ── Activity timeout ────────────────────────────────────
if idle >= ACTIVITY_TIMEOUT {
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
@ -434,80 +637,11 @@ where
return; // Causes select! to cancel copy_bidirectional
}
let c2s_total = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let s2c_requested_total = wd_counters
.s2c_requested_bytes
.load(Ordering::Relaxed);
let s2c_written_total = wd_counters.s2c_bytes.load(Ordering::Relaxed);
let s2c_write_ops_total = wd_counters
.s2c_write_ops
.load(Ordering::Relaxed);
let s2c_partial_total = wd_counters
.s2c_partial_writes
.load(Ordering::Relaxed);
let consecutive_pending = wd_counters
.s2c_consecutive_pending_writes
.load(Ordering::Relaxed) as u32;
let sample = RelaySignalSample {
c2s_bytes: c2s_total.saturating_sub(prev_c2s_sample),
s2c_requested_bytes: s2c_requested_total
.saturating_sub(prev_s2c_requested_sample),
s2c_written_bytes: s2c_written_total
.saturating_sub(prev_s2c_written_sample),
s2c_write_ops: s2c_write_ops_total
.saturating_sub(prev_s2c_write_ops_sample),
s2c_partial_writes: s2c_partial_total
.saturating_sub(prev_s2c_partial_sample),
s2c_consecutive_pending_writes: consecutive_pending,
};
if let Some(transition) = adaptive.observe(sample, ADAPTIVE_TICK.as_secs_f64()) {
match transition.reason {
TierTransitionReason::SoftConfirmed => {
wd_stats.increment_relay_adaptive_promotions_total();
}
TierTransitionReason::HardPressure => {
wd_stats.increment_relay_adaptive_promotions_total();
wd_stats.increment_relay_adaptive_hard_promotions_total();
}
TierTransitionReason::QuietDemotion => {
wd_stats.increment_relay_adaptive_demotions_total();
}
}
adaptive_buffers::record_user_tier(&wd_user, adaptive.max_tier_seen());
debug!(
user = %wd_user,
dc = wd_dc,
from_tier = transition.from.as_u8(),
to_tier = transition.to.as_u8(),
reason = ?transition.reason,
throughput_ema_bps = sample
.c2s_bytes
.max(sample.s2c_written_bytes)
.saturating_mul(8)
.saturating_mul(4),
"Adaptive relay tier transition"
);
}
prev_c2s_sample = c2s_total;
prev_s2c_requested_sample = s2c_requested_total;
prev_s2c_written_sample = s2c_written_total;
prev_s2c_write_ops_sample = s2c_write_ops_total;
prev_s2c_partial_sample = s2c_partial_total;
accumulated_log = accumulated_log.saturating_add(ADAPTIVE_TICK);
if accumulated_log < WATCHDOG_INTERVAL {
continue;
}
accumulated_log = Duration::ZERO;
// ── Periodic rate logging ───────────────────────────────
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
let c2s_delta = c2s.saturating_sub(prev_c2s_log);
let s2c_delta = s2c.saturating_sub(prev_s2c_log);
let c2s_delta = watchdog_delta(c2s, prev_c2s);
let s2c_delta = watchdog_delta(s2c, prev_s2c);
if c2s_delta > 0 || s2c_delta > 0 {
let secs = WATCHDOG_INTERVAL.as_secs_f64();
@ -521,8 +655,8 @@ where
);
}
prev_c2s_log = c2s;
prev_s2c_log = s2c;
prev_c2s = c2s;
prev_s2c = s2c;
}
};
@ -557,7 +691,6 @@ where
let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed);
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
let duration = epoch.elapsed();
adaptive_buffers::record_user_tier(&user_owned, seed_tier);
match copy_result {
Some(Ok((c2s, s2c))) => {
@ -573,6 +706,22 @@ where
);
Ok(())
}
Some(Err(e)) if is_quota_io_error(&e) => {
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
warn!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Data quota reached, closing relay"
);
Err(ProxyError::DataQuotaExceeded {
user: user_owned.clone(),
})
}
Some(Err(e)) => {
// I/O error in one of the directions
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
@ -606,3 +755,39 @@ where
}
}
}
#[cfg(test)]
#[path = "tests/relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/relay_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
mod relay_quota_lock_pressure_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_boundary_blackhat_tests.rs"]
mod relay_quota_boundary_blackhat_tests;
#[cfg(test)]
#[path = "tests/relay_quota_model_adversarial_tests.rs"]
mod relay_quota_model_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_overflow_regression_tests.rs"]
mod relay_quota_overflow_regression_tests;
#[cfg(test)]
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
mod relay_watchdog_delta_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"]
mod relay_quota_waker_storm_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
mod relay_quota_wake_liveness_regression_tests;

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::watch;
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover";
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode {
}
impl RelayRouteMode {
pub(crate) fn as_u8(self) -> u8 {
self as u8
}
pub(crate) fn from_u8(value: u8) -> Self {
match value {
1 => Self::Middle,
_ => Self::Direct,
}
}
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Direct => "direct",
@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState {
#[derive(Clone)]
pub(crate) struct RouteRuntimeController {
mode: Arc<AtomicU8>,
generation: Arc<AtomicU64>,
direct_since_epoch_secs: Arc<AtomicU64>,
tx: watch::Sender<RouteCutoverState>,
}
@ -60,18 +47,13 @@ impl RouteRuntimeController {
0
};
Self {
mode: Arc::new(AtomicU8::new(initial_mode.as_u8())),
generation: Arc::new(AtomicU64::new(0)),
direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)),
tx,
}
}
pub(crate) fn snapshot(&self) -> RouteCutoverState {
RouteCutoverState {
mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)),
generation: self.generation.load(Ordering::Relaxed),
}
*self.tx.borrow()
}
pub(crate) fn subscribe(&self) -> watch::Receiver<RouteCutoverState> {
@ -84,9 +66,10 @@ impl RouteRuntimeController {
}
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed);
if previous == mode.as_u8() {
return None;
let mut next = None;
let changed = self.tx.send_if_modified(|state| {
if state.mode == mode {
return false;
}
if matches!(mode, RelayRouteMode::Direct) {
self.direct_since_epoch_secs
@ -94,10 +77,17 @@ impl RouteRuntimeController {
} else {
self.direct_since_epoch_secs.store(0, Ordering::Relaxed);
}
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
let next = RouteCutoverState { mode, generation };
self.tx.send_replace(next);
Some(next)
state.mode = mode;
state.generation = state.generation.saturating_add(1);
next = Some(*state);
true
});
if !changed {
return None;
}
next
}
}
@ -110,10 +100,10 @@ fn now_epoch_secs() -> u64 {
pub(crate) fn is_session_affected_by_cutover(
current: RouteCutoverState,
_session_mode: RelayRouteMode,
session_mode: RelayRouteMode,
session_generation: u64,
) -> bool {
current.generation > session_generation
current.generation > session_generation && current.mode != session_mode
}
pub(crate) fn affected_cutover_state(
@ -129,9 +119,7 @@ pub(crate) fn affected_cutover_state(
}
pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration {
let mut value = session_id
^ generation.rotate_left(17)
^ 0x9e37_79b9_7f4a_7c15;
let mut value = session_id ^ generation.rotate_left(17) ^ 0x9e37_79b9_7f4a_7c15;
value ^= value >> 30;
value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9);
value ^= value >> 27;
@ -140,3 +128,11 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio
let ms = 1000 + (value % 1000);
Duration::from_millis(ms)
}
#[cfg(test)]
#[path = "tests/route_mode_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/route_mode_coherence_adversarial_tests.rs"]
mod coherence_adversarial_tests;

View File

@ -1,3 +1,5 @@
#![allow(dead_code)]
/// Session eviction is intentionally disabled in runtime.
///
/// The initial `user+dc` single-lease model caused valid parallel client

View File

@ -0,0 +1,714 @@
use super::*;
use crate::config::ProxyConfig;
use crate::error::ProxyError;
use crate::ip_tracker::UserIpTracker;
use crate::stats::Stats;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
// ------------------------------------------------------------------
// Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6)
// ------------------------------------------------------------------
#[tokio::test]
async fn client_stress_10k_connections_limit_strict() {
let user = "stress-user";
let limit = 512;
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
let iterations = 1000;
let mut tasks = Vec::new();
for i in 0..iterations {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
let user_str = user.to_string();
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)),
10000 + (i % 1000) as u16,
);
match RunningClientHandler::acquire_user_connection_reservation_static(
&user_str, &config, stats, peer, ip_tracker,
)
.await
{
Ok(res) => Ok(res),
Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}));
}
let results = futures::future::join_all(tasks).await;
let mut successes = 0;
let mut failures = 0;
let mut reservations = Vec::new();
for res in results {
match res.unwrap() {
Ok(r) => {
successes += 1;
reservations.push(r);
}
Err(_) => failures += 1,
}
}
assert_eq!(successes, limit, "Should allow exactly 'limit' connections");
assert_eq!(
failures,
iterations - limit,
"Should fail the rest with LimitExceeded"
);
assert_eq!(stats.get_user_curr_connects(user), limit as u64);
drop(reservations);
ip_tracker.drain_cleanup_queue().await;
assert_eq!(
stats.get_user_curr_connects(user),
0,
"Stats must converge to 0 after all drops"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"IP tracker must converge to 0"
);
}
// ------------------------------------------------------------------
// Priority 3: IP Tracker Race Stress
// ------------------------------------------------------------------
#[tokio::test]
async fn client_ip_tracker_race_condition_stress() {
let user = "race-user";
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 100).await;
let iterations = 1000;
let mut tasks = Vec::new();
for i in 0..iterations {
let ip_tracker = Arc::clone(&ip_tracker);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8));
tasks.push(tokio::spawn(async move {
for _ in 0..10 {
if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await {
ip_tracker.remove_ip("race-user", ip).await;
}
}
}));
}
futures::future::join_all(tasks).await;
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"IP count must be zero after balanced add/remove burst"
);
}
#[tokio::test]
async fn client_limit_burst_peak_never_exceeds_cap() {
let user = "peak-cap-user";
let limit = 32;
let attempts = 256;
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
let peak = Arc::new(AtomicU64::new(0));
let mut tasks = Vec::with_capacity(attempts);
for i in 0..attempts {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
let peak = Arc::clone(&peak);
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 250 + 1) as u8)),
20000 + i as u16,
);
let acquired = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker,
)
.await;
if let Ok(reservation) = acquired {
let now = stats.get_user_curr_connects(user);
loop {
let prev = peak.load(Ordering::Relaxed);
if now <= prev {
break;
}
if peak
.compare_exchange(prev, now, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
break;
}
}
tokio::time::sleep(Duration::from_millis(2)).await;
drop(reservation);
}
}));
}
futures::future::join_all(tasks).await;
ip_tracker.drain_cleanup_queue().await;
assert!(
peak.load(Ordering::Relaxed) <= limit as u64,
"peak concurrent reservations must not exceed configured cap"
);
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_quota_rejection_never_mutates_live_counters() {
let user = "quota-reject-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default();
config.access.user_data_quota.insert(user.to_string(), 0);
let peer: SocketAddr = "198.51.100.201:31111".parse().unwrap();
let res = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(matches!(res, Err(ProxyError::DataQuotaExceeded { .. })));
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_expiration_rejection_never_mutates_live_counters() {
let user = "expired-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default();
config.access.user_expirations.insert(
user.to_string(),
chrono::Utc::now() - chrono::Duration::seconds(1),
);
let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap();
let res = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(matches!(res, Err(ProxyError::UserExpired { .. })));
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_ip_limit_failure_rolls_back_counter_exactly() {
let user = "ip-limit-rollback-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 16);
let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap();
let first = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
first_peer,
ip_tracker.clone(),
)
.await
.unwrap();
let second_peer: SocketAddr = "198.51.100.204:31114".parse().unwrap();
let second = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
second_peer,
ip_tracker.clone(),
)
.await;
assert!(matches!(
second,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
assert_eq!(stats.get_user_curr_connects(user), 1);
drop(first);
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_parallel_limit_checks_success_path_leaves_no_residue() {
let user = "parallel-check-success-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 128).await;
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 128);
let mut tasks = Vec::new();
for i in 0..128u16 {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(10, 10, (i / 255) as u8, (i % 255 + 1) as u8)),
32000 + i,
);
RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker)
.await
}));
}
for result in futures::future::join_all(tasks).await {
assert!(result.unwrap().is_ok());
}
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_parallel_limit_checks_failure_path_leaves_no_residue() {
let user = "parallel-check-failure-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 0).await;
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 512);
let mut tasks = Vec::new();
for i in 0..64u16 {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)),
33000 + i,
);
RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker)
.await
}));
}
let mut _denied = 0usize;
for result in futures::future::join_all(tasks).await {
match result.unwrap() {
Ok(()) => {}
Err(ProxyError::ConnectionLimitExceeded { .. }) => _denied += 1,
Err(other) => panic!("unexpected error: {other}"),
}
}
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_churn_mixed_success_failure_converges_to_zero_state() {
let user = "mixed-churn-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 4).await;
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 8);
let mut tasks = Vec::new();
for i in 0..200u16 {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(192, 0, 2, (i % 16 + 1) as u8)),
34000 + (i % 32),
);
let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker,
)
.await;
if let Ok(reservation) = maybe_res {
tokio::time::sleep(Duration::from_millis((i % 3) as u64)).await;
drop(reservation);
}
}));
}
futures::future::join_all(tasks).await;
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one() {
let user = "same-ip-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
let peer: SocketAddr = "203.0.113.44:35555".parse().unwrap();
let mut tasks = Vec::new();
for _ in 0..64 {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
tasks.push(tokio::spawn(async move {
RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker,
)
.await
}));
}
let mut granted = 0usize;
let mut reservations = Vec::new();
for result in futures::future::join_all(tasks).await {
match result.unwrap() {
Ok(reservation) => {
granted += 1;
reservations.push(reservation);
}
Err(ProxyError::ConnectionLimitExceeded { .. }) => {}
Err(other) => panic!("unexpected error: {other}"),
}
}
assert_eq!(
granted, 1,
"only one reservation may be granted for same IP with limit=1"
);
drop(reservations);
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_repeat_acquire_release_cycles_never_accumulate_state() {
let user = "repeat-cycle-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 32).await;
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 32);
for i in 0..500u16 {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(198, 18, (i / 250) as u8, (i % 250 + 1) as u8)),
36000 + (i % 128),
);
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await
.unwrap();
drop(reservation);
}
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_multi_user_isolation_under_parallel_limit_exhaustion() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert("u1".to_string(), 8);
config.access.user_max_tcp_conns.insert("u2".to_string(), 8);
let mut tasks = Vec::new();
for i in 0..128u16 {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
tasks.push(tokio::spawn(async move {
let user = if i % 2 == 0 { "u1" } else { "u2" };
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(100, 64, (i / 64) as u8, (i % 64 + 1) as u8)),
37000 + i,
);
RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker,
)
.await
}));
}
let mut u1_success = 0usize;
let mut u2_success = 0usize;
let mut reservations = Vec::new();
for (idx, result) in futures::future::join_all(tasks)
.await
.into_iter()
.enumerate()
{
let user = if idx % 2 == 0 { "u1" } else { "u2" };
match result.unwrap() {
Ok(reservation) => {
if user == "u1" {
u1_success += 1;
} else {
u2_success += 1;
}
reservations.push(reservation);
}
Err(ProxyError::ConnectionLimitExceeded { .. }) => {}
Err(other) => panic!("unexpected error: {other}"),
}
}
assert_eq!(u1_success, 8, "u1 must get exactly its own configured cap");
assert_eq!(u2_success, 8, "u2 must get exactly its own configured cap");
drop(reservations);
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects("u1"), 0);
assert_eq!(stats.get_user_curr_connects("u2"), 0);
}
#[tokio::test]
async fn client_limit_recovery_after_full_rejection_wave() {
let user = "recover-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
let first_peer: SocketAddr = "198.51.100.50:38001".parse().unwrap();
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
first_peer,
ip_tracker.clone(),
)
.await
.unwrap();
for i in 0..64u16 {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(198, 51, 100, (i % 60 + 1) as u8)),
38002 + i,
);
let denied = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(matches!(
denied,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
}
drop(reservation);
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
let recovery_peer: SocketAddr = "198.51.100.200:38999".parse().unwrap();
let recovered = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
recovery_peer,
ip_tracker.clone(),
)
.await;
assert!(
recovered.is_ok(),
"capacity must recover after prior holder drops"
);
}
#[tokio::test]
async fn client_dual_limit_cross_product_never_leaks_on_reject() {
let user = "dual-limit-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 2).await;
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 2);
let p1: SocketAddr = "203.0.113.10:39001".parse().unwrap();
let p2: SocketAddr = "203.0.113.11:39002".parse().unwrap();
let r1 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
p1,
ip_tracker.clone(),
)
.await
.unwrap();
let r2 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
p2,
ip_tracker.clone(),
)
.await
.unwrap();
for i in 0..32u16 {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(203, 0, 113, (50 + i) as u8)),
39010 + i,
);
let denied = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(matches!(
denied,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
}
assert_eq!(stats.get_user_curr_connects(user), 2);
drop((r1, r2));
ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn client_check_user_limits_concurrent_churn_no_counter_drift() {
let user = "check-drift-user";
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 64).await;
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 64);
let mut tasks = Vec::new();
for i in 0..512u16 {
let stats = Arc::clone(&stats);
let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone();
tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(172, 20, (i / 255) as u8, (i % 255 + 1) as u8)),
40000 + (i % 500),
);
let _ = RunningClientHandler::check_user_limits_static(
user,
&config,
&stats,
peer,
&ip_tracker,
)
.await;
}));
}
for task in futures::future::join_all(tasks).await {
task.unwrap();
}
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}

View File

@ -0,0 +1,126 @@
use super::*;
const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60;
#[test]
fn beobachten_ttl_exact_upper_bound_is_preserved() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES;
let ttl = beobachten_ttl(&config);
assert_eq!(
ttl,
Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60),
"upper-bound TTL should remain unchanged"
);
}
#[test]
fn beobachten_ttl_above_upper_bound_is_clamped() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1;
let ttl = beobachten_ttl(&config);
assert_eq!(
ttl,
Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60),
"TTL above security cap must be clamped"
);
}
#[test]
fn beobachten_ttl_u64_max_is_clamped_fail_safe() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = u64::MAX;
let ttl = beobachten_ttl(&config);
assert_eq!(
ttl,
Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60),
"extreme configured TTL must not become multi-century retention"
);
}
#[test]
fn positive_one_minute_maps_to_exact_60_seconds() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
assert_eq!(beobachten_ttl(&config), Duration::from_secs(60));
}
#[test]
fn adversarial_boundary_triplet_behaves_deterministically() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES - 1;
assert_eq!(
beobachten_ttl(&config),
Duration::from_secs((BEOBACHTEN_TTL_MAX_MINUTES - 1) * 60)
);
config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES;
assert_eq!(
beobachten_ttl(&config),
Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)
);
config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1;
assert_eq!(
beobachten_ttl(&config),
Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)
);
}
#[test]
fn light_fuzz_random_minutes_match_fail_safe_model() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
let mut seed = 0xD15E_A5E5_F00D_BAADu64;
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
config.general.beobachten_minutes = seed;
let ttl = beobachten_ttl(&config);
let expected = if seed == 0 {
Duration::from_secs(60)
} else {
Duration::from_secs(seed.min(BEOBACHTEN_TTL_MAX_MINUTES) * 60)
};
assert_eq!(ttl, expected, "ttl mismatch for minutes={seed}");
assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60));
}
}
#[test]
fn stress_monotonic_minutes_remain_monotonic_until_cap_then_flat() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
let mut prev = Duration::from_secs(0);
for minutes in 0..=(BEOBACHTEN_TTL_MAX_MINUTES + 4096) {
config.general.beobachten_minutes = minutes;
let ttl = beobachten_ttl(&config);
assert!(ttl >= prev, "ttl must be non-decreasing as minutes grow");
assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60));
if minutes > BEOBACHTEN_TTL_MAX_MINUTES {
assert_eq!(
ttl,
Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60),
"ttl must stay clamped once cap is exceeded"
);
}
prev = ttl;
}
}

View File

@ -0,0 +1,904 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{
HANDSHAKE_LEN, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE, TLS_RECORD_APPLICATION,
TLS_VERSION,
};
use crate::protocol::tls;
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
struct CampaignHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn build_mask_harness(secret_hex: &str, mask_port: u16) -> CampaignHarness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
CampaignHarness {
config,
stats: stats.clone(),
upstream_manager: new_upstream_manager(stats),
replay_checker: Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_record(record_type: u8, payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(record_type);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
wrap_tls_record(TLS_RECORD_APPLICATION, payload)
}
async fn read_and_discard_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
async fn run_tls_success_mtproto_fail_capture(
harness: CampaignHarness,
peer: SocketAddr,
client_hello: Vec<u8>,
bad_mtproto_record: Vec<u8>,
trailing_records: Vec<Vec<u8>>,
expected_forward: Vec<u8>,
) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = (*harness.config).clone();
cfg.censorship.mask_port = backend_addr.port();
let cfg = Arc::new(cfg);
let expected = expected_forward.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
got
});
let (server_side, mut client_side) = duplex(262144);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
cfg,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5];
client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side.write_all(&bad_mtproto_record).await.unwrap();
for record in trailing_records {
client_side.write_all(&record).await.unwrap();
}
let got = tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, expected_forward);
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}
async fn run_invalid_tls_capture(config: Arc<ProxyConfig>, payload: Vec<u8>, expected: Vec<u8>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = (*config).clone();
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
let cfg = Arc::new(cfg);
let expected_probe = expected.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected_probe.len()];
stream.read_exact(&mut got).await.unwrap();
got
});
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.77:45001".parse().unwrap(),
cfg,
stats,
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(&payload).await.unwrap();
client_side.shutdown().await.unwrap();
let got = tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, expected);
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}
#[tokio::test]
async fn blackhat_campaign_01_tail_only_record_is_forwarded_after_tls_success_mtproto_fail() {
let secret = [0xA1u8; 16];
let harness = build_mask_harness("a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1", 1);
let client_hello = make_valid_tls_client_hello(&secret, 11, 600, 0x41);
let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let tail = wrap_tls_application_data(b"blackhat-tail-01");
run_tls_success_mtproto_fail_capture(
harness,
"198.51.100.1:55001".parse().unwrap(),
client_hello,
bad_record,
vec![tail.clone()],
tail,
)
.await;
}
#[tokio::test]
async fn blackhat_campaign_02_two_ordered_records_preserved_after_fallback() {
let secret = [0xA2u8; 16];
let harness = build_mask_harness("a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2", 1);
let client_hello = make_valid_tls_client_hello(&secret, 12, 600, 0x42);
let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let r1 = wrap_tls_application_data(b"first");
let r2 = wrap_tls_application_data(b"second");
let expected = [r1.clone(), r2.clone()].concat();
run_tls_success_mtproto_fail_capture(
harness,
"198.51.100.2:55002".parse().unwrap(),
client_hello,
bad_record,
vec![r1, r2],
expected,
)
.await;
}
#[tokio::test]
async fn blackhat_campaign_03_large_tls_application_record_survives_fallback() {
let secret = [0xA3u8; 16];
let harness = build_mask_harness("a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3", 1);
let client_hello = make_valid_tls_client_hello(&secret, 13, 600, 0x43);
let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let big_payload = vec![0x5Au8; MAX_TLS_PLAINTEXT_SIZE];
let big_record = wrap_tls_application_data(&big_payload);
run_tls_success_mtproto_fail_capture(
harness,
"198.51.100.3:55003".parse().unwrap(),
client_hello,
bad_record,
vec![big_record.clone()],
big_record,
)
.await;
}
#[tokio::test]
async fn blackhat_campaign_04_coalesced_tail_in_failed_record_is_reframed_and_forwarded() {
let secret = [0xA4u8; 16];
let harness = build_mask_harness("a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4", 1);
let client_hello = make_valid_tls_client_hello(&secret, 14, 600, 0x44);
let coalesced_tail = b"coalesced-tail-blackhat".to_vec();
let mut bad_payload = vec![0u8; HANDSHAKE_LEN];
bad_payload.extend_from_slice(&coalesced_tail);
let bad_record = wrap_tls_application_data(&bad_payload);
let expected = wrap_tls_application_data(&coalesced_tail);
run_tls_success_mtproto_fail_capture(
harness,
"198.51.100.4:55004".parse().unwrap(),
client_hello,
bad_record,
Vec::new(),
expected,
)
.await;
}
#[tokio::test]
async fn blackhat_campaign_05_coalesced_tail_plus_next_record_keep_wire_order() {
let secret = [0xA5u8; 16];
let harness = build_mask_harness("a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5", 1);
let client_hello = make_valid_tls_client_hello(&secret, 15, 600, 0x45);
let coalesced_tail = b"inline-tail".to_vec();
let mut bad_payload = vec![0u8; HANDSHAKE_LEN];
bad_payload.extend_from_slice(&coalesced_tail);
let bad_record = wrap_tls_application_data(&bad_payload);
let next_record = wrap_tls_application_data(b"next-record");
let expected = [
wrap_tls_application_data(&coalesced_tail),
next_record.clone(),
]
.concat();
run_tls_success_mtproto_fail_capture(
harness,
"198.51.100.5:55005".parse().unwrap(),
client_hello,
bad_record,
vec![next_record],
expected,
)
.await;
}
#[tokio::test]
async fn blackhat_campaign_06_replayed_tls_hello_is_masked_without_serverhello() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let harness = build_mask_harness("a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6", backend_addr.port());
let replay_checker = harness.replay_checker.clone();
let client_hello = make_valid_tls_client_hello(&[0xA6; 16], 16, 600, 0x46);
let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let first_tail = wrap_tls_application_data(b"seed-tail");
let expected_hello = client_hello.clone();
let expected_tail = first_tail.clone();
let accept_task = tokio::spawn(async move {
let (mut s1, _) = listener.accept().await.unwrap();
let mut got_tail = vec![0u8; expected_tail.len()];
s1.read_exact(&mut got_tail).await.unwrap();
assert_eq!(got_tail, expected_tail);
drop(s1);
let (mut s2, _) = listener.accept().await.unwrap();
let mut got_hello = vec![0u8; expected_hello.len()];
s2.read_exact(&mut got_hello).await.unwrap();
got_hello
});
let run_one = |checker: Arc<ReplayChecker>, send_mtproto: bool| {
let mut cfg = (*harness.config).clone();
cfg.censorship.mask_port = backend_addr.port();
let cfg = Arc::new(cfg);
let hello = client_hello.clone();
let invalid_mtproto_record = invalid_mtproto_record.clone();
let first_tail = first_tail.clone();
let stats = harness.stats.clone();
let upstream = harness.upstream_manager.clone();
let pool = harness.buffer_pool.clone();
let rng = harness.rng.clone();
let route = harness.route_runtime.clone();
let ipt = harness.ip_tracker.clone();
let beob = harness.beobachten.clone();
async move {
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.6:55006".parse().unwrap(),
cfg,
stats,
upstream,
checker,
pool,
rng,
None,
route,
None,
ipt,
beob,
false,
));
client_side.write_all(&hello).await.unwrap();
if send_mtproto {
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap();
} else {
let mut one = [0u8; 1];
let no_server_hello = tokio::time::timeout(
Duration::from_millis(300),
client_side.read_exact(&mut one),
)
.await;
assert!(no_server_hello.is_err() || no_server_hello.unwrap().is_err());
}
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}
};
run_one(replay_checker.clone(), true).await;
run_one(replay_checker, false).await;
let got = tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, client_hello);
}
#[tokio::test]
async fn blackhat_campaign_07_truncated_clienthello_exact_prefix_is_forwarded() {
let mut payload = vec![0u8; 5 + 37];
payload[0] = 0x16;
payload[1] = 0x03;
payload[2] = 0x01;
payload[3..5].copy_from_slice(&600u16.to_be_bytes());
payload[5..].fill(0x71);
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), payload.clone(), payload).await;
}
#[tokio::test]
async fn blackhat_campaign_08_out_of_bounds_len_forwards_header_only() {
let header = vec![0x16, 0x03, 0x01, 0xFF, 0xFF];
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), header.clone(), header).await;
}
#[tokio::test]
async fn blackhat_campaign_09_fragmented_header_then_partial_body_masks_seen_bytes_only() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.censorship.mask = true;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_unix_sock = None;
let expected = {
let mut x = vec![0u8; 5 + 11];
x[0] = 0x16;
x[1] = 0x03;
x[2] = 0x01;
x[3..5].copy_from_slice(&600u16.to_be_bytes());
x[5..].fill(0xCC);
x
};
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
got
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.9:55009".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(&[0x16, 0x03]).await.unwrap();
client_side.write_all(&[0x01, 0x02, 0x58]).await.unwrap();
client_side.write_all(&vec![0xCC; 11]).await.unwrap();
client_side.shutdown().await.unwrap();
let got = tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got.len(), 16);
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}
#[tokio::test]
async fn blackhat_campaign_10_zero_handshake_timeout_with_delay_still_avoids_timeout_counter() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.timeouts.client_handshake = 0;
cfg.censorship.server_hello_delay_min_ms = 700;
cfg.censorship.server_hello_delay_max_ms = 700;
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let started = Instant::now();
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.10:55010".parse().unwrap(),
Arc::new(cfg),
stats.clone(),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut invalid = vec![0u8; 5 + 700];
invalid[0] = 0x16;
invalid[1] = 0x03;
invalid[2] = 0x01;
invalid[3..5].copy_from_slice(&700u16.to_be_bytes());
invalid[5..].fill(0x66);
client_side.write_all(&invalid).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
assert_eq!(stats.get_handshake_timeouts(), 0);
assert!(started.elapsed() >= Duration::from_millis(650));
}
#[tokio::test]
async fn blackhat_campaign_11_parallel_bad_tls_probes_all_masked_without_timeouts() {
let n = 24usize;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.censorship.mask = true;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_port = backend_addr.port();
let stats = Arc::new(Stats::new());
let accept_task = tokio::spawn(async move {
let mut seen = HashSet::new();
for _ in 0..n {
let (mut stream, _) = listener.accept().await.unwrap();
let mut hdr = [0u8; 5];
stream.read_exact(&mut hdr).await.unwrap();
seen.insert(hdr.to_vec());
}
seen
});
let mut tasks = Vec::new();
for i in 0..n {
let mut hdr = [0u8; 5];
hdr[0] = 0x16;
hdr[1] = 0x03;
hdr[2] = 0x01;
hdr[3] = 0xFF;
hdr[4] = i as u8;
let cfg = Arc::new(cfg.clone());
let stats = stats.clone();
tasks.push(tokio::spawn(async move {
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
format!("198.51.100.11:{}", 56000 + i).parse().unwrap(),
cfg,
stats,
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(&hdr).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
hdr.to_vec()
}));
}
let mut expected = HashSet::new();
for t in tasks {
expected.insert(t.await.unwrap());
}
let seen = tokio::time::timeout(Duration::from_secs(6), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(seen, expected);
assert_eq!(stats.get_handshake_timeouts(), 0);
}
#[tokio::test]
async fn blackhat_campaign_12_parallel_tls_success_mtproto_fail_sessions_keep_isolation() {
let sessions = 16usize;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut expected = HashSet::new();
for i in 0..sessions {
let rec = wrap_tls_application_data(&vec![i as u8; 8 + i]);
expected.insert(rec);
}
let accept_task = tokio::spawn(async move {
let mut got_set = HashSet::new();
for _ in 0..sessions {
let (mut stream, _) = listener.accept().await.unwrap();
let mut head = [0u8; 5];
stream.read_exact(&mut head).await.unwrap();
let len = u16::from_be_bytes([head[3], head[4]]) as usize;
let mut rec = vec![0u8; 5 + len];
rec[..5].copy_from_slice(&head);
stream.read_exact(&mut rec[5..]).await.unwrap();
got_set.insert(rec);
}
got_set
});
let mut tasks = Vec::new();
for i in 0..sessions {
let mut harness =
build_mask_harness("abababababababababababababababab", backend_addr.port());
let mut cfg = (*harness.config).clone();
cfg.censorship.mask_port = backend_addr.port();
harness.config = Arc::new(cfg);
tasks.push(tokio::spawn(async move {
let secret = [0xABu8; 16];
let hello =
make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10));
let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]);
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
format!("198.51.100.12:{}", 56100 + i).parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side.write_all(&bad).await.unwrap();
client_side.write_all(&tail).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(5), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
tail
}));
}
let mut produced = HashSet::new();
for t in tasks {
produced.insert(t.await.unwrap());
}
let observed = tokio::time::timeout(Duration::from_secs(8), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(produced, expected);
assert_eq!(observed, expected);
}
#[tokio::test]
async fn blackhat_campaign_13_backend_down_does_not_escalate_to_handshake_timeout() {
let mut cfg = ProxyConfig::default();
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.timeouts.client_handshake = 1;
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.13:55013".parse().unwrap(),
Arc::new(cfg),
stats.clone(),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let bad = vec![0x16, 0x03, 0x01, 0xFF, 0x00];
client_side.write_all(&bad).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
assert_eq!(stats.get_handshake_timeouts(), 0);
}
#[tokio::test]
async fn blackhat_campaign_14_masking_disabled_path_finishes_cleanly() {
let mut cfg = ProxyConfig::default();
cfg.censorship.mask = false;
cfg.timeouts.client_handshake = 1;
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.14:55014".parse().unwrap(),
Arc::new(cfg),
stats.clone(),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let bad = vec![0x16, 0x03, 0x01, 0xFF, 0xF0];
client_side.write_all(&bad).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
assert_eq!(stats.get_handshake_timeouts(), 0);
}
#[tokio::test]
async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() {
let mut seed = 0x9E3779B97F4A7C15u64;
for idx in 0..20u16 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let mut tls_len = (seed as usize) % 20000;
if idx % 3 == 0 {
tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024);
}
let body_to_send =
if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) {
(seed as usize % 29).min(tls_len.saturating_sub(1))
} else {
0
};
let mut probe = vec![0u8; 5 + body_to_send];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
for b in &mut probe[5..] {
seed = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
*b = (seed >> 24) as u8;
}
let expected = probe.clone();
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe, expected).await;
}
}
#[tokio::test]
async fn blackhat_campaign_16_mixed_probe_burst_stress_finishes_without_panics() {
let cases = 18usize;
let mut tasks = Vec::new();
for i in 0..cases {
tasks.push(tokio::spawn(async move {
if i % 2 == 0 {
let mut probe = vec![0u8; 5 + (i % 13)];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&600u16.to_be_bytes());
probe[5..].fill((0x90 + i as u8) ^ 0x5A);
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe)
.await;
} else {
let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8];
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await;
}
}));
}
for task in tasks {
task.await.unwrap();
}
}

View File

@ -0,0 +1,255 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
struct PipelineHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn build_harness(config: ProxyConfig) -> PipelineHarness {
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
PipelineHarness {
config,
stats,
upstream_manager,
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn read_and_discard_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
#[tokio::test]
async fn masking_runs_outside_handshake_timeout_budget_with_high_reject_delay() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1;
config.timeouts.client_handshake = 0;
config.censorship.server_hello_delay_min_ms = 730;
config.censorship.server_hello_delay_max_ms = 730;
let harness = build_harness(config);
let stats = harness.stats.clone();
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "198.51.100.241:56541".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
let mut invalid_hello = vec![0u8; 5 + 600];
invalid_hello[0] = 0x16;
invalid_hello[1] = 0x03;
invalid_hello[2] = 0x01;
invalid_hello[3..5].copy_from_slice(&600u16.to_be_bytes());
invalid_hello[5..].fill(0x44);
let started = Instant::now();
client_side.write_all(&invalid_hello).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(
result.is_ok(),
"bad-client fallback must not be canceled by handshake timeout"
);
assert_eq!(
stats.get_handshake_timeouts(),
0,
"masking fallback path must not increment handshake timeout counter"
);
assert!(
started.elapsed() >= Duration::from_millis(700),
"configured reject delay should still be visible before masking"
);
}
#[tokio::test]
async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0;
config.access.ignore_time_skew = true;
config.access.users.insert(
"user".to_string(),
"d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string(),
);
let harness = build_harness(config);
let secret = [0xD0u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 0, 600, 0x41);
let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let trailing_record = wrap_tls_application_data(b"no-clienthello-reinject");
let expected_trailing = trailing_record.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected_trailing.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(
got, expected_trailing,
"mask backend must receive only post-handshake trailing TLS records"
);
});
let (server_side, mut client_side) = duplex(131072);
let peer: SocketAddr = "198.51.100.242:56542".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5];
client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
drop(client_side);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}

View File

@ -0,0 +1,208 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn percentile_ms(mut values: Vec<u128>, p_num: usize, p_den: usize) -> u128 {
values.sort_unstable();
if values.is_empty() {
return 0;
}
let idx = ((values.len() - 1) * p_num) / p_den;
values[idx]
}
async fn measure_reject_duration_ms(body_sent: usize) -> u128 {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.timeouts.client_handshake = 1;
cfg.censorship.server_hello_delay_min_ms = 700;
cfg.censorship.server_hello_delay_max_ms = 700;
let (server_side, mut client_side) = duplex(65536);
let started = Instant::now();
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.170:56170".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&600u16.to_be_bytes());
probe[5..].fill(0xA7);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
started.elapsed().as_millis()
}
async fn capture_forwarded_len(body_sent: usize) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_shape_hardening = false;
cfg.timeouts.client_handshake = 1;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.171:56171".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&600u16.to_be_bytes());
probe[5..].fill(0xB4);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn diagnostic_timing_profiles_are_within_realistic_guardrails() {
let classes = [17usize, 511usize, 1023usize, 4095usize];
for class in classes {
let mut samples = Vec::new();
for _ in 0..8 {
samples.push(measure_reject_duration_ms(class).await);
}
let p50 = percentile_ms(samples.clone(), 50, 100);
let p95 = percentile_ms(samples.clone(), 95, 100);
let max = *samples.iter().max().unwrap();
println!(
"diagnostic_timing class={} p50={}ms p95={}ms max={}ms",
class, p50, p95, max
);
assert!(p50 >= 650, "p50 too low for delayed reject class={}", class);
assert!(
p95 <= 1200,
"p95 too high for delayed reject class={}",
class
);
assert!(
max <= 1500,
"max too high for delayed reject class={}",
class
);
}
}
#[tokio::test]
async fn diagnostic_forwarded_size_profiles_by_probe_class() {
let classes = [
0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize,
];
let mut observed = Vec::new();
for class in classes {
let len = capture_forwarded_len(class).await;
println!("diagnostic_shape class={} forwarded_len={}", class, len);
observed.push(len as u128);
assert_eq!(
len,
5 + class,
"unexpected forwarded len for class={}",
class
);
}
let p50 = percentile_ms(observed.clone(), 50, 100);
let p95 = percentile_ms(observed.clone(), 95, 100);
let max = *observed.iter().max().unwrap();
println!(
"diagnostic_shape_summary p50={}bytes p95={}bytes max={}bytes",
p50, p95, max
);
assert!(p95 >= p50);
assert!(max >= p95);
}

View File

@ -0,0 +1,767 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
struct Harness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn build_harness(secret_hex: &str, mask_port: u16) -> Harness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
Harness {
config,
stats: stats.clone(),
upstream_manager: new_upstream_manager(stats),
replay_checker: Arc::new(ReplayChecker::new(512, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn read_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
async fn run_tls_success_mtproto_fail_capture(
secret_hex: &str,
secret: [u8; 16],
timestamp: u32,
trailing_records: Vec<Vec<u8>>,
) -> Vec<u8> {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let expected_len = trailing_records.iter().map(Vec::len).sum::<usize>();
let expected_concat = trailing_records.concat();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected_len];
stream.read_exact(&mut got).await.unwrap();
got
});
let harness = build_harness(secret_hex, backend_addr.port());
let client_hello = make_valid_tls_client_hello(&secret, timestamp, 600, 0x42);
let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let (server_side, mut client_side) = duplex(262144);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.210:56010".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5];
client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16);
read_tls_record_body(&mut client_side, tls_response_head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for record in trailing_records {
client_side.write_all(&record).await.unwrap();
}
let got = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, expected_concat);
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
got
}
#[tokio::test]
async fn masking_budget_survives_zero_handshake_timeout_with_delay() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.timeouts.client_handshake = 0;
cfg.censorship.server_hello_delay_min_ms = 720;
cfg.censorship.server_hello_delay_max_ms = 720;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; 605];
stream.read_exact(&mut got).await.unwrap();
got
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.211:56011".parse().unwrap(),
config,
stats.clone(),
new_upstream_manager(stats.clone()),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut invalid_hello = vec![0u8; 605];
invalid_hello[0] = 0x16;
invalid_hello[1] = 0x03;
invalid_hello[2] = 0x01;
invalid_hello[3..5].copy_from_slice(&600u16.to_be_bytes());
invalid_hello[5..].fill(0xA1);
let started = Instant::now();
client_side.write_all(&invalid_hello).await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
assert_eq!(stats.get_handshake_timeouts(), 0);
assert!(started.elapsed() >= Duration::from_millis(680));
}
#[tokio::test]
async fn tls_mtproto_fail_forwards_only_trailing_record() {
let tail = wrap_tls_application_data(b"tail-only");
let got = run_tls_success_mtproto_fail_capture(
"c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1",
[0xC1; 16],
1,
vec![tail.clone()],
)
.await;
assert_eq!(got, tail);
}
#[tokio::test]
async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let harness = build_harness("c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2", backend_addr.port());
let secret = [0xC2u8; 16];
let hello = make_valid_tls_client_hello(&secret, 2, 600, 0x41);
let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let first_tail = wrap_tls_application_data(b"seed");
let expected_hello = hello.clone();
let expected_tail = first_tail.clone();
let accept_task = tokio::spawn(async move {
let (mut s1, _) = listener.accept().await.unwrap();
let mut got_tail = vec![0u8; expected_tail.len()];
s1.read_exact(&mut got_tail).await.unwrap();
assert_eq!(got_tail, expected_tail);
drop(s1);
let (mut s2, _) = listener.accept().await.unwrap();
let mut got_hello = vec![0u8; expected_hello.len()];
s2.read_exact(&mut got_hello).await.unwrap();
assert_eq!(got_hello, expected_hello);
});
let run_session = |send_mtproto: bool| {
let (server_side, mut client_side) = duplex(131072);
let config = harness.config.clone();
let stats = harness.stats.clone();
let upstream = harness.upstream_manager.clone();
let replay = harness.replay_checker.clone();
let pool = harness.buffer_pool.clone();
let rng = harness.rng.clone();
let route = harness.route_runtime.clone();
let ipt = harness.ip_tracker.clone();
let beob = harness.beobachten.clone();
let hello = hello.clone();
let invalid_mtproto_record = invalid_mtproto_record.clone();
let first_tail = first_tail.clone();
async move {
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.212:56012".parse().unwrap(),
config,
stats,
upstream,
replay,
pool,
rng,
None,
route,
None,
ipt,
beob,
false,
));
client_side.write_all(&hello).await.unwrap();
if send_mtproto {
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap();
} else {
let mut one = [0u8; 1];
let no_server_hello = tokio::time::timeout(
Duration::from_millis(300),
client_side.read_exact(&mut one),
)
.await;
assert!(no_server_hello.is_err() || no_server_hello.unwrap().is_err());
}
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}
};
run_session(true).await;
run_session(false).await;
tokio::time::timeout(Duration::from_secs(5), accept_task)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn connects_bad_increments_once_per_invalid_mtproto() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let harness = build_harness("c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", backend_addr.port());
let stats = harness.stats.clone();
let bad_before = stats.get_connects_bad();
let tail = wrap_tls_application_data(b"accounting");
let expected_tail = tail.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected_tail.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected_tail);
});
let hello = make_valid_tls_client_hello(&[0xC3; 16], 3, 600, 0x42);
let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.213:56013".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
read_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&tail).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert_eq!(stats.get_connects_bad(), bad_before + 1);
}
#[tokio::test]
async fn truncated_clienthello_forwards_only_seen_prefix() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_unix_sock = None;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let expected_prefix_len = 5 + 17;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected_prefix_len];
stream.read_exact(&mut got).await.unwrap();
got
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.214:56014".parse().unwrap(),
config,
stats,
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut hello = vec![0u8; 5 + 17];
hello[0] = 0x16;
hello[1] = 0x03;
hello[2] = 0x01;
hello[3..5].copy_from_slice(&600u16.to_be_bytes());
hello[5..].fill(0x55);
client_side.write_all(&hello).await.unwrap();
client_side.shutdown().await.unwrap();
let got = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, hello);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn out_of_bounds_tls_len_forwards_header_only() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_unix_sock = None;
let config = Arc::new(cfg);
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
got
});
let (server_side, mut client_side) = duplex(8192);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.215:56015".parse().unwrap(),
config,
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let hdr = [0x16, 0x03, 0x01, 0x42, 0x69];
client_side.write_all(&hdr).await.unwrap();
client_side.shutdown().await.unwrap();
let got = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, hdr);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn non_tls_with_modes_disabled_is_masked() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_unix_sock = None;
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
got
});
let (server_side, mut client_side) = duplex(8192);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.216:56016".parse().unwrap(),
config,
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let probe = *b"HELLO";
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let got = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, probe);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn concurrent_tls_mtproto_fail_sessions_are_isolated() {
let sessions = 12usize;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut expected = std::collections::HashSet::new();
for idx in 0..sessions {
let payload = vec![idx as u8; 32 + idx];
expected.insert(wrap_tls_application_data(&payload));
}
let accept_task = tokio::spawn(async move {
let mut remaining = expected;
for _ in 0..sessions {
let (mut stream, _) = listener.accept().await.unwrap();
let mut header = [0u8; 5];
stream.read_exact(&mut header).await.unwrap();
assert_eq!(header[0], TLS_RECORD_APPLICATION);
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut record = vec![0u8; 5 + len];
record[..5].copy_from_slice(&header);
stream.read_exact(&mut record[5..]).await.unwrap();
assert!(remaining.remove(&record));
}
assert!(remaining.is_empty());
});
let mut tasks = Vec::with_capacity(sessions);
for idx in 0..sessions {
let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4";
let harness = build_harness(secret_hex, backend_addr.port());
let hello =
make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8);
let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]);
let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16)
.parse()
.unwrap();
tasks.push(tokio::spawn(async move {
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
read_tls_record_body(&mut client_side, head).await;
client_side.write_all(&invalid_mtproto).await.unwrap();
client_side.write_all(&trailing).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}));
}
for task in tasks {
task.await.unwrap();
}
tokio::time::timeout(Duration::from_secs(6), accept_task)
.await
.unwrap()
.unwrap();
}
macro_rules! tail_length_case {
($name:ident, $hex:expr, $secret:expr, $ts:expr, $len:expr) => {
#[tokio::test]
async fn $name() {
let mut payload = vec![0u8; $len];
for (i, b) in payload.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(17).wrapping_add(5);
}
let record = wrap_tls_application_data(&payload);
let got =
run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()])
.await;
assert_eq!(got, record);
}
};
}
tail_length_case!(
tail_len_1_preserved,
"d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1",
[0xD1; 16],
30,
1
);
tail_length_case!(
tail_len_2_preserved,
"d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2",
[0xD2; 16],
31,
2
);
tail_length_case!(
tail_len_3_preserved,
"d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3",
[0xD3; 16],
32,
3
);
tail_length_case!(
tail_len_7_preserved,
"d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4",
[0xD4; 16],
33,
7
);
tail_length_case!(
tail_len_31_preserved,
"d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5",
[0xD5; 16],
34,
31
);
tail_length_case!(
tail_len_127_preserved,
"d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6",
[0xD6; 16],
35,
127
);
tail_length_case!(
tail_len_511_preserved,
"d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7",
[0xD7; 16],
36,
511
);
tail_length_case!(
tail_len_1023_preserved,
"d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8",
[0xD8; 16],
37,
1023
);

View File

@ -0,0 +1,358 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
fn make_test_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn masking_config(mask_port: u16) -> Arc<ProxyConfig> {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
Arc::new(cfg)
}
async fn run_generic_probe_and_capture_prefix(payload: Vec<u8>, expected_prefix: Vec<u8>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let reply = REPLY_404.to_vec();
let prefix_len = expected_prefix.len();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; prefix_len];
stream.read_exact(&mut got).await.unwrap();
stream.write_all(&reply).await.unwrap();
got
});
let config = masking_config(backend_addr.port());
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "203.0.113.210:55110".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&payload).await.unwrap();
client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404);
let got = tokio::time::timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, expected_prefix);
let result = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}
async fn read_http_probe_header(stream: &mut TcpStream) -> Vec<u8> {
let mut out = Vec::with_capacity(96);
let mut one = [0u8; 1];
loop {
stream.read_exact(&mut one).await.unwrap();
out.push(one[0]);
if out.ends_with(b"\r\n\r\n") {
break;
}
assert!(
out.len() <= 512,
"probe header exceeded sane limit while waiting for terminator"
);
}
out
}
#[tokio::test]
async fn blackhat_fragmented_plain_http_probe_masks_and_preserves_prefix() {
let payload = b"GET /probe-evasion HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
run_generic_probe_and_capture_prefix(payload.clone(), payload).await;
}
#[tokio::test]
async fn blackhat_invalid_tls_like_probe_masks_and_preserves_header_prefix() {
let payload = vec![0x16, 0x03, 0x03, 0x00, 0x64, 0x01, 0x00];
run_generic_probe_and_capture_prefix(payload.clone(), payload).await;
}
#[tokio::test]
async fn integration_client_handler_plain_probe_masks_and_preserves_prefix() {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = mask_listener.local_addr().unwrap();
let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let front_addr = front_listener.local_addr().unwrap();
let payload = b"GET /integration-probe HTTP/1.1\r\nHost: a.example\r\n\r\n".to_vec();
let expected_prefix = payload.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = vec![0u8; expected_prefix.len()];
stream.read_exact(&mut got).await.unwrap();
stream.write_all(REPLY_404).await.unwrap();
got
});
let config = masking_config(backend_addr.port());
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let server_task = {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let route_runtime = route_runtime.clone();
let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone();
tokio::spawn(async move {
let (stream, peer) = front_listener.accept().await.unwrap();
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
ClientHandler::new(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
real_peer_report,
)
.run()
.await
})
};
let mut client = TcpStream::connect(front_addr).await.unwrap();
client.write_all(&payload).await.unwrap();
client.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed))
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404);
let got = tokio::time::timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(got, payload);
let result = tokio::time::timeout(Duration::from_secs(2), server_task)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}
#[tokio::test]
async fn light_fuzz_small_probe_variants_always_mask_and_preserve_declared_prefix() {
let mut rng = StdRng::seed_from_u64(0xA11E_5EED_F0F0_CAFE);
for i in 0..24usize {
let mut payload = if rng.random::<bool>() {
b"GET /fuzz HTTP/1.1\r\nHost: fuzz.example\r\n\r\n".to_vec()
} else {
vec![0x16, 0x03, 0x03, 0x00, 0x64]
};
let tail_len = rng.random_range(0..=8usize);
for _ in 0..tail_len {
payload.push(rng.random::<u8>());
}
let expected_prefix = payload.clone();
run_generic_probe_and_capture_prefix(payload, expected_prefix).await;
if i % 6 == 0 {
tokio::task::yield_now().await;
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
let session_count = 12usize;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut expected = std::collections::HashSet::new();
for idx in 0..session_count {
let probe =
format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
expected.insert(probe);
}
let accept_task = tokio::spawn(async move {
let mut remaining = expected;
for _ in 0..session_count {
let (mut stream, _) = listener.accept().await.unwrap();
let head = read_http_probe_header(&mut stream).await;
stream.write_all(REPLY_404).await.unwrap();
assert!(
remaining.remove(&head),
"backend received unexpected or duplicated probe prefix"
);
}
assert!(
remaining.is_empty(),
"all session prefixes must be observed exactly once"
);
});
let mut tasks = Vec::with_capacity(session_count);
for idx in 0..session_count {
let config = masking_config(backend_addr.port());
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let probe =
format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx)
.parse()
.unwrap();
tasks.push(tokio::spawn(async move {
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404);
let result = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
}));
}
for task in tasks {
task.await.unwrap();
}
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap();
}

View File

@ -0,0 +1,645 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
struct RedTeamHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn build_harness(secret_hex: &str, mask_port: u16) -> RedTeamHarness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
RedTeamHarness {
config,
stats,
upstream_manager,
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn run_tls_success_mtproto_fail_session(
secret_hex: &str,
secret: [u8; 16],
timestamp: u32,
tail: Vec<u8>,
) -> Vec<u8> {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let harness = build_harness(secret_hex, backend_addr.port());
let client_hello = make_valid_tls_client_hello(&secret, timestamp, 600, 0x42);
let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let trailing_record = wrap_tls_application_data(&tail);
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; trailing_record.len()];
stream.read_exact(&mut got).await.unwrap();
got
});
let (server_side, mut client_side) = duplex(262144);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.250:56900".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
let body_len = u16::from_be_bytes([head[3], head[4]]) as usize;
let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap();
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side
.write_all(&wrap_tls_application_data(&tail))
.await
.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
forwarded
}
#[tokio::test]
#[ignore = "red-team expected-fail: demonstrates that post-TLS fallback still forwards data to backend"]
async fn redteam_01_backend_receives_no_data_after_mtproto_fail() {
let forwarded = run_tls_success_mtproto_fail_session(
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
[0xAA; 16],
1,
b"probe-a".to_vec(),
)
.await;
assert!(
forwarded.is_empty(),
"backend unexpectedly received fallback bytes"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: strict no-fallback policy hypothesis"]
async fn redteam_02_backend_must_never_receive_tls_records_after_mtproto_fail() {
let forwarded = run_tls_success_mtproto_fail_session(
"abababababababababababababababab",
[0xAB; 16],
2,
b"probe-b".to_vec(),
)
.await;
assert_ne!(
forwarded[0], 0x17,
"received TLS application record despite strict policy"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: impossible timing uniformity target"]
async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.access.ignore_time_skew = true;
cfg.access.users.insert(
"user".to_string(),
"acacacacacacacacacacacacacacacac".to_string(),
);
let harness = RedTeamHarness {
config: Arc::new(cfg),
stats: Arc::new(Stats::new()),
upstream_manager: Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
Arc::new(Stats::new()),
)),
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
};
let hello = make_valid_tls_client_hello(&[0xAC; 16], 3, 600, 0x42);
let (server_side, mut client_side) = duplex(131072);
let started = Instant::now();
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.251:56901".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&hello).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(
started.elapsed() < Duration::from_millis(1),
"fallback path took longer than 1ms"
);
}
macro_rules! redteam_tail_must_not_forward_case {
($name:ident, $hex:expr, $secret:expr, $ts:expr, $len:expr) => {
#[tokio::test]
#[ignore = "red-team expected-fail: strict no-forwarding hypothesis"]
async fn $name() {
let mut tail = vec![0u8; $len];
for (i, b) in tail.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(31).wrapping_add(7);
}
let forwarded = run_tls_success_mtproto_fail_session($hex, $secret, $ts, tail).await;
assert!(
forwarded.is_empty(),
"strict model expects zero forwarded bytes, got {}",
forwarded.len()
);
}
};
}
redteam_tail_must_not_forward_case!(
redteam_04_tail_len_1_not_forwarded,
"adadadadadadadadadadadadadadadad",
[0xAD; 16],
4,
1
);
redteam_tail_must_not_forward_case!(
redteam_05_tail_len_2_not_forwarded,
"aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae",
[0xAE; 16],
5,
2
);
redteam_tail_must_not_forward_case!(
redteam_06_tail_len_3_not_forwarded,
"afafafafafafafafafafafafafafafaf",
[0xAF; 16],
6,
3
);
redteam_tail_must_not_forward_case!(
redteam_07_tail_len_7_not_forwarded,
"b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0",
[0xB0; 16],
7,
7
);
redteam_tail_must_not_forward_case!(
redteam_08_tail_len_15_not_forwarded,
"b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1",
[0xB1; 16],
8,
15
);
redteam_tail_must_not_forward_case!(
redteam_09_tail_len_63_not_forwarded,
"b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2",
[0xB2; 16],
9,
63
);
redteam_tail_must_not_forward_case!(
redteam_10_tail_len_127_not_forwarded,
"b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3",
[0xB3; 16],
10,
127
);
redteam_tail_must_not_forward_case!(
redteam_11_tail_len_255_not_forwarded,
"b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4",
[0xB4; 16],
11,
255
);
redteam_tail_must_not_forward_case!(
redteam_12_tail_len_511_not_forwarded,
"b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5",
[0xB5; 16],
12,
511
);
redteam_tail_must_not_forward_case!(
redteam_13_tail_len_1023_not_forwarded,
"b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6",
[0xB6; 16],
13,
1023
);
redteam_tail_must_not_forward_case!(
redteam_14_tail_len_2047_not_forwarded,
"b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7",
[0xB7; 16],
14,
2047
);
redteam_tail_must_not_forward_case!(
redteam_15_tail_len_4095_not_forwarded,
"b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8",
[0xB8; 16],
15,
4095
);
#[tokio::test]
#[ignore = "red-team expected-fail: impossible indistinguishability envelope"]
async fn redteam_16_timing_delta_between_paths_must_be_sub_1ms_under_concurrency() {
let runs = 20usize;
let mut durations = Vec::with_capacity(runs);
for i in 0..runs {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0xB9u8; 16];
let harness = build_harness("b9b9b9b9b9b9b9b9b9b9b9b9b9b9b9b9", backend_addr.port());
let hello = make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x42);
let accept_task = tokio::spawn(async move {
let (_stream, _) = listener.accept().await.unwrap();
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.252:56902".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
let started = Instant::now();
client_side.write_all(&hello).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
durations.push(started.elapsed());
}
let min = durations.iter().copied().min().unwrap();
let max = durations.iter().copied().max().unwrap();
assert!(
max - min <= Duration::from_millis(1),
"timing spread too wide for strict anti-probing envelope"
);
}
async fn measure_invalid_probe_duration_ms(delay_ms: u64, tls_len: u16, body_sent: usize) -> u128 {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.timeouts.client_handshake = 1;
cfg.censorship.server_hello_delay_min_ms = delay_ms;
cfg.censorship.server_hello_delay_max_ms = delay_ms;
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.253:56903".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
Arc::new(Stats::new()),
)),
Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&tls_len.to_be_bytes());
probe[5..].fill(0xD7);
let started = Instant::now();
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
started.elapsed().as_millis()
}
async fn capture_forwarded_probe_len(tls_len: u16, body_sent: usize) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.timeouts.client_handshake = 1;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.254:56904".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
Arc::new(Stats::new()),
)),
Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&tls_len.to_be_bytes());
probe[5..].fill(0xBC);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
macro_rules! redteam_timing_envelope_case {
($name:ident, $delay_ms:expr, $tls_len:expr, $body_sent:expr, $max_ms:expr) => {
#[tokio::test]
#[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"]
async fn $name() {
let elapsed_ms =
measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await;
assert!(
elapsed_ms <= $max_ms,
"timing envelope violated: elapsed={}ms, max={}ms",
elapsed_ms,
$max_ms
);
}
};
}
macro_rules! redteam_constant_shape_case {
($name:ident, $tls_len:expr, $body_sent:expr, $expected_len:expr) => {
#[tokio::test]
#[ignore = "red-team expected-fail: strict constant-shape backend fingerprint hypothesis"]
async fn $name() {
let got = capture_forwarded_probe_len($tls_len, $body_sent).await;
assert_eq!(
got, $expected_len,
"fingerprint shape mismatch: got={} expected={} (strict constant-shape model)",
got, $expected_len
);
}
};
}
redteam_timing_envelope_case!(redteam_17_timing_env_very_tight_00, 700, 600, 0, 3);
redteam_timing_envelope_case!(redteam_18_timing_env_very_tight_01, 700, 600, 1, 3);
redteam_timing_envelope_case!(redteam_19_timing_env_very_tight_02, 700, 600, 7, 3);
redteam_timing_envelope_case!(redteam_20_timing_env_very_tight_03, 700, 600, 17, 3);
redteam_timing_envelope_case!(redteam_21_timing_env_very_tight_04, 700, 600, 31, 3);
redteam_timing_envelope_case!(redteam_22_timing_env_very_tight_05, 700, 600, 63, 3);
redteam_timing_envelope_case!(redteam_23_timing_env_very_tight_06, 700, 600, 127, 3);
redteam_timing_envelope_case!(redteam_24_timing_env_very_tight_07, 700, 600, 255, 3);
redteam_timing_envelope_case!(redteam_25_timing_env_very_tight_08, 700, 600, 511, 3);
redteam_timing_envelope_case!(redteam_26_timing_env_very_tight_09, 700, 600, 1023, 3);
redteam_timing_envelope_case!(redteam_27_timing_env_very_tight_10, 700, 600, 2047, 3);
redteam_timing_envelope_case!(redteam_28_timing_env_very_tight_11, 700, 600, 4095, 3);
redteam_constant_shape_case!(redteam_29_constant_shape_00, 600, 0, 517);
redteam_constant_shape_case!(redteam_30_constant_shape_01, 600, 1, 517);
redteam_constant_shape_case!(redteam_31_constant_shape_02, 600, 7, 517);
redteam_constant_shape_case!(redteam_32_constant_shape_03, 600, 17, 517);
redteam_constant_shape_case!(redteam_33_constant_shape_04, 600, 31, 517);
redteam_constant_shape_case!(redteam_34_constant_shape_05, 600, 63, 517);
redteam_constant_shape_case!(redteam_35_constant_shape_06, 600, 127, 517);
redteam_constant_shape_case!(redteam_36_constant_shape_07, 600, 255, 517);
redteam_constant_shape_case!(redteam_37_constant_shape_08, 600, 511, 517);
redteam_constant_shape_case!(redteam_38_constant_shape_09, 600, 1023, 517);
redteam_constant_shape_case!(redteam_39_constant_shape_10, 600, 2047, 517);
redteam_constant_shape_case!(redteam_40_constant_shape_11, 600, 4095, 517);

View File

@ -0,0 +1,246 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_probe_capture(
body_sent: usize,
tls_len: u16,
enable_shape_hardening: bool,
floor: usize,
cap: usize,
) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_shape_hardening = enable_shape_hardening;
cfg.censorship.mask_shape_bucket_floor_bytes = floor;
cfg.censorship.mask_shape_bucket_cap_bytes = cap;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.214:57014".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&tls_len.to_be_bytes());
probe[5..].fill(0x66);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
fn pearson_corr(xs: &[f64], ys: &[f64]) -> f64 {
if xs.len() != ys.len() || xs.is_empty() {
return 0.0;
}
let n = xs.len() as f64;
let mean_x = xs.iter().sum::<f64>() / n;
let mean_y = ys.iter().sum::<f64>() / n;
let mut cov = 0.0;
let mut var_x = 0.0;
let mut var_y = 0.0;
for (&x, &y) in xs.iter().zip(ys.iter()) {
let dx = x - mean_x;
let dy = y - mean_y;
cov += dx * dy;
var_x += dx * dx;
var_y += dy * dy;
}
if var_x == 0.0 || var_y == 0.0 {
return 0.0;
}
cov / (var_x.sqrt() * var_y.sqrt())
}
fn lcg_sizes(count: usize, floor: usize, cap: usize) -> Vec<usize> {
let mut x = 0x9E3779B97F4A7C15u64;
let span = cap.saturating_mul(3);
let mut out = Vec::with_capacity(count + 8);
for _ in 0..count {
x = x
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let v = (x as usize) % span.max(1);
out.push(v);
}
// Inject edge and boundary-heavy probes.
out.extend_from_slice(&[
0,
floor.saturating_sub(1),
floor,
floor.saturating_add(1),
cap.saturating_sub(1),
cap,
cap.saturating_add(1),
cap.saturating_mul(2),
]);
out
}
async fn collect_distribution(
sizes: &[usize],
hardening: bool,
floor: usize,
cap: usize,
) -> Vec<usize> {
let mut out = Vec::with_capacity(sizes.len());
for &body in sizes {
out.push(run_probe_capture(body, 1200, hardening, floor, cap).await);
}
out
}
#[tokio::test]
#[ignore = "red-team expected-fail: strict decorrelation target for hardened output lengths"]
async fn redteam_fuzz_01_hardened_output_length_correlation_should_be_below_0_2() {
let floor = 512usize;
let cap = 4096usize;
let sizes = lcg_sizes(24, floor, cap);
let hardened = collect_distribution(&sizes, true, floor, cap).await;
let x: Vec<f64> = sizes.iter().map(|v| *v as f64).collect();
let y_hard: Vec<f64> = hardened.iter().map(|v| *v as f64).collect();
let corr_hard = pearson_corr(&x, &y_hard).abs();
println!(
"redteam_fuzz corr_hardened={corr_hard:.4} samples={}",
sizes.len()
);
assert!(
corr_hard < 0.2,
"strict model expects near-zero size correlation; observed corr={corr_hard:.4}"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: strict class-collapse ratio target"]
async fn redteam_fuzz_02_hardened_unique_output_ratio_should_be_below_5pct() {
let floor = 512usize;
let cap = 4096usize;
let sizes = lcg_sizes(24, floor, cap);
let hardened = collect_distribution(&sizes, true, floor, cap).await;
let in_unique = {
let mut s = std::collections::BTreeSet::new();
for v in &sizes {
s.insert(*v);
}
s.len()
};
let out_unique = {
let mut s = std::collections::BTreeSet::new();
for v in &hardened {
s.insert(*v);
}
s.len()
};
let ratio = out_unique as f64 / in_unique as f64;
println!(
"redteam_fuzz unique_ratio_hardened={ratio:.4} out_unique={} in_unique={}",
out_unique, in_unique
);
assert!(
ratio <= 0.05,
"strict model expects near-total collapse; observed ratio={ratio:.4}"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: strict separability improvement target"]
async fn redteam_fuzz_03_hardened_signal_must_be_10x_lower_than_plain() {
let floor = 512usize;
let cap = 4096usize;
let sizes = lcg_sizes(24, floor, cap);
let plain = collect_distribution(&sizes, false, floor, cap).await;
let hardened = collect_distribution(&sizes, true, floor, cap).await;
let x: Vec<f64> = sizes.iter().map(|v| *v as f64).collect();
let y_plain: Vec<f64> = plain.iter().map(|v| *v as f64).collect();
let y_hard: Vec<f64> = hardened.iter().map(|v| *v as f64).collect();
let corr_plain = pearson_corr(&x, &y_plain).abs();
let corr_hard = pearson_corr(&x, &y_hard).abs();
println!("redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}");
assert!(
corr_hard <= corr_plain * 0.1,
"strict model expects 10x suppression; plain={corr_plain:.4} hardened={corr_hard:.4}"
);
}

View File

@ -0,0 +1,179 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn expected_bucket(total: usize, floor: usize, cap: usize) -> usize {
if total == 0 || floor == 0 || cap < floor {
return total;
}
if total >= cap {
return total;
}
let mut bucket = floor;
while bucket < total {
match bucket.checked_mul(2) {
Some(next) => bucket = next,
None => return total,
}
if bucket > cap {
return cap;
}
}
bucket
}
async fn run_probe_capture(
body_sent: usize,
tls_len: u16,
enable_shape_hardening: bool,
floor: usize,
cap: usize,
) -> Vec<u8> {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_shape_hardening = enable_shape_hardening;
cfg.censorship.mask_shape_bucket_floor_bytes = floor;
cfg.censorship.mask_shape_bucket_cap_bytes = cap;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.199:56999".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&tls_len.to_be_bytes());
probe[5..].fill(0x66);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn shape_hardening_non_power_of_two_cap_collapses_probe_classes() {
let floor = 1000usize;
let cap = 1500usize;
let low = run_probe_capture(1195, 700, true, floor, cap).await;
let high = run_probe_capture(1494, 700, true, floor, cap).await;
assert_eq!(low.len(), 1500);
assert_eq!(high.len(), 1500);
}
#[tokio::test]
async fn shape_hardening_disabled_keeps_non_power_of_two_cap_lengths_distinct() {
let floor = 1000usize;
let cap = 1500usize;
let low = run_probe_capture(1195, 700, false, floor, cap).await;
let high = run_probe_capture(1494, 700, false, floor, cap).await;
assert_eq!(low.len(), 1200);
assert_eq!(high.len(), 1499);
}
#[tokio::test]
async fn shape_hardening_parallel_stress_collapses_sub_cap_probes() {
let floor = 1000usize;
let cap = 1500usize;
let mut tasks = Vec::new();
for idx in 0..24usize {
let body = 1001 + (idx * 19 % 480);
tasks.push(tokio::spawn(async move {
run_probe_capture(body, 1200, true, floor, cap).await.len()
}));
}
for task in tasks {
let observed = task.await.unwrap();
assert_eq!(observed, 1500);
}
}
#[tokio::test]
async fn shape_hardening_light_fuzz_matches_bucket_oracle() {
let floor = 512usize;
let cap = 4096usize;
for step in 1usize..=36usize {
let total = 1 + (((step * 313) ^ (step << 7)) % (cap + 300));
let body = total.saturating_sub(5);
let got = run_probe_capture(body, 650, true, floor, cap).await;
let expected = expected_bucket(total, floor, cap);
assert_eq!(
got.len(),
expected,
"step={step} total={total} expected={expected} got={} ",
got.len()
);
}
}

View File

@ -0,0 +1,238 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_probe_capture(
body_sent: usize,
tls_len: u16,
enable_shape_hardening: bool,
floor: usize,
cap: usize,
) -> Vec<u8> {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_shape_hardening = enable_shape_hardening;
cfg.censorship.mask_shape_bucket_floor_bytes = floor;
cfg.censorship.mask_shape_bucket_cap_bytes = cap;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.211:57011".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&tls_len.to_be_bytes());
probe[5..].fill(0x66);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
async fn measure_reject_ms(body_sent: usize) -> u128 {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.censorship.server_hello_delay_min_ms = 700;
cfg.censorship.server_hello_delay_max_ms = 700;
let (server_side, mut client_side) = duplex(65536);
let started = Instant::now();
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.212:57012".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&600u16.to_be_bytes());
probe[5..].fill(0x44);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
started.elapsed().as_millis()
}
#[tokio::test]
#[ignore = "red-team expected-fail: above-cap exact length still leaks classifier signal"]
async fn redteam_shape_01_above_cap_flows_should_collapse_to_single_class() {
let floor = 512usize;
let cap = 4096usize;
let a = run_probe_capture(5000, 7000, true, floor, cap).await;
let b = run_probe_capture(6000, 7000, true, floor, cap).await;
assert_eq!(
a.len(),
b.len(),
"strict anti-classifier model expects same backend length class above cap"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: current padding bytes are deterministic zeros"]
async fn redteam_shape_02_padding_tail_must_be_non_deterministic() {
let floor = 512usize;
let cap = 4096usize;
let got = run_probe_capture(17, 600, true, floor, cap).await;
assert!(got.len() > 22, "test requires padding tail to exist");
let tail = &got[22..];
assert!(
tail.iter().any(|b| *b != 0),
"padding tail is fully zeroed and thus deterministic"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: exact-floor probes still expose boundary class"]
async fn redteam_shape_03_exact_floor_input_should_not_be_fixed_point() {
let floor = 512usize;
let cap = 4096usize;
let got = run_probe_capture(507, 600, true, floor, cap).await;
assert!(
got.len() > floor,
"strict model expects extra blur even when input lands exactly on floor"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: strict one-bucket collapse hypothesis"]
async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() {
let floor = 512usize;
let cap = 4096usize;
let classes = [
17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize,
];
let mut observed = Vec::new();
for body in classes {
observed.push(run_probe_capture(body, 1200, true, floor, cap).await.len());
}
let first = observed[0];
for v in observed {
assert_eq!(
v, first,
"strict model expects one collapsed class across all sub-cap probes"
);
}
}
#[tokio::test]
#[ignore = "red-team expected-fail: over-strict micro-timing invariance"]
async fn redteam_shape_05_reject_timing_spread_should_be_under_2ms() {
let classes = [17usize, 511usize, 1023usize, 2047usize, 4095usize];
let mut values = Vec::new();
for class in classes {
values.push(measure_reject_ms(class).await);
}
let min = *values.iter().min().unwrap();
let max = *values.iter().max().unwrap();
assert!(
min == 700 && max == 700,
"strict model requires exact 700ms for every malformed class: min={min}ms max={max}ms"
);
}
#[test]
#[ignore = "red-team expected-fail: secure-by-default hypothesis"]
fn redteam_shape_06_shape_hardening_should_be_secure_by_default() {
let cfg = ProxyConfig::default();
assert!(
cfg.censorship.mask_shape_hardening,
"strict model expects shape hardening enabled by default"
);
}

View File

@ -0,0 +1,122 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_probe_capture(
body_sent: usize,
tls_len: u16,
enable_shape_hardening: bool,
floor: usize,
cap: usize,
) -> Vec<u8> {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_shape_hardening = enable_shape_hardening;
cfg.censorship.mask_shape_bucket_floor_bytes = floor;
cfg.censorship.mask_shape_bucket_cap_bytes = cap;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got
});
let (server_side, mut client_side) = duplex(65536);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.188:56888".parse().unwrap(),
Arc::new(cfg),
Arc::new(Stats::new()),
new_upstream_manager(Arc::new(Stats::new())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&tls_len.to_be_bytes());
probe[5..].fill(0x66);
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn shape_hardening_disabled_keeps_original_probe_length() {
let got = run_probe_capture(17, 600, false, 512, 4096).await;
assert_eq!(got.len(), 22);
assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]);
}
#[tokio::test]
async fn shape_hardening_enabled_pads_small_probe_to_floor_bucket() {
let got = run_probe_capture(17, 600, true, 512, 4096).await;
assert_eq!(got.len(), 512);
assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]);
}
#[tokio::test]
async fn shape_hardening_enabled_pads_mid_probe_to_next_bucket() {
let got = run_probe_capture(511, 600, true, 512, 4096).await;
assert_eq!(got.len(), 1024);
assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]);
}
#[tokio::test]
async fn shape_hardening_respects_cap_and_avoids_padding_above_cap() {
let got = run_probe_capture(5000, 7000, true, 512, 4096).await;
assert_eq!(got.len(), 5005);
assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x1b, 0x58]);
}

View File

@ -0,0 +1,256 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
struct StressHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn build_harness(mask_port: u16, secret_hex: &str) -> StressHarness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
StressHarness {
config,
stats: stats.clone(),
upstream_manager: new_upstream_manager(stats),
replay_checker: Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn read_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
async fn run_parallel_tail_fallback_case(
sessions: usize,
payload_len: usize,
write_chunk: usize,
ts_base: u32,
peer_port_base: u16,
) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut expected = std::collections::HashSet::new();
for idx in 0..sessions {
let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3];
expected.insert(wrap_tls_application_data(&payload));
}
let accept_task = tokio::spawn(async move {
let mut remaining = expected;
for _ in 0..sessions {
let (mut stream, _) = listener.accept().await.unwrap();
let mut header = [0u8; 5];
stream.read_exact(&mut header).await.unwrap();
assert_eq!(header[0], TLS_RECORD_APPLICATION);
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut record = vec![0u8; 5 + len];
record[..5].copy_from_slice(&header);
stream.read_exact(&mut record[5..]).await.unwrap();
assert!(remaining.remove(&record));
}
assert!(remaining.is_empty());
});
let mut tasks = Vec::with_capacity(sessions);
for idx in 0..sessions {
let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0");
let hello =
make_valid_tls_client_hello(&[0xE0; 16], ts_base + idx as u32, 600, 0x40 + (idx as u8));
let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3];
let trailing = wrap_tls_application_data(&payload);
// Keep source IPs unique across stress cases so global pre-auth probe state
// cannot contaminate unrelated sessions and make this test nondeterministic.
let peer_ip_third = 100 + ((ts_base as u8) / 10);
let peer_ip_fourth = (idx as u8).saturating_add(1);
let peer: SocketAddr = format!(
"198.51.{}.{}:{}",
peer_ip_third,
peer_ip_fourth,
peer_port_base + idx as u16
)
.parse()
.unwrap();
tasks.push(tokio::spawn(async move {
let (server_side, mut client_side) = duplex(262144);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&hello).await.unwrap();
let mut server_hello_head = [0u8; 5];
client_side
.read_exact(&mut server_hello_head)
.await
.unwrap();
assert_eq!(server_hello_head[0], 0x16);
read_tls_record_body(&mut client_side, server_hello_head).await;
client_side.write_all(&invalid_mtproto).await.unwrap();
for chunk in trailing.chunks(write_chunk.max(1)) {
client_side.write_all(chunk).await.unwrap();
}
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), handler)
.await
.unwrap()
.unwrap();
}));
}
for task in tasks {
task.await.unwrap();
}
tokio::time::timeout(Duration::from_secs(8), accept_task)
.await
.unwrap()
.unwrap();
}
macro_rules! stress_case {
($name:ident, $sessions:expr, $payload_len:expr, $chunk:expr, $ts:expr, $port:expr) => {
#[tokio::test]
async fn $name() {
run_parallel_tail_fallback_case($sessions, $payload_len, $chunk, $ts, $port).await;
}
};
}
stress_case!(stress_masking_parallel_s01, 4, 16, 1, 1000, 57000);
stress_case!(stress_masking_parallel_s02, 5, 24, 2, 1010, 57010);
stress_case!(stress_masking_parallel_s03, 6, 32, 3, 1020, 57020);
stress_case!(stress_masking_parallel_s04, 7, 40, 4, 1030, 57030);
stress_case!(stress_masking_parallel_s05, 8, 48, 5, 1040, 57040);
stress_case!(stress_masking_parallel_s06, 9, 56, 6, 1050, 57050);
stress_case!(stress_masking_parallel_s07, 10, 64, 7, 1060, 57060);
stress_case!(stress_masking_parallel_s08, 11, 72, 8, 1070, 57070);
stress_case!(stress_masking_parallel_s09, 12, 80, 9, 1080, 57080);
stress_case!(stress_masking_parallel_s10, 13, 88, 10, 1090, 57090);
stress_case!(stress_masking_parallel_s11, 6, 128, 11, 1100, 57100);
stress_case!(stress_masking_parallel_s12, 7, 160, 12, 1110, 57110);
stress_case!(stress_masking_parallel_s13, 8, 192, 13, 1120, 57120);
stress_case!(stress_masking_parallel_s14, 9, 224, 14, 1130, 57130);
stress_case!(stress_masking_parallel_s15, 10, 256, 15, 1140, 57140);
stress_case!(stress_masking_parallel_s16, 11, 288, 16, 1150, 57150);
stress_case!(stress_masking_parallel_s17, 12, 320, 17, 1160, 57160);
stress_case!(stress_masking_parallel_s18, 13, 352, 18, 1170, 57170);
stress_case!(stress_masking_parallel_s19, 14, 384, 19, 1180, 57180);
stress_case!(stress_masking_parallel_s20, 15, 416, 20, 1190, 57190);
stress_case!(stress_masking_parallel_s21, 16, 448, 21, 1200, 57200);
stress_case!(stress_masking_parallel_s22, 17, 480, 22, 1210, 57210);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,370 @@
//! Differential timing-profile adversarial tests.
//! Compare malformed in-range TLS truncation probes with plain web baselines,
//! ensuring masking behavior stays in similar latency buckets.
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
#[derive(Clone, Copy, Debug)]
enum ProbeClass {
MalformedTlsTruncation,
PlainWebBaseline,
}
fn make_test_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn malformed_tls_probe() -> Vec<u8> {
vec![
0x16,
0x03,
0x03,
((MIN_TLS_CLIENT_HELLO_SIZE >> 8) & 0xff) as u8,
(MIN_TLS_CLIENT_HELLO_SIZE & 0xff) as u8,
0x41,
]
}
fn plain_web_probe() -> Vec<u8> {
b"GET /timing-profile HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec()
}
fn summarize(samples_ms: &[u128]) -> (f64, u128, u128, u128) {
let mut sorted = samples_ms.to_vec();
sorted.sort_unstable();
let sum: u128 = sorted.iter().copied().sum();
let mean = sum as f64 / sorted.len() as f64;
let min = sorted[0];
let p95_idx = ((sorted.len() as f64) * 0.95).floor() as usize;
let p95 = sorted[p95_idx.min(sorted.len() - 1)];
let max = sorted[sorted.len() - 1];
(mean, min, p95, max)
}
async fn run_generic_once(class: ProbeClass) -> u128 {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let backend_reply = REPLY_404.to_vec();
let accept_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 5];
stream.read_exact(&mut buf).await.unwrap();
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
if matches!(class, ProbeClass::PlainWebBaseline) {
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
}
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "203.0.113.210:55110".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
let probe = match class {
ProbeClass::MalformedTlsTruncation => malformed_tls_probe(),
ProbeClass::PlainWebBaseline => plain_web_probe(),
};
let started = Instant::now();
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404);
tokio::time::timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
started.elapsed().as_millis()
}
async fn run_client_handler_once(class: ProbeClass) -> u128 {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = mask_listener.local_addr().unwrap();
let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let front_addr = front_listener.local_addr().unwrap();
let backend_reply = REPLY_404.to_vec();
let mask_accept_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut buf = [0u8; 5];
stream.read_exact(&mut buf).await.unwrap();
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
if matches!(class, ProbeClass::PlainWebBaseline) {
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
}
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let server_task = {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let route_runtime = route_runtime.clone();
let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone();
tokio::spawn(async move {
let (stream, peer) = front_listener.accept().await.unwrap();
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
ClientHandler::new(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
real_peer_report,
)
.run()
.await
})
};
let probe = match class {
ProbeClass::MalformedTlsTruncation => malformed_tls_probe(),
ProbeClass::PlainWebBaseline => plain_web_probe(),
};
let mut client = TcpStream::connect(front_addr).await.unwrap();
let started = Instant::now();
client.write_all(&probe).await.unwrap();
client.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed))
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404);
tokio::time::timeout(Duration::from_secs(2), mask_accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), server_task)
.await
.unwrap()
.unwrap();
started.elapsed().as_millis()
}
#[tokio::test]
async fn differential_timing_generic_malformed_tls_vs_plain_web_mask_profile_similar() {
const ITER: usize = 24;
const BUCKET_MS: u128 = 20;
let mut malformed = Vec::with_capacity(ITER);
let mut plain = Vec::with_capacity(ITER);
for _ in 0..ITER {
malformed.push(run_generic_once(ProbeClass::MalformedTlsTruncation).await);
plain.push(run_generic_once(ProbeClass::PlainWebBaseline).await);
}
let (m_mean, m_min, m_p95, m_max) = summarize(&malformed);
let (p_mean, p_min, p_p95, p_max) = summarize(&plain);
println!(
"TIMING_DIFF generic class=malformed mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}",
m_mean,
m_min,
m_p95,
m_max,
(m_mean as u128) / BUCKET_MS,
m_p95 / BUCKET_MS
);
println!(
"TIMING_DIFF generic class=plain_web mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}",
p_mean,
p_min,
p_p95,
p_max,
(p_mean as u128) / BUCKET_MS,
p_p95 / BUCKET_MS
);
let mean_bucket_delta = ((m_mean as i128) - (p_mean as i128)).abs() / (BUCKET_MS as i128);
let p95_bucket_delta = ((m_p95 as i128) - (p_p95 as i128)).abs() / (BUCKET_MS as i128);
assert!(
mean_bucket_delta <= 1,
"generic timing mean diverged: malformed_mean_ms={:.2}, plain_mean_ms={:.2}",
m_mean,
p_mean
);
assert!(
p95_bucket_delta <= 2,
"generic timing p95 diverged: malformed_p95_ms={}, plain_p95_ms={}",
m_p95,
p_p95
);
}
#[tokio::test]
async fn differential_timing_client_handler_malformed_tls_vs_plain_web_mask_profile_similar() {
const ITER: usize = 16;
const BUCKET_MS: u128 = 20;
let mut malformed = Vec::with_capacity(ITER);
let mut plain = Vec::with_capacity(ITER);
for _ in 0..ITER {
malformed.push(run_client_handler_once(ProbeClass::MalformedTlsTruncation).await);
plain.push(run_client_handler_once(ProbeClass::PlainWebBaseline).await);
}
let (m_mean, m_min, m_p95, m_max) = summarize(&malformed);
let (p_mean, p_min, p_p95, p_max) = summarize(&plain);
println!(
"TIMING_DIFF handler class=malformed mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}",
m_mean,
m_min,
m_p95,
m_max,
(m_mean as u128) / BUCKET_MS,
m_p95 / BUCKET_MS
);
println!(
"TIMING_DIFF handler class=plain_web mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}",
p_mean,
p_min,
p_p95,
p_max,
(p_mean as u128) / BUCKET_MS,
p_p95 / BUCKET_MS
);
let mean_bucket_delta = ((m_mean as i128) - (p_mean as i128)).abs() / (BUCKET_MS as i128);
let p95_bucket_delta = ((m_p95 as i128) - (p_p95 as i128)).abs() / (BUCKET_MS as i128);
assert!(
mean_bucket_delta <= 1,
"handler timing mean diverged: malformed_mean_ms={:.2}, plain_mean_ms={:.2}",
m_mean,
p_mean
);
assert!(
p95_bucket_delta <= 2,
"handler timing p95 diverged: malformed_p95_ms={}, plain_p95_ms={}",
m_p95,
p_p95
);
}

View File

@ -0,0 +1,209 @@
//! TLS ClientHello size validation tests for proxy anti-censorship security
//! Covers positive, negative, edge, adversarial, and fuzz cases.
//! Ensures proxy does not reveal itself on probe failures.
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
fn test_probe_for_len(len: usize) -> [u8; 5] {
[
0x16,
0x03,
0x03,
((len >> 8) & 0xff) as u8,
(len & 0xff) as u8,
]
}
fn make_test_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = test_probe_for_len(len);
let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe, "mask backend must receive original probe bytes");
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let bad_before = stats.get_connects_bad();
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "203.0.113.123:55123".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe).await.unwrap();
let mut observed = vec![0u8; backend_reply.len()];
client_side.read_exact(&mut observed).await.unwrap();
assert_eq!(
observed, backend_reply,
"invalid TLS path must be masked as a real site"
);
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
accept_task.await.unwrap();
let expected_bad = if expect_bad_increment {
bad_before + 1
} else {
bad_before
};
assert_eq!(
stats.get_connects_bad(),
expected_bad,
"unexpected connects_bad classification for tls_len={len}"
);
}
#[tokio::test]
async fn tls_client_hello_lower_bound_minus_one_is_masked_and_counted_bad() {
run_probe_and_assert_masking(MIN_TLS_CLIENT_HELLO_SIZE - 1, true).await;
}
#[tokio::test]
async fn tls_client_hello_upper_bound_plus_one_is_masked_and_counted_bad() {
run_probe_and_assert_masking(MAX_TLS_PLAINTEXT_SIZE + 1, true).await;
}
#[tokio::test]
async fn tls_client_hello_header_zero_len_is_masked_and_counted_bad() {
run_probe_and_assert_masking(0, true).await;
}
#[test]
fn tls_client_hello_len_bounds_unit_adversarial_sweep() {
let cases = [
(0usize, false),
(1usize, false),
(99usize, false),
(100usize, true),
(101usize, true),
(511usize, true),
(512usize, true),
(MAX_TLS_PLAINTEXT_SIZE - 1, true),
(MAX_TLS_PLAINTEXT_SIZE, true),
(MAX_TLS_PLAINTEXT_SIZE + 1, false),
(u16::MAX as usize, false),
(usize::MAX, false),
];
for (len, expected) in cases {
assert_eq!(
tls_clienthello_len_in_bounds(len),
expected,
"unexpected bounds result for tls_len={len}"
);
}
}
#[test]
fn tls_client_hello_len_bounds_light_fuzz_deterministic_lcg() {
let mut x: u32 = 0xA5A5_5A5A;
for _ in 0..2_048 {
x = x.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
let base = (x as usize) & 0x3fff;
let len = match x & 0x7 {
0 => MIN_TLS_CLIENT_HELLO_SIZE - 1,
1 => MIN_TLS_CLIENT_HELLO_SIZE,
2 => MIN_TLS_CLIENT_HELLO_SIZE + 1,
3 => MAX_TLS_PLAINTEXT_SIZE - 1,
4 => MAX_TLS_PLAINTEXT_SIZE,
5 => MAX_TLS_PLAINTEXT_SIZE + 1,
_ => base,
};
let expect_bad = !(MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&len);
assert_eq!(
tls_clienthello_len_in_bounds(len),
!expect_bad,
"deterministic fuzz mismatch for tls_len={len}"
);
}
}
#[test]
fn tls_client_hello_len_bounds_stress_many_evaluations() {
for _ in 0..100_000 {
assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE));
assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE));
assert!(!tls_clienthello_len_in_bounds(
MIN_TLS_CLIENT_HELLO_SIZE - 1
));
assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1));
}
}
#[tokio::test]
async fn tls_client_hello_masking_integration_repeated_small_probes() {
for _ in 0..25 {
run_probe_and_assert_masking(MIN_TLS_CLIENT_HELLO_SIZE - 1, true).await;
}
}

View File

@ -0,0 +1,572 @@
//! Black-hat adversarial tests for truncated in-range TLS ClientHello probes.
//! These tests encode a strict anti-probing expectation: malformed TLS traffic
//! should still be masked as a legitimate website response.
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::sleep;
fn in_range_probe_header() -> [u8; 5] {
[
0x16,
0x03,
0x03,
((MIN_TLS_CLIENT_HELLO_SIZE >> 8) & 0xff) as u8,
(MIN_TLS_CLIENT_HELLO_SIZE & 0xff) as u8,
]
}
fn make_test_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn truncated_in_range_record(actual_body_len: usize) -> Vec<u8> {
let mut out = in_range_probe_header().to_vec();
out.extend(std::iter::repeat_n(0x41, actual_body_len));
out
}
async fn write_fragmented<W: AsyncWriteExt + Unpin>(
writer: &mut W,
bytes: &[u8],
chunks: &[usize],
delay_ms: u64,
) {
let mut offset = 0usize;
for &chunk in chunks {
if offset >= bytes.len() {
break;
}
let end = (offset + chunk).min(bytes.len());
writer.write_all(&bytes[offset..end]).await.unwrap();
offset = end;
if delay_ms > 0 {
sleep(Duration::from_millis(delay_ms)).await;
}
}
if offset < bytes.len() {
writer.write_all(&bytes[offset..]).await.unwrap();
}
}
async fn run_blackhat_generic_fragmented_probe_should_mask(
payload: Vec<u8>,
chunks: &[usize],
delay_ms: u64,
backend_reply: Vec<u8>,
) {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let mask_addr = mask_listener.local_addr().unwrap();
let probe_header = in_range_probe_header();
let mask_accept_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe_header);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "203.0.113.202:55002".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
write_fragmented(&mut client_side, &payload, chunks, delay_ms).await;
client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await
.unwrap()
.unwrap();
assert_eq!(observed, backend_reply);
tokio::time::timeout(Duration::from_secs(2), mask_accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
}
async fn run_blackhat_client_handler_fragmented_probe_should_mask(
payload: Vec<u8>,
chunks: &[usize],
delay_ms: u64,
backend_reply: Vec<u8>,
) {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let mask_addr = mask_listener.local_addr().unwrap();
let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let front_addr = front_listener.local_addr().unwrap();
let probe_header = in_range_probe_header();
let mask_accept_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe_header);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let server_task = {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let route_runtime = route_runtime.clone();
let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone();
tokio::spawn(async move {
let (stream, peer) = front_listener.accept().await.unwrap();
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
ClientHandler::new(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
real_peer_report,
)
.run()
.await
})
};
let mut client = TcpStream::connect(front_addr).await.unwrap();
write_fragmented(&mut client, &payload, chunks, delay_ms).await;
client.shutdown().await.unwrap();
let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed))
.await
.unwrap()
.unwrap();
assert_eq!(observed, backend_reply);
tokio::time::timeout(Duration::from_secs(2), mask_accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), server_task)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let mask_addr = mask_listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec();
let probe = in_range_probe_header();
let mask_accept_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "203.0.113.201:55001".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe).await.unwrap();
client_side.shutdown().await.unwrap();
// Security expectation: even malformed in-range TLS should be masked.
// This invariant must hold to avoid probe-distinguishable EOF/timeout behavior.
let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await
.unwrap()
.unwrap();
assert_eq!(observed, backend_reply);
tokio::time::timeout(Duration::from_secs(2), mask_accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask() {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let mask_addr = mask_listener.local_addr().unwrap();
let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let front_addr = front_listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec();
let probe = in_range_probe_header();
let mask_accept_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = make_test_upstream_manager(stats.clone());
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let server_task = {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let route_runtime = route_runtime.clone();
let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone();
tokio::spawn(async move {
let (stream, peer) = front_listener.accept().await.unwrap();
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
ClientHandler::new(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
real_peer_report,
)
.run()
.await
})
};
let mut client = TcpStream::connect(front_addr).await.unwrap();
client.write_all(&probe).await.unwrap();
client.shutdown().await.unwrap();
// Security expectation: malformed in-range TLS should still be masked.
let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed))
.await
.unwrap()
.unwrap();
assert_eq!(observed, backend_reply);
tokio::time::timeout(Duration::from_secs(2), mask_accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), server_task)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn blackhat_generic_truncated_min_body_1_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(1),
&[6],
0,
b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_generic_truncated_min_body_8_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(8),
&[13],
0,
b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_generic_truncated_min_body_99_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1),
&[5, MIN_TLS_CLIENT_HELLO_SIZE - 1],
0,
b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_generic_fragmented_header_then_close_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(0),
&[1, 1, 1, 1, 1],
0,
b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(5),
&[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
0,
b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_generic_slowloris_fragmented_min_probe_should_mask_but_times_out() {
run_blackhat_generic_fragmented_probe_should_mask(
truncated_in_range_record(1),
&[1, 1, 1, 1, 1, 1],
250,
b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_client_handler_truncated_min_body_1_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(1),
&[6],
0,
b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_client_handler_truncated_min_body_8_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(8),
&[13],
0,
b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_client_handler_truncated_min_body_99_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1),
&[5, MIN_TLS_CLIENT_HELLO_SIZE - 1],
0,
b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_client_handler_fragmented_header_then_close_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(0),
&[1, 1, 1, 1, 1],
0,
b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(5),
&[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
0,
b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}
#[tokio::test]
async fn blackhat_client_handler_slowloris_fragmented_min_probe_should_mask_but_times_out() {
run_blackhat_client_handler_fragmented_probe_should_mask(
truncated_in_range_record(1),
&[1, 1, 1, 1, 1, 1],
250,
b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(),
)
.await;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,37 @@
use super::*;
#[test]
fn wrap_tls_application_record_empty_payload_emits_zero_length_record() {
let record = wrap_tls_application_record(&[]);
assert_eq!(record.len(), 5);
assert_eq!(record[0], TLS_RECORD_APPLICATION);
assert_eq!(&record[1..3], &TLS_VERSION);
assert_eq!(&record[3..5], &0u16.to_be_bytes());
}
#[test]
fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() {
let total = (u16::MAX as usize) + 37;
let payload = vec![0xA5u8; total];
let record = wrap_tls_application_record(&payload);
let mut offset = 0usize;
let mut recovered = Vec::with_capacity(total);
let mut frames = 0usize;
while offset + 5 <= record.len() {
assert_eq!(record[offset], TLS_RECORD_APPLICATION);
assert_eq!(&record[offset + 1..offset + 3], &TLS_VERSION);
let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize;
let body_start = offset + 5;
let body_end = body_start + len;
assert!(body_end <= record.len(), "declared TLS record length must be in-bounds");
recovered.extend_from_slice(&record[body_start..body_end]);
offset = body_end;
frames += 1;
}
assert_eq!(offset, record.len(), "record parser must consume exact output size");
assert_eq!(frames, 2, "oversized payload should split into exactly two records");
assert_eq!(recovered, payload, "chunked records must preserve full payload");
}

View File

@ -0,0 +1,56 @@
use super::*;
use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4, TG_DATACENTERS_V6};
use std::net::SocketAddr;
#[test]
fn business_scope_hint_accepts_exact_boundary_length() {
let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN));
assert_eq!(
validated_scope_hint(&value),
Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
);
}
#[test]
fn business_scope_hint_rejects_missing_prefix_even_when_charset_is_valid() {
assert_eq!(validated_scope_hint("alpha-01"), None);
}
#[test]
fn business_known_dc_uses_ipv4_table_by_default() {
let cfg = ProxyConfig::default();
let resolved = get_dc_addr_static(2, &cfg).expect("known dc must resolve");
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn business_negative_dc_maps_by_absolute_value() {
let cfg = ProxyConfig::default();
let resolved =
get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value");
let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn business_known_dc_uses_ipv6_table_when_preferred_and_enabled() {
let mut cfg = ProxyConfig::default();
cfg.network.prefer = 6;
cfg.network.ipv6 = Some(true);
let resolved = get_dc_addr_static(1, &cfg).expect("known dc must resolve on ipv6 path");
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn business_unknown_dc_uses_configured_default_dc_when_in_range() {
let mut cfg = ProxyConfig::default();
cfg.default_dc = Some(4);
let resolved =
get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default");
let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}

View File

@ -0,0 +1,100 @@
use super::*;
use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4};
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Mutex;
#[test]
fn common_invalid_override_entries_fallback_to_static_table() {
let mut cfg = ProxyConfig::default();
cfg.dc_overrides.insert(
"2".to_string(),
vec!["bad-address".to_string(), "still-bad".to_string()],
);
let resolved =
get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve");
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
assert_eq!(resolved, expected);
}
#[test]
fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it() {
let mut cfg = ProxyConfig::default();
cfg.network.prefer = 6;
cfg.network.ipv6 = Some(true);
cfg.dc_overrides
.insert("3".to_string(), vec!["203.0.113.203:443".to_string()]);
let resolved =
get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists");
assert_eq!(resolved, "203.0.113.203:443".parse::<SocketAddr>().unwrap());
}
#[test]
fn common_scope_hint_rejects_unicode_lookalike_characters() {
assert_eq!(validated_scope_hint("scope_аlpha"), None);
assert_eq!(validated_scope_hint("scope_Αlpha"), None);
}
#[cfg(unix)]
#[test]
fn common_anchored_open_rejects_nul_filename() {
use std::os::unix::ffi::OsStringExt;
let parent = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!("telemt-direct-relay-nul-{}", std::process::id()));
std::fs::create_dir_all(&parent).expect("parent directory must be creatable");
let path = SanitizedUnknownDcLogPath {
resolved_path: parent.join("placeholder.log"),
allowed_parent: parent,
file_name: std::ffi::OsString::from_vec(vec![b'a', 0, b'b']),
};
let err = open_unknown_dc_log_append_anchored(&path)
.expect_err("anchored open must fail on NUL in filename");
assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}
#[cfg(unix)]
#[test]
fn common_anchored_open_creates_owner_only_file_permissions() {
use std::os::unix::fs::PermissionsExt;
let parent = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!("telemt-direct-relay-perm-{}", std::process::id()));
std::fs::create_dir_all(&parent).expect("parent directory must be creatable");
let sanitized = SanitizedUnknownDcLogPath {
resolved_path: parent.join("unknown-dc.log"),
allowed_parent: parent.clone(),
file_name: std::ffi::OsString::from("unknown-dc.log"),
};
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must create regular file");
use std::io::Write;
writeln!(file, "dc_idx=1").expect("write must succeed");
let mode = std::fs::metadata(parent.join("unknown-dc.log"))
.expect("metadata must be readable")
.permissions()
.mode()
& 0o777;
assert_eq!(mode, 0o600);
}
#[test]
fn common_duplicate_dc_attempts_do_not_consume_unique_slots() {
let set = Mutex::new(HashSet::new());
assert!(should_log_unknown_dc_with_set(&set, 100));
assert!(!should_log_unknown_dc_with_set(&set, 100));
assert!(should_log_unknown_dc_with_set(&set, 101));
assert_eq!(set.lock().expect("set lock must be available").len(), 2);
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,200 @@
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
fn nonempty_line_count(text: &str) -> usize {
text.lines().filter(|line| !line.trim().is_empty()).count()
}
#[test]
fn subtle_stress_single_unknown_dc_under_concurrency_logs_once() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
let winners = Arc::new(AtomicUsize::new(0));
let mut workers = Vec::new();
for _ in 0..128 {
let winners = Arc::clone(&winners);
workers.push(std::thread::spawn(move || {
if should_log_unknown_dc(31_333) {
winners.fetch_add(1, Ordering::Relaxed);
}
}));
}
for worker in workers {
worker.join().expect("worker must not panic");
}
assert_eq!(winners.load(Ordering::Relaxed), 1);
}
#[test]
fn subtle_light_fuzz_scope_hint_matches_oracle() {
fn oracle(input: &str) -> bool {
let Some(rest) = input.strip_prefix("scope_") else {
return false;
};
!rest.is_empty()
&& rest.len() <= MAX_SCOPE_HINT_LEN
&& rest.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-')
}
let mut state: u64 = 0xC0FF_EE11_D15C_AFE5;
for _ in 0..4_096 {
state ^= state << 7;
state ^= state >> 9;
state ^= state << 8;
let len = (state as usize % 72) + 1;
let mut s = String::with_capacity(len + 6);
if (state & 1) == 0 {
s.push_str("scope_");
} else {
s.push_str("user_");
}
for idx in 0..len {
let v = ((state >> ((idx % 8) * 8)) & 0xff) as u8;
let ch = match v % 6 {
0 => (b'a' + (v % 26)) as char,
1 => (b'A' + (v % 26)) as char,
2 => (b'0' + (v % 10)) as char,
3 => '-',
4 => '_',
_ => '.',
};
s.push(ch);
}
let got = validated_scope_hint(&s).is_some();
assert_eq!(got, oracle(&s), "mismatch for input: {s}");
}
}
#[test]
fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() {
let mut state: u64 = 0x1234_5678_9ABC_DEF0;
for _ in 0..2_048 {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let mut cfg = ProxyConfig::default();
cfg.network.prefer = if (state & 1) == 0 { 4 } else { 6 };
cfg.network.ipv6 = Some((state & 2) != 0);
cfg.default_dc = Some(((state >> 8) as u8).max(1));
let dc_idx = (state as i16).wrapping_sub(16_384);
let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail");
assert_eq!(
resolved.port(),
crate::protocol::constants::TG_DATACENTER_PORT
);
let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true);
assert_eq!(resolved.is_ipv6(), expect_v6);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn subtle_integration_parallel_same_dc_logs_one_line() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
let rel_dir = format!("target/telemt-direct-relay-same-{}", std::process::id());
let rel_file = format!("{rel_dir}/unknown-dc.log");
let abs_dir = std::env::current_dir()
.expect("cwd must be available")
.join(&rel_dir);
std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable");
let abs_file = abs_dir.join("unknown-dc.log");
let _ = std::fs::remove_file(&abs_file);
let mut cfg = ProxyConfig::default();
cfg.general.unknown_dc_file_log_enabled = true;
cfg.general.unknown_dc_log_path = Some(rel_file);
let cfg = Arc::new(cfg);
let mut tasks = Vec::new();
for _ in 0..32 {
let cfg = Arc::clone(&cfg);
tasks.push(tokio::spawn(async move {
let _ = get_dc_addr_static(31_777, cfg.as_ref());
}));
}
for task in tasks {
task.await.expect("task must not panic");
}
for _ in 0..60 {
if let Ok(content) = std::fs::read_to_string(&abs_file)
&& nonempty_line_count(&content) == 1
{
return;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
let content = std::fs::read_to_string(&abs_file).unwrap_or_default();
assert_eq!(nonempty_line_count(&content), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn subtle_integration_parallel_unique_dcs_log_unique_lines() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
let rel_dir = format!("target/telemt-direct-relay-unique-{}", std::process::id());
let rel_file = format!("{rel_dir}/unknown-dc.log");
let abs_dir = std::env::current_dir()
.expect("cwd must be available")
.join(&rel_dir);
std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable");
let abs_file = abs_dir.join("unknown-dc.log");
let _ = std::fs::remove_file(&abs_file);
let mut cfg = ProxyConfig::default();
cfg.general.unknown_dc_file_log_enabled = true;
cfg.general.unknown_dc_log_path = Some(rel_file);
let cfg = Arc::new(cfg);
let dcs = [
31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908,
];
let mut tasks = Vec::new();
for dc in dcs {
let cfg = Arc::clone(&cfg);
tasks.push(tokio::spawn(async move {
let _ = get_dc_addr_static(dc, cfg.as_ref());
}));
}
for task in tasks {
task.await.expect("task must not panic");
}
for _ in 0..80 {
if let Ok(content) = std::fs::read_to_string(&abs_file)
&& nonempty_line_count(&content) >= 8
{
return;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
let content = std::fs::read_to_string(&abs_file).unwrap_or_default();
assert!(
nonempty_line_count(&content) >= 8,
"expected at least one line per unique dc, content: {content}"
);
}

View File

@ -0,0 +1,563 @@
use super::*;
use crate::crypto::sha256;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::{Duration, Instant};
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg
}
// ------------------------------------------------------------------
// Mutational Bit-Flipping Tests (OWASP ASVS 5.1.4)
// ------------------------------------------------------------------
#[tokio::test]
async fn mtproto_handshake_bit_flip_anywhere_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "11223344556677889900aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
// Baseline check
let res = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
match res {
HandshakeResult::Success(_) => {}
_ => panic!("Baseline failed: expected Success"),
}
// Flip bits in the encrypted part (beyond the key material)
for byte_pos in SKIP_LEN..HANDSHAKE_LEN {
let mut h = base;
h[byte_pos] ^= 0x01; // Flip 1 bit
let res = handle_mtproto_handshake(
&h,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"Flip at byte {byte_pos} bit 0 must be rejected"
);
}
}
// ------------------------------------------------------------------
// Adversarial Probing / Timing Neutrality (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[tokio::test]
async fn mtproto_handshake_timing_neutrality_mocked() {
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap();
const ITER: usize = 50;
let mut start = Instant::now();
for _ in 0..ITER {
let _ = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
}
let duration_success = start.elapsed();
start = Instant::now();
for i in 0..ITER {
let mut h = base;
h[SKIP_LEN + (i % 48)] ^= 0xFF;
let _ = handle_mtproto_handshake(
&h,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
}
let duration_fail = start.elapsed();
let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64)
.abs()
/ ITER as f64;
// Threshold (loose for CI)
assert!(
avg_diff_ms < 100.0,
"Timing difference too large: {} ms/iter",
avg_diff_ms
);
}
// ------------------------------------------------------------------
// Stress Tests (OWASP ASVS 5.1.6)
// ------------------------------------------------------------------
#[tokio::test]
async fn auth_probe_throttle_saturation_stress() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let now = Instant::now();
// Record enough failures for one IP to trigger backoff
let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(target_ip, now);
}
assert!(auth_probe_is_throttled(target_ip, now));
// Stress test with many unique IPs
for i in 0..500u32 {
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 256) as u8));
auth_probe_record_failure(ip, now);
}
let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0);
assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}"
);
}
#[tokio::test]
async fn mtproto_handshake_abridged_prefix_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
handshake[0] = 0xef; // Abridged prefix
let config = ProxyConfig::default();
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap();
let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
// MTProxy stops immediately on 0xef
assert!(matches!(res, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn mtproto_handshake_preferred_user_mismatch_continues() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret1_hex = "11111111111111111111111111111111";
let secret2_hex = "22222222222222222222222222222222";
let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1);
let mut config = ProxyConfig::default();
config
.access
.users
.insert("user1".to_string(), secret1_hex.to_string());
config
.access
.users
.insert("user2".to_string(), secret2_hex.to_string());
config.access.ignore_time_skew = true;
config.general.modes.secure = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap();
// Even if we prefer user1, if user2 matches, it should succeed.
let res = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
Some("user1"),
)
.await;
if let HandshakeResult::Success((_, _, success)) = res {
assert_eq!(success.user, "user2");
} else {
panic!("Handshake failed even though user2 matched");
}
}
#[tokio::test]
async fn mtproto_handshake_concurrent_flood_stability() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let mut config = test_config_with_secret_hex(secret_hex);
config.access.ignore_time_skew = true;
let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60)));
let config = Arc::new(config);
let mut tasks = Vec::new();
for i in 0..50 {
let base = base;
let config = Arc::clone(&config);
let replay_checker = Arc::clone(&replay_checker);
let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap();
tasks.push(tokio::spawn(async move {
let res = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
matches!(res, HandshakeResult::Success(_))
}));
}
// We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk),
// but the system must not panic or hang.
for task in tasks {
let _ = task.await.unwrap();
}
}
#[tokio::test]
async fn mtproto_replay_is_rejected_across_distinct_peers() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "0123456789abcdeffedcba9876543210";
let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let first_peer: SocketAddr = "198.51.100.10:41001".parse().unwrap();
let second_peer: SocketAddr = "198.51.100.11:41002".parse().unwrap();
let first = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
first_peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(first, HandshakeResult::Success(_)));
let replay = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
second_peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(replay, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "89abcdef012345670123456789abcdef";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(8192, Duration::from_secs(60));
for i in 0..512usize {
let mut mutated = base;
let pos = (SKIP_LEN + (i * 31) % (HANDSHAKE_LEN - SKIP_LEN)).min(HANDSHAKE_LEN - 1);
mutated[pos] ^= ((i as u8) | 1).rotate_left((i % 8) as u32);
let peer: SocketAddr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(198, 18, (i / 254) as u8, (i % 254 + 1) as u8)),
42000 + (i % 1000) as u16,
);
let res = tokio::time::timeout(
Duration::from_millis(250),
handle_mtproto_handshake(
&mutated,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed mutation must complete in bounded time");
assert!(
matches!(
res,
HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)
),
"mutation corpus must stay within explicit handshake outcomes"
);
}
}
#[tokio::test]
async fn auth_probe_success_clears_throttled_peer_state() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let target_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 90));
let now = Instant::now();
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(target_ip, now);
}
assert!(auth_probe_is_throttled(target_ip, now));
auth_probe_record_success(target_ip);
assert!(
!auth_probe_is_throttled(target_ip, now + Duration::from_millis(1)),
"successful auth must clear per-peer throttle state"
);
}
#[tokio::test]
async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "00112233445566778899aabbccddeeff";
let mut invalid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
invalid[SKIP_LEN + 3] ^= 0xff;
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) {
let peer: SocketAddr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(
10,
(i / 65535) as u8,
((i / 255) % 255) as u8,
(i % 255 + 1) as u8,
)),
43000 + (i % 20000) as u16,
);
let res = handle_mtproto_handshake(
&invalid,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }));
}
let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0);
assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"probe map must remain bounded under invalid storm: {tracked}"
);
}
#[tokio::test]
async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "f0e1d2c3b4a5968778695a4b3c2d1e0f";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(10_000, Duration::from_secs(60));
let mut seed: u64 = 0xC0FF_EE12_3456_789A;
for i in 0..2_048usize {
let mut mutated = base;
for _ in 0..4 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let idx = SKIP_LEN + (seed as usize % (HANDSHAKE_LEN - SKIP_LEN));
mutated[idx] ^= ((seed >> 11) as u8).wrapping_add(1);
}
let peer: SocketAddr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(10, 123, (i / 254) as u8, (i % 254 + 1) as u8)),
45000 + (i % 2000) as u16,
);
let outcome = tokio::time::timeout(
Duration::from_millis(250),
handle_mtproto_handshake(
&mutated,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("mutation iteration must complete in bounded time");
assert!(
matches!(
outcome,
HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)
),
"mutations must remain fail-closed/auth-only"
);
}
}
#[tokio::test]
#[ignore = "heavy soak; run manually"]
async fn mtproto_blackhat_20k_mutation_soak_never_panics() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(50_000, Duration::from_secs(120));
let mut seed: u64 = 0xA5A5_5A5A_DEAD_BEEF;
for i in 0..20_000usize {
let mut mutated = base;
for _ in 0..3 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let idx = SKIP_LEN + (seed as usize % (HANDSHAKE_LEN - SKIP_LEN));
mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1);
}
let peer: SocketAddr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(172, 31, (i / 254) as u8, (i % 254 + 1) as u8)),
47000 + (i % 15000) as u16,
);
let _ = tokio::time::timeout(
Duration::from_millis(250),
handle_mtproto_handshake(
&mutated,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("soak mutation must complete in bounded time");
}
}

View File

@ -0,0 +1,187 @@
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn positive_preauth_throttle_activates_after_failure_threshold() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 20));
let now = Instant::now();
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(ip, now);
}
assert!(
auth_probe_is_throttled(ip, now),
"peer must be throttled once fail streak reaches threshold"
);
}
#[test]
fn negative_unrelated_peer_remains_unthrottled() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let attacker = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 12));
let benign = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 13));
let now = Instant::now();
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(attacker, now);
}
assert!(auth_probe_is_throttled(attacker, now));
assert!(
!auth_probe_is_throttled(benign, now),
"throttle state must stay scoped to normalized peer key"
);
}
#[test]
fn edge_expired_entry_is_pruned_and_no_longer_throttled() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 41));
let base = Instant::now();
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(ip, base);
}
let expired_at = base + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1);
assert!(
!auth_probe_is_throttled(ip, expired_at),
"expired entries must not keep throttling peers"
);
let state = auth_probe_state_map();
assert!(
state.get(&normalize_auth_probe_ip(ip)).is_none(),
"expired lookup should prune stale state"
);
}
#[test]
fn adversarial_saturation_grace_requires_extra_failures_before_preauth_throttle() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let ip = IpAddr::V4(Ipv4Addr::new(198, 18, 0, 7));
let now = Instant::now();
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(ip, now);
}
auth_probe_note_saturation(now);
assert!(
!auth_probe_should_apply_preauth_throttle(ip, now),
"during global saturation, peer must receive configured grace window"
);
for _ in 0..AUTH_PROBE_SATURATION_GRACE_FAILS {
auth_probe_record_failure(ip, now + Duration::from_millis(1));
}
assert!(
auth_probe_should_apply_preauth_throttle(ip, now + Duration::from_millis(1)),
"after grace failures are exhausted, preauth throttle must activate"
);
}
#[test]
fn integration_over_cap_insertion_keeps_probe_map_bounded() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let now = Instant::now();
for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 1024) {
let ip = IpAddr::V4(Ipv4Addr::new(
10,
((idx / 65_536) % 256) as u8,
((idx / 256) % 256) as u8,
(idx % 256) as u8,
));
auth_probe_record_failure(ip, now);
}
let tracked = auth_probe_state_map().len();
assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"probe map must remain hard bounded under insertion storm"
);
}
#[test]
fn light_fuzz_randomized_failures_preserve_cap_and_nonzero_streaks() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let mut seed = 0x4D53_5854_6F66_6175u64;
let now = Instant::now();
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
auth_probe_record_failure(ip, now + Duration::from_millis((seed & 0x3f) as u64));
}
let state = auth_probe_state_map();
assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES);
for entry in state.iter() {
assert!(entry.value().fail_streak > 0);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_failure_flood_keeps_state_hard_capped() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let start = Instant::now();
let mut tasks = Vec::new();
for worker in 0..8u8 {
tasks.push(tokio::spawn(async move {
for i in 0..4096u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
172,
worker,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
auth_probe_record_failure(ip, start + Duration::from_millis((i % 4) as u64));
}
}));
}
for task in tasks {
task.await.expect("stress worker must not panic");
}
let tracked = auth_probe_state_map().len();
assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"parallel failure flood must not exceed cap"
);
let probe = IpAddr::V4(Ipv4Addr::new(172, 3, 4, 5));
let _ = auth_probe_is_throttled(probe, start + Duration::from_millis(2));
}

View File

@ -0,0 +1,278 @@
use super::*;
use crate::config::ProxyConfig;
use crate::crypto::AesCtr;
use crate::crypto::sha256;
use crate::protocol::constants::ProtoTag;
use crate::stats::ReplayChecker;
use std::net::SocketAddr;
use std::sync::MutexGuard;
use tokio::time::{Duration as TokioDuration, timeout};
fn make_mtproto_handshake_with_proto_bytes(
secret_hex: &str,
proto_bytes: [u8; 4],
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes);
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx)
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg
}
fn auth_probe_test_guard() -> MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[tokio::test]
async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "11223344556677889900aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
let first = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(first, HandshakeResult::Success(_)));
let second = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(second, HandshakeResult::BadClient { .. }));
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "00112233445566778899aabbccddeeff";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap();
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0x00, 0x00, 0x00, 0x00],
1,
));
corpus.push(make_mtproto_handshake_with_proto_bytes(
secret_hex,
[0xff, 0xff, 0xff, 0xff],
1,
));
corpus.push(make_valid_mtproto_handshake(
"ffeeddccbbaa99887766554433221100",
ProtoTag::Secure,
1,
));
let mut seed = 0xF0F0_F00D_BAAD_CAFEu64;
for _ in 0..32 {
let mut mutated = base;
for _ in 0..4 {
seed = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.into_iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}
#[tokio::test]
async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "99887766554433221100ffeeddccbbaa";
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4);
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60));
let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap();
let first = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("base handshake must not hang");
assert!(matches!(first, HandshakeResult::Success(_)));
let replay = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("duplicate handshake must not hang");
assert!(matches!(replay, HandshakeResult::BadClient { .. }));
let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new();
let mut prekey_flip = base;
prekey_flip[SKIP_LEN] ^= 0x80;
corpus.push(prekey_flip);
let mut iv_flip = base;
iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01;
corpus.push(iv_flip);
let mut tail_flip = base;
tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40;
corpus.push(tail_flip);
let mut seed = 0xBADC_0FFE_EE11_4242u64;
for _ in 0..24 {
let mut mutated = base;
for _ in 0..3 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1);
}
corpus.push(mutated);
}
for (idx, input) in corpus.iter().enumerate() {
let result = timeout(
TokioDuration::from_secs(1),
handle_mtproto_handshake(
input,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
),
)
.await
.expect("fuzzed handshake must complete in time");
assert!(
matches!(result, HandshakeResult::BadClient { .. }),
"mixed corpus item {idx} must fail closed"
);
}
clear_auth_probe_state_for_testing();
}

View File

@ -0,0 +1,71 @@
use super::*;
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn poison_saturation_mutex() {
let saturation = auth_probe_saturation_state();
let poison_thread = std::thread::spawn(move || {
let _guard = saturation
.lock()
.expect("saturation mutex must be lockable for poison setup");
panic!("intentional poison for saturation mutex resilience test");
});
let _ = poison_thread.join();
}
#[test]
fn auth_probe_saturation_note_recovers_after_mutex_poison() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
poison_saturation_mutex();
let now = Instant::now();
auth_probe_note_saturation(now);
assert!(
auth_probe_saturation_is_throttled_at_for_testing(now),
"poisoned saturation mutex must not disable saturation throttling"
);
}
#[test]
fn auth_probe_saturation_check_recovers_after_mutex_poison() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
poison_saturation_mutex();
{
let mut guard = auth_probe_saturation_state_lock();
*guard = Some(AuthProbeSaturationState {
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS,
blocked_until: Instant::now() + Duration::from_millis(10),
last_seen: Instant::now(),
});
}
assert!(
auth_probe_saturation_is_throttled_for_testing(),
"throttle check must recover poisoned saturation mutex and stay fail-closed"
);
}
#[test]
fn clear_auth_probe_state_clears_saturation_even_if_poisoned() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
poison_saturation_mutex();
auth_probe_note_saturation(Instant::now());
assert!(auth_probe_saturation_is_throttled_for_testing());
clear_auth_probe_state_for_testing();
assert!(
!auth_probe_saturation_is_throttled_for_testing(),
"clear helper must clear saturation state even after poison"
);
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,577 @@
use super::*;
use std::collections::BTreeSet;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
#[derive(Clone, Copy)]
enum PathClass {
ConnectFail,
ConnectSuccess,
SlowBackend,
}
fn mean_ms(samples: &[u128]) -> f64 {
if samples.is_empty() {
return 0.0;
}
let sum: u128 = samples.iter().copied().sum();
sum as f64 / samples.len() as f64
}
fn percentile_ms(mut values: Vec<u128>, p_num: usize, p_den: usize) -> u128 {
values.sort_unstable();
if values.is_empty() {
return 0;
}
let idx = ((values.len() - 1) * p_num) / p_den;
values[idx]
}
fn bucketize_ms(values: &[u128], bucket_ms: u128) -> Vec<u128> {
values.iter().map(|v| *v / bucket_ms).collect()
}
fn best_threshold_accuracy_u128(a: &[u128], b: &[u128]) -> f64 {
let min_v = *a.iter().chain(b.iter()).min().unwrap();
let max_v = *a.iter().chain(b.iter()).max().unwrap();
let mut best = 0.0f64;
for t in min_v..=max_v {
let correct_a = a.iter().filter(|&&x| x <= t).count();
let correct_b = b.iter().filter(|&&x| x > t).count();
let acc = (correct_a + correct_b) as f64 / (a.len() + b.len()) as f64;
if acc > best {
best = acc;
}
}
best
}
fn spread_u128(values: &[u128]) -> u128 {
if values.is_empty() {
return 0;
}
let min_v = *values.iter().min().unwrap();
let max_v = *values.iter().max().unwrap();
max_v - min_v
}
fn interval_gap_usize(a: &BTreeSet<usize>, b: &BTreeSet<usize>) -> usize {
if a.is_empty() || b.is_empty() {
return 0;
}
let a_min = *a.iter().next().unwrap();
let a_max = *a.iter().next_back().unwrap();
let b_min = *b.iter().next().unwrap();
let b_max = *b.iter().next_back().unwrap();
if a_max < b_min {
b_min - a_max
} else if b_max < a_min {
a_min - b_max
} else {
0
}
}
async fn collect_timing_samples(path: PathClass, timing_norm_enabled: bool, n: usize) -> Vec<u128> {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(measure_masking_duration_ms(path, timing_norm_enabled).await);
}
out
}
async fn measure_masking_duration_ms(path: PathClass, timing_norm_enabled: bool) -> u128 {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_timing_normalization_enabled = timing_norm_enabled;
config.censorship.mask_timing_normalization_floor_ms = 220;
config.censorship.mask_timing_normalization_ceiling_ms = 260;
let accept_task = match path {
PathClass::ConnectFail => {
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1;
None
}
PathClass::ConnectSuccess => {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
Some(tokio::spawn(async move {
let (_stream, _) = listener.accept().await.unwrap();
}))
}
PathClass::SlowBackend => {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
Some(tokio::spawn(async move {
let (_stream, _) = listener.accept().await.unwrap();
tokio::time::sleep(Duration::from_millis(320)).await;
}))
}
};
let (client_reader, _client_writer) = duplex(1024);
let (_client_visible_reader, client_visible_writer) = duplex(1024);
let peer: SocketAddr = "198.51.100.230:57230".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"GET /ab-harness HTTP/1.1\r\nHost: x\r\n\r\n",
peer,
local,
&config,
&beobachten,
)
.await;
if let Some(task) = accept_task {
let _ = tokio::time::timeout(Duration::from_secs(2), task).await;
}
started.elapsed().as_millis()
}
async fn capture_above_cap_forwarded_len(
body_sent: usize,
above_cap_blur_enabled: bool,
above_cap_blur_max_bytes: usize,
) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_shape_hardening = true;
config.censorship.mask_shape_bucket_floor_bytes = 512;
config.censorship.mask_shape_bucket_cap_bytes = 4096;
config.censorship.mask_shape_above_cap_blur = above_cap_blur_enabled;
config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (client_reader, mut client_writer) = duplex(64 * 1024);
let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024);
let peer: SocketAddr = "198.51.100.231:57231".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let mut initial = vec![0u8; 5 + body_sent];
initial[0] = 0x16;
initial[1] = 0x03;
initial[2] = 0x01;
initial[3..5].copy_from_slice(&7000u16.to_be_bytes());
initial[5..].fill(0x5A);
let fallback_task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&initial,
peer,
local,
&config,
&beobachten,
)
.await;
});
client_writer.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), fallback_task)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn integration_ab_harness_envelope_and_blur_improve_obfuscation_vs_baseline() {
const ITER: usize = 8;
let mut baseline_fail = Vec::with_capacity(ITER);
let mut baseline_success = Vec::with_capacity(ITER);
let mut baseline_slow = Vec::with_capacity(ITER);
let mut hardened_fail = Vec::with_capacity(ITER);
let mut hardened_success = Vec::with_capacity(ITER);
let mut hardened_slow = Vec::with_capacity(ITER);
for _ in 0..ITER {
baseline_fail.push(measure_masking_duration_ms(PathClass::ConnectFail, false).await);
baseline_success.push(measure_masking_duration_ms(PathClass::ConnectSuccess, false).await);
baseline_slow.push(measure_masking_duration_ms(PathClass::SlowBackend, false).await);
hardened_fail.push(measure_masking_duration_ms(PathClass::ConnectFail, true).await);
hardened_success.push(measure_masking_duration_ms(PathClass::ConnectSuccess, true).await);
hardened_slow.push(measure_masking_duration_ms(PathClass::SlowBackend, true).await);
}
let baseline_means = [
mean_ms(&baseline_fail),
mean_ms(&baseline_success),
mean_ms(&baseline_slow),
];
let hardened_means = [
mean_ms(&hardened_fail),
mean_ms(&hardened_success),
mean_ms(&hardened_slow),
];
let baseline_range = baseline_means
.iter()
.copied()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(mn, mx), v| {
(mn.min(v), mx.max(v))
});
let hardened_range = hardened_means
.iter()
.copied()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(mn, mx), v| {
(mn.min(v), mx.max(v))
});
let baseline_spread = baseline_range.1 - baseline_range.0;
let hardened_spread = hardened_range.1 - hardened_range.0;
println!(
"ab_harness_timing baseline_means={:?} hardened_means={:?} baseline_spread={:.2} hardened_spread={:.2}",
baseline_means, hardened_means, baseline_spread, hardened_spread
);
assert!(
hardened_spread < baseline_spread,
"timing envelope should reduce cross-path mean spread: baseline={baseline_spread:.2} hardened={hardened_spread:.2}"
);
let mut baseline_a = BTreeSet::new();
let mut baseline_b = BTreeSet::new();
let mut hardened_a = BTreeSet::new();
let mut hardened_b = BTreeSet::new();
for _ in 0..24 {
baseline_a.insert(capture_above_cap_forwarded_len(5000, false, 0).await);
baseline_b.insert(capture_above_cap_forwarded_len(5040, false, 0).await);
hardened_a.insert(capture_above_cap_forwarded_len(5000, true, 96).await);
hardened_b.insert(capture_above_cap_forwarded_len(5040, true, 96).await);
}
let baseline_overlap = baseline_a.intersection(&baseline_b).count();
let hardened_overlap = hardened_a.intersection(&hardened_b).count();
let baseline_gap = interval_gap_usize(&baseline_a, &baseline_b);
let hardened_gap = interval_gap_usize(&hardened_a, &hardened_b);
println!(
"ab_harness_length baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={} baseline_a={} baseline_b={} hardened_a={} hardened_b={}",
baseline_overlap,
hardened_overlap,
baseline_gap,
hardened_gap,
baseline_a.len(),
baseline_b.len(),
hardened_a.len(),
hardened_b.len()
);
assert_eq!(
baseline_overlap, 0,
"baseline above-cap classes should be disjoint"
);
assert!(
hardened_a.len() > baseline_a.len() && hardened_b.len() > baseline_b.len(),
"above-cap blur should widen per-class emitted lengths: baseline_a={} baseline_b={} hardened_a={} hardened_b={}",
baseline_a.len(),
baseline_b.len(),
hardened_a.len(),
hardened_b.len()
);
assert!(
hardened_overlap > baseline_overlap || hardened_gap < baseline_gap,
"above-cap blur should reduce class separability via direct overlap or tighter interval gap: baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={}",
baseline_overlap,
hardened_overlap,
baseline_gap,
hardened_gap
);
}
#[test]
fn timing_classifier_helper_bucketize_is_stable() {
let values = vec![219u128, 220, 239, 240, 259, 260];
let got = bucketize_ms(&values, 20);
assert_eq!(got, vec![10, 11, 11, 12, 12, 13]);
}
#[test]
fn timing_classifier_helper_percentile_is_monotonic() {
let samples = vec![210u128, 220, 230, 240, 250, 260, 270, 280];
let p50 = percentile_ms(samples.clone(), 50, 100);
let p95 = percentile_ms(samples.clone(), 95, 100);
assert!(p95 >= p50);
}
#[test]
fn timing_classifier_helper_threshold_accuracy_is_perfect_for_disjoint_sets() {
let a = vec![10u128, 11, 12, 13, 14];
let b = vec![20u128, 21, 22, 23, 24];
let acc = best_threshold_accuracy_u128(&a, &b);
assert!(acc >= 0.99);
}
#[test]
fn timing_classifier_helper_threshold_accuracy_drops_for_identical_sets() {
let a = vec![10u128, 11, 12, 13, 14];
let b = vec![10u128, 11, 12, 13, 14];
let acc = best_threshold_accuracy_u128(&a, &b);
assert!(
acc <= 0.6,
"identical sets should not be strongly separable"
);
}
#[test]
fn timing_classifier_helper_bucketed_threshold_reduces_resolution() {
let raw_a = vec![221u128, 223, 225, 227, 229];
let raw_b = vec![231u128, 233, 235, 237, 239];
let raw_acc = best_threshold_accuracy_u128(&raw_a, &raw_b);
let bucketed_a = bucketize_ms(&raw_a, 20);
let bucketed_b = bucketize_ms(&raw_b, 20);
let bucketed_acc = best_threshold_accuracy_u128(&bucketed_a, &bucketed_b);
assert!(raw_acc >= bucketed_acc);
}
#[tokio::test]
async fn timing_classifier_baseline_connect_fail_vs_slow_backend_is_highly_separable() {
let fail = collect_timing_samples(PathClass::ConnectFail, false, 8).await;
let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await;
let acc = best_threshold_accuracy_u128(&fail, &slow);
assert!(
acc >= 0.80,
"baseline timing classes should be separable enough"
);
}
#[tokio::test]
async fn timing_classifier_normalized_connect_fail_vs_slow_backend_reduces_separability() {
let baseline_fail = collect_timing_samples(PathClass::ConnectFail, false, 8).await;
let baseline_slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await;
let hardened_fail = collect_timing_samples(PathClass::ConnectFail, true, 8).await;
let hardened_slow = collect_timing_samples(PathClass::SlowBackend, true, 8).await;
let baseline_acc = best_threshold_accuracy_u128(&baseline_fail, &baseline_slow);
let hardened_acc = best_threshold_accuracy_u128(&hardened_fail, &hardened_slow);
assert!(
hardened_acc <= baseline_acc,
"normalization should not increase timing separability"
);
}
#[tokio::test]
async fn timing_classifier_bucketed_normalized_connect_fail_vs_slow_backend_is_bounded() {
let baseline_fail = collect_timing_samples(PathClass::ConnectFail, false, 10).await;
let baseline_slow = collect_timing_samples(PathClass::SlowBackend, false, 10).await;
let hardened_fail = collect_timing_samples(PathClass::ConnectFail, true, 10).await;
let hardened_slow = collect_timing_samples(PathClass::SlowBackend, true, 10).await;
let baseline_acc = best_threshold_accuracy_u128(
&bucketize_ms(&baseline_fail, 20),
&bucketize_ms(&baseline_slow, 20),
);
let hardened_acc = best_threshold_accuracy_u128(
&bucketize_ms(&hardened_fail, 20),
&bucketize_ms(&hardened_slow, 20),
);
assert!(
hardened_acc <= baseline_acc,
"normalized bucketed classifier should not outperform baseline: baseline={baseline_acc:.3} hardened={hardened_acc:.3}"
);
}
#[tokio::test]
async fn timing_classifier_normalized_connect_fail_samples_stay_in_sane_bounds() {
let samples = collect_timing_samples(PathClass::ConnectFail, true, 6).await;
for s in samples {
assert!((150..=1200).contains(&s), "sample out of sane bounds: {s}");
}
}
#[tokio::test]
async fn timing_classifier_normalized_connect_success_samples_stay_in_sane_bounds() {
let samples = collect_timing_samples(PathClass::ConnectSuccess, true, 6).await;
for s in samples {
assert!((150..=1200).contains(&s), "sample out of sane bounds: {s}");
}
}
#[tokio::test]
async fn timing_classifier_normalized_slow_backend_samples_stay_in_sane_bounds() {
let samples = collect_timing_samples(PathClass::SlowBackend, true, 6).await;
for s in samples {
assert!((150..=1400).contains(&s), "sample out of sane bounds: {s}");
}
}
#[tokio::test]
async fn timing_classifier_normalized_mean_bucket_delta_connect_fail_vs_connect_success_is_small() {
let fail = collect_timing_samples(PathClass::ConnectFail, true, 8).await;
let success = collect_timing_samples(PathClass::ConnectSuccess, true, 8).await;
let fail_mean = mean_ms(&fail);
let success_mean = mean_ms(&success);
let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20;
assert!(
delta_bucket <= 3,
"mean bucket delta too large: {delta_bucket}"
);
}
#[tokio::test]
async fn timing_classifier_normalized_p95_bucket_delta_connect_success_vs_slow_is_small() {
let success = collect_timing_samples(PathClass::ConnectSuccess, true, 10).await;
let slow = collect_timing_samples(PathClass::SlowBackend, true, 10).await;
let p95_success = percentile_ms(success, 95, 100);
let p95_slow = percentile_ms(slow, 95, 100);
let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20;
assert!(
delta_bucket <= 4,
"p95 bucket delta too large: {delta_bucket}"
);
}
#[tokio::test]
async fn timing_classifier_normalized_spread_is_not_worse_than_baseline_for_connect_fail() {
let baseline = collect_timing_samples(PathClass::ConnectFail, false, 8).await;
let hardened = collect_timing_samples(PathClass::ConnectFail, true, 8).await;
let baseline_spread = spread_u128(&baseline);
let hardened_spread = spread_u128(&hardened);
assert!(
hardened_spread <= baseline_spread.saturating_add(600),
"normalized spread exploded unexpectedly: baseline={baseline_spread} hardened={hardened_spread}"
);
}
#[tokio::test]
async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization()
{
const SAMPLE_COUNT: usize = 6;
let pairs = [
(PathClass::ConnectFail, PathClass::ConnectSuccess),
(PathClass::ConnectFail, PathClass::SlowBackend),
(PathClass::ConnectSuccess, PathClass::SlowBackend),
];
let mut meaningful_improvement_seen = false;
let mut baseline_sum = 0.0f64;
let mut hardened_sum = 0.0f64;
let mut pair_count = 0usize;
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
let tolerated_pair_regression = acc_quant_step + 0.03;
for (a, b) in pairs {
let baseline_a = collect_timing_samples(a, false, SAMPLE_COUNT).await;
let baseline_b = collect_timing_samples(b, false, SAMPLE_COUNT).await;
let hardened_a = collect_timing_samples(a, true, SAMPLE_COUNT).await;
let hardened_b = collect_timing_samples(b, true, SAMPLE_COUNT).await;
let baseline_acc = best_threshold_accuracy_u128(
&bucketize_ms(&baseline_a, 20),
&bucketize_ms(&baseline_b, 20),
);
let hardened_acc = best_threshold_accuracy_u128(
&bucketize_ms(&hardened_a, 20),
&bucketize_ms(&hardened_b, 20),
);
// When baseline separability is near-random, tiny sample jitter can make
// hardened appear "worse" without indicating a real side-channel regression.
// Guard hard only on informative baseline pairs.
if baseline_acc >= 0.75 {
assert!(
hardened_acc <= baseline_acc + tolerated_pair_regression,
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}"
);
}
println!(
"timing_classifier_pair baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated_pair_regression={tolerated_pair_regression:.3}"
);
if hardened_acc + 0.05 <= baseline_acc {
meaningful_improvement_seen = true;
}
baseline_sum += baseline_acc;
hardened_sum += hardened_acc;
pair_count += 1;
}
let baseline_avg = baseline_sum / pair_count as f64;
let hardened_avg = hardened_sum / pair_count as f64;
assert!(
hardened_avg <= baseline_avg + 0.10,
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
);
// Optional signal only: do not require improvement on every run because
// noisy CI schedulers can flatten pairwise differences at low sample counts.
let _ = meaningful_improvement_seen;
}
#[tokio::test]
async fn timing_classifier_stress_parallel_sampling_finishes_and_stays_bounded() {
let mut tasks = Vec::new();
for i in 0..24usize {
tasks.push(tokio::spawn(async move {
let class = match i % 3 {
0 => PathClass::ConnectFail,
1 => PathClass::ConnectSuccess,
_ => PathClass::SlowBackend,
};
let sample = measure_masking_duration_ms(class, true).await;
assert!(
(100..=1600).contains(&sample),
"stress sample out of bounds: {sample}"
);
}));
}
for task in tasks {
tokio::time::timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
}
}

View File

@ -0,0 +1,806 @@
use super::*;
use crate::config::ProxyConfig;
use crate::proxy::relay::relay_bidirectional;
use crate::stats::Stats;
use crate::stats::beobachten::BeobachtenStore;
use crate::stream::BufferPool;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
// ------------------------------------------------------------------
// Probing Indistinguishability (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_probes_indistinguishable_timing() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 80; // Should timeout/refuse
let peer: SocketAddr = "192.0.2.10:443".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
// Test different probe types
let probes = vec![
(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"),
(b"SSH-2.0-probe".to_vec(), "SSH"),
(
vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00],
"TLS-scanner",
),
(vec![0x42; 5], "port-scanner"),
];
for (probe, type_name) in probes {
let (client_reader, _client_writer) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let start = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = start.elapsed();
// We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests)
// to mask whether the backend was reachable or refused.
assert!(
elapsed >= Duration::from_millis(30),
"Probe {type_name} finished too fast: {elapsed:?}"
);
}
}
// ------------------------------------------------------------------
// Masking Budget Stress Tests (OWASP ASVS 5.1.6)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_budget_stress_under_load() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1; // Unlikely port
let peer: SocketAddr = "192.0.2.20:443".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = Arc::new(BeobachtenStore::new());
let mut tasks = Vec::new();
for _ in 0..50 {
let (client_reader, _client_writer) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let config = config.clone();
let beobachten = Arc::clone(&beobachten);
tasks.push(tokio::spawn(async move {
let start = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"probe",
peer,
local_addr,
&config,
&beobachten,
)
.await;
start.elapsed()
}));
}
for task in tasks {
let elapsed = task.await.unwrap();
assert!(
elapsed >= Duration::from_millis(30),
"Stress probe finished too fast: {elapsed:?}"
);
}
}
// ------------------------------------------------------------------
// detect_client_type Fingerprint Check
// ------------------------------------------------------------------
#[test]
fn test_detect_client_type_boundary_cases() {
// 9 bytes = port-scanner
assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner");
// 10 bytes = unknown
assert_eq!(detect_client_type(&[0x42; 10]), "unknown");
// HTTP verbs without trailing space
assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10
assert_eq!(detect_client_type(b"GET /path"), "HTTP");
}
// ------------------------------------------------------------------
// Priority 2: Slowloris and Slow Read Attacks (OWASP ASVS 5.1.5)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_slowloris_client_idle_timeout_rejected() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let initial = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let initial = initial.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; initial.len()];
stream.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, initial);
let mut drip = [0u8; 1];
let drip_read =
tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip))
.await;
assert!(
drip_read.is_err() || drip_read.unwrap().is_err(),
"backend must not receive post-timeout slowloris drip bytes"
);
}
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let beobachten = BeobachtenStore::new();
let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap();
let local: SocketAddr = "192.0.2.1:443".parse().unwrap();
let (mut client_writer, client_reader) = duplex(1024);
let (_client_visible_reader, client_visible_writer) = duplex(1024);
let handle = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&initial,
peer,
local,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(160)).await;
let _ = client_writer.write_all(b"X").await;
handle.await.unwrap();
accept_task.await.unwrap();
}
// ------------------------------------------------------------------
// Priority 2: Fallback Server Down / Fingerprinting (OWASP ASVS 5.1.7)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_fallback_down_mimics_timeout() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1; // Unlikely port
let (server_reader, server_writer) = duplex(1024);
let beobachten = BeobachtenStore::new();
let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap();
let local: SocketAddr = "192.0.2.1:443".parse().unwrap();
let start = Instant::now();
handle_bad_client(
server_reader,
server_writer,
b"GET / HTTP/1.1\r\n",
peer,
local,
&config,
&beobachten,
)
.await;
let elapsed = start.elapsed();
// It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately
assert!(
elapsed >= Duration::from_millis(40),
"Must respect connect budget even on failure: {:?}",
elapsed
);
}
// ------------------------------------------------------------------
// Priority 2: SSRF Prevention (OWASP ASVS 5.1.2)
// ------------------------------------------------------------------
#[tokio::test]
async fn masking_ssrf_resolve_internal_ranges_blocked() {
use crate::network::dns_overrides::resolve_socket_addr;
let blocked_ips = [
"127.0.0.1",
"169.254.169.254",
"10.0.0.1",
"192.168.1.1",
"0.0.0.0",
];
for ip in blocked_ips {
assert!(
resolve_socket_addr(ip, 80).is_none(),
"runtime DNS overrides must not resolve unconfigured literal host targets"
);
}
}
#[tokio::test]
async fn masking_unknown_proxy_protocol_version_falls_back_to_v1_unknown_header() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut header = [0u8; 15];
stream.read_exact(&mut header).await.unwrap();
assert_eq!(&header, b"PROXY UNKNOWN\r\n");
let mut payload = [0u8; 5];
stream.read_exact(&mut payload).await.unwrap();
assert_eq!(&payload, b"probe");
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 255;
let peer: SocketAddr = "198.51.100.77:50001".parse().unwrap();
let local_addr: SocketAddr = "[2001:db8::10]:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (client_reader, _client_writer) = duplex(128);
let (_client_visible_reader, client_visible_writer) = duplex(128);
handle_bad_client(
client_reader,
client_visible_writer,
b"probe",
peer,
local_addr,
&config,
&beobachten,
)
.await;
accept_task.await.unwrap();
}
#[tokio::test]
async fn masking_zero_length_initial_data_does_not_hang_or_panic() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut one = [0u8; 1];
let n = tokio::time::timeout(Duration::from_millis(150), stream.read(&mut one))
.await
.unwrap()
.unwrap();
assert_eq!(
n, 0,
"backend must observe clean EOF for empty initial payload"
);
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let peer: SocketAddr = "203.0.113.70:50002".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (client_reader, client_writer) = duplex(64);
drop(client_writer);
let (_client_visible_reader, client_visible_writer) = duplex(64);
handle_bad_client(
client_reader,
client_visible_writer,
b"",
peer,
local,
&config,
&beobachten,
)
.await;
accept_task.await.unwrap();
}
#[tokio::test]
async fn masking_oversized_initial_payload_is_forwarded_verbatim() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let payload = vec![0xA5u8; 32 * 1024];
let accept_task = tokio::spawn({
let payload = payload.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; payload.len()];
stream.read_exact(&mut observed).await.unwrap();
assert_eq!(
observed, payload,
"large initial payload must stay byte-for-byte"
);
}
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let peer: SocketAddr = "203.0.113.71:50003".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (client_reader, _client_writer) = duplex(64);
let (_client_visible_reader, client_visible_writer) = duplex(64);
handle_bad_client(
client_reader,
client_visible_writer,
&payload,
peer,
local,
&config,
&beobachten,
)
.await;
accept_task.await.unwrap();
}
#[tokio::test]
async fn masking_refused_backend_keeps_constantish_timing_floor_under_burst() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1;
let peer: SocketAddr = "203.0.113.72:50004".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
for _ in 0..16 {
let (client_reader, _client_writer) = duplex(128);
let (_client_visible_reader, client_visible_writer) = duplex(128);
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"GET / HTTP/1.1\r\n",
peer,
local,
&config,
&beobachten,
)
.await;
assert!(
started.elapsed() >= Duration::from_millis(30),
"refused-backend path must keep timing floor to reduce fingerprinting"
);
}
}
#[tokio::test]
async fn masking_backend_half_close_then_client_half_close_completes_without_hang() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut pre = [0u8; 4];
stream.read_exact(&mut pre).await.unwrap();
assert_eq!(&pre, b"PING");
stream.write_all(b"PONG").await.unwrap();
stream.shutdown().await.unwrap();
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
let peer: SocketAddr = "203.0.113.73:50005".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client_writer, client_reader) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(256);
let handle = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
b"PING",
peer,
local,
&config,
&beobachten,
)
.await;
});
client_writer.shutdown().await.unwrap();
let mut got = [0u8; 4];
client_visible_reader.read_exact(&mut got).await.unwrap();
assert_eq!(&got, b"PONG");
timeout(Duration::from_secs(2), handle)
.await
.expect("masking task must terminate after bilateral half-close")
.unwrap();
accept_task.await.unwrap();
}
#[tokio::test]
async fn chaos_burst_reconnect_storm_for_masking_and_relay_concurrently() {
const MASKING_SESSIONS: usize = 48;
const RELAY_SESSIONS: usize = 48;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let backend_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
for _ in 0..MASKING_SESSIONS {
let (mut stream, _) = listener.accept().await.unwrap();
let mut req = [0u8; 32];
stream.read_exact(&mut req).await.unwrap();
assert!(
req.starts_with(b"GET /storm/"),
"masking backend must receive storm reconnect probes"
);
stream.write_all(&backend_reply).await.unwrap();
stream.shutdown().await.unwrap();
}
}
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0;
let config = Arc::new(config);
let beobachten = Arc::new(BeobachtenStore::new());
let peer: SocketAddr = "198.51.100.200:55555".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let mut masking_tasks = Vec::with_capacity(MASKING_SESSIONS);
for i in 0..MASKING_SESSIONS {
let config = Arc::clone(&config);
let beobachten = Arc::clone(&beobachten);
let expected_reply = backend_reply.clone();
masking_tasks.push(tokio::spawn(async move {
let mut probe = [0u8; 32];
let template = format!("GET /storm/{i:04} HTTP/1.1\r\n\r\n");
let bytes = template.as_bytes();
probe[..bytes.len()].copy_from_slice(bytes);
let (client_reader, client_writer) = duplex(256);
drop(client_writer);
let (mut client_visible_reader, client_visible_writer) = duplex(1024);
let handle = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local,
&config,
&beobachten,
)
.await;
});
let mut observed = vec![0u8; expected_reply.len()];
client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, expected_reply);
timeout(Duration::from_secs(2), handle)
.await
.expect("masking reconnect task must complete")
.unwrap();
}));
}
let mut relay_tasks = Vec::with_capacity(RELAY_SESSIONS);
for i in 0..RELAY_SESSIONS {
relay_tasks.push(tokio::spawn(async move {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
"chaos-storm-relay",
stats,
None,
Arc::new(BufferPool::new()),
));
let c2s = vec![(i as u8).wrapping_add(1); 64];
client_peer.write_all(&c2s).await.unwrap();
let mut c2s_seen = vec![0u8; c2s.len()];
server_peer.read_exact(&mut c2s_seen).await.unwrap();
assert_eq!(c2s_seen, c2s);
let s2c = vec![(i as u8).wrapping_add(17); 96];
server_peer.write_all(&s2c).await.unwrap();
let mut s2c_seen = vec![0u8; s2c.len()];
client_peer.read_exact(&mut s2c_seen).await.unwrap();
assert_eq!(s2c_seen, s2c);
drop(client_peer);
drop(server_peer);
timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay reconnect task must complete")
.unwrap()
.unwrap();
}));
}
for task in masking_tasks {
timeout(Duration::from_secs(3), task)
.await
.expect("masking storm join must complete")
.unwrap();
}
for task in relay_tasks {
timeout(Duration::from_secs(3), task)
.await
.expect("relay storm join must complete")
.unwrap();
}
timeout(Duration::from_secs(3), backend_task)
.await
.expect("masking backend accept loop must complete")
.unwrap();
}
fn read_env_usize_or_default(name: &str, default: usize) -> usize {
match std::env::var(name) {
Ok(raw) => match raw.parse::<usize>() {
Ok(parsed) if parsed > 0 => parsed,
_ => default,
},
Err(_) => default,
}
}
#[tokio::test]
#[ignore = "heavy soak; run manually"]
async fn chaos_burst_reconnect_storm_for_masking_and_relay_multiwave_soak() {
let waves = read_env_usize_or_default("CHAOS_WAVES", 4);
let masking_per_wave = read_env_usize_or_default("CHAOS_MASKING_PER_WAVE", 160);
let relay_per_wave = read_env_usize_or_default("CHAOS_RELAY_PER_WAVE", 160);
let total_masking = waves * masking_per_wave;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let backend_task = tokio::spawn({
let backend_reply = backend_reply.clone();
async move {
for _ in 0..total_masking {
let (mut stream, _) = listener.accept().await.unwrap();
let mut req = [0u8; 32];
stream.read_exact(&mut req).await.unwrap();
assert!(
req.starts_with(b"GET /storm/"),
"mask backend must only receive storm probes"
);
stream.write_all(&backend_reply).await.unwrap();
stream.shutdown().await.unwrap();
}
}
});
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0;
let config = Arc::new(config);
let beobachten = Arc::new(BeobachtenStore::new());
let peer: SocketAddr = "198.51.100.201:56565".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
for wave in 0..waves {
let mut masking_tasks = Vec::with_capacity(masking_per_wave);
for i in 0..masking_per_wave {
let config = Arc::clone(&config);
let beobachten = Arc::clone(&beobachten);
let expected_reply = backend_reply.clone();
masking_tasks.push(tokio::spawn(async move {
let mut probe = [0u8; 32];
let template = format!("GET /storm/{wave:02}-{i:03}\r\n\r\n");
let bytes = template.as_bytes();
probe[..bytes.len()].copy_from_slice(bytes);
let (client_reader, client_writer) = duplex(256);
drop(client_writer);
let (mut client_visible_reader, client_visible_writer) = duplex(1024);
let handle = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local,
&config,
&beobachten,
)
.await;
});
let mut observed = vec![0u8; expected_reply.len()];
client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, expected_reply);
timeout(Duration::from_secs(3), handle)
.await
.expect("masking storm task must complete")
.unwrap();
}));
}
let mut relay_tasks = Vec::with_capacity(relay_per_wave);
for i in 0..relay_per_wave {
relay_tasks.push(tokio::spawn(async move {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
"chaos-multiwave-relay",
stats,
None,
Arc::new(BufferPool::new()),
));
let c2s = vec![(wave as u8).wrapping_add(i as u8).wrapping_add(1); 32];
client_peer.write_all(&c2s).await.unwrap();
let mut c2s_seen = vec![0u8; c2s.len()];
server_peer.read_exact(&mut c2s_seen).await.unwrap();
assert_eq!(c2s_seen, c2s);
let s2c = vec![(wave as u8).wrapping_add(i as u8).wrapping_add(17); 48];
server_peer.write_all(&s2c).await.unwrap();
let mut s2c_seen = vec![0u8; s2c.len()];
client_peer.read_exact(&mut s2c_seen).await.unwrap();
assert_eq!(s2c_seen, s2c);
drop(client_peer);
drop(server_peer);
timeout(Duration::from_secs(3), relay_task)
.await
.expect("relay storm task must complete")
.unwrap()
.unwrap();
}));
}
for task in masking_tasks {
timeout(Duration::from_secs(6), task)
.await
.expect("masking wave task join must complete")
.unwrap();
}
for task in relay_tasks {
timeout(Duration::from_secs(6), task)
.await
.expect("relay wave task join must complete")
.unwrap();
}
}
timeout(Duration::from_secs(8), backend_task)
.await
.expect("mask backend must complete all accepted storm sessions")
.unwrap();
}
#[tokio::test]
#[ignore = "heavy soak; run manually"]
async fn masking_timing_bucket_soak_refused_backend_stays_within_narrow_band() {
let mut config = ProxyConfig::default();
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1;
let peer: SocketAddr = "203.0.113.74:50006".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let mut samples = Vec::with_capacity(128);
for _ in 0..128 {
let (client_reader, _client_writer) = duplex(128);
let (_client_visible_reader, client_visible_writer) = duplex(128);
let started = Instant::now();
handle_bad_client(
client_reader,
client_visible_writer,
b"GET / HTTP/1.1\r\n",
peer,
local,
&config,
&beobachten,
)
.await;
samples.push(started.elapsed().as_millis());
}
samples.sort_unstable();
let p10 = samples[samples.len() / 10];
let p90 = samples[(samples.len() * 9) / 10];
assert!(
p90.saturating_sub(p10) <= 40,
"timing spread too wide for refused-backend masking path: p10={p10}ms p90={p90}ms"
);
}

View File

@ -0,0 +1,107 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
async fn capture_forwarded_len_with_mode(
body_sent: usize,
close_client_after_write: bool,
aggressive_mode: bool,
above_cap_blur: bool,
above_cap_blur_max_bytes: usize,
) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_shape_hardening = true;
config.censorship.mask_shape_hardening_aggressive_mode = aggressive_mode;
config.censorship.mask_shape_bucket_floor_bytes = 512;
config.censorship.mask_shape_bucket_cap_bytes = 4096;
config.censorship.mask_shape_above_cap_blur = above_cap_blur;
config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_reader, mut client_writer) = duplex(64 * 1024);
let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024);
let peer: SocketAddr = "198.51.100.248:57248".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&7000u16.to_be_bytes());
probe[5..].fill(0x31);
let fallback = tokio::spawn(async move {
handle_bad_client(
server_reader,
client_visible_writer,
&probe,
peer,
local,
&config,
&beobachten,
)
.await;
});
if close_client_after_write {
client_writer.shutdown().await.unwrap();
} else {
client_writer.write_all(b"keepalive").await.unwrap();
tokio::time::sleep(Duration::from_millis(170)).await;
drop(client_writer);
}
let _ = tokio::time::timeout(Duration::from_secs(4), fallback)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn aggressive_mode_shapes_backend_silent_non_eof_path() {
let body_sent = 17usize;
let floor = 512usize;
let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await;
let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await;
assert!(legacy < floor, "legacy mode should keep timeout path unshaped");
assert!(
aggressive >= floor,
"aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})"
);
}
#[tokio::test]
async fn aggressive_mode_enforces_positive_above_cap_blur() {
let body_sent = 5000usize;
let base = 5 + body_sent;
for _ in 0..48 {
let observed = capture_forwarded_len_with_mode(body_sent, true, true, true, 1).await;
assert!(
observed > base,
"aggressive mode must not emit exact base length when blur is enabled (observed={observed}, base={base})"
);
}
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More