Compare commits

...

8 Commits

Author SHA1 Message Date
Alexey
3011a9ef6d Merge branch 'flow' into flow 2026-03-22 15:50:21 +03:00
Alexey
7b570be5b3 DC -> Client Runtime in Metrics and API 2026-03-22 15:28:55 +03:00
Alexey
0461bc65c6 DC -> Client Optimizations 2026-03-22 15:00:15 +03:00
David Osipov
ead23608f0 Add stress and manual benchmark tests for handshake protocols
- Introduced `handshake_real_bug_stress_tests.rs` to validate TLS and MTProto handshake behaviors under various conditions, including ALPN rejection and session ID handling.
- Implemented tests to ensure replay cache integrity and proper handling of malicious input without panicking.
- Added `handshake_timing_manual_bench_tests.rs` for performance benchmarking of user authentication paths, comparing preferred user handling against full user scans in both MTProto and TLS contexts.
- Included timing-sensitive tests to measure the impact of SNI on handshake performance.
2026-03-22 15:39:57 +04:00
Alexey
cf82b637d2 Merge branch 'main' into flow 2026-03-22 12:38:37 +03:00
Alexey
2e8bfa1101 Update codeql-config.yml 2026-03-22 12:38:15 +03:00
Alexey
d091b0b251 Update CODE_OF_CONDUCT.md 2026-03-22 11:48:06 +03:00
Alexey
56fc6c4896 Update Dockerfile 2026-03-22 11:16:09 +03:00
25 changed files with 4461 additions and 122 deletions

View File

@@ -7,7 +7,16 @@ queries:
- uses: security-and-quality
- uses: ./.github/codeql/queries
paths-ignore:
- "**/tests/**"
- "**/test/**"
- "**/*_test.rs"
- "**/*/tests.rs"
query-filters:
- exclude:
tags:
- test
- exclude:
id:
- rust/unwrap-on-option

View File

@@ -1,8 +1,8 @@
# Code of Conduct
## 1. Purpose
## Purpose
Telemt exists to solve technical problems.
**Telemt exists to solve technical problems.**
Telemt is open to contributors who want to learn, improve and build meaningful systems together.
@@ -18,27 +18,34 @@ Technology has consequences. Responsibility is inherent.
---
## 2. Principles
## Principles
* **Technical over emotional**
Arguments are grounded in data, logs, reproducible cases, or clear reasoning.
* **Clarity over noise**
Communication is structured, concise, and relevant.
* **Openness with standards**
Participation is open. The work remains disciplined.
* **Independence of judgment**
Claims are evaluated on technical merit, not affiliation or posture.
* **Responsibility over capability**
Capability does not justify careless use.
* **Cooperation over friction**
Progress depends on coordination, mutual support, and honest review.
* **Good intent, rigorous method**
Assume good intent, but require rigor.
> **Aussagen gelten nach ihrer Begründung.**
@@ -47,7 +54,7 @@ Technology has consequences. Responsibility is inherent.
---
## 3. Expected Behavior
## Expected Behavior
Participants are expected to:
@@ -69,7 +76,7 @@ New contributors are welcome. They are expected to grow into these standards. Ex
---
## 4. Unacceptable Behavior
## Unacceptable Behavior
The following is not allowed:
@@ -89,7 +96,7 @@ Such discussions may be closed, removed, or redirected.
---
## 5. Security and Misuse
## Security and Misuse
Telemt is intended for responsible use.
@@ -109,15 +116,13 @@ Security is both technical and behavioral.
Telemt is open to contributors of different backgrounds, experience levels, and working styles.
Standards are public, legible, and applied to the work itself.
Questions are welcome. Careful disagreement is welcome. Honest correction is welcome.
Gatekeeping by obscurity, status signaling, or hostility is not.
- Standards are public, legible, and applied to the work itself.
- Questions are welcome. Careful disagreement is welcome. Honest correction is welcome.
- Gatekeeping by obscurity, status signaling, or hostility is not.
---
## 7. Scope
## Scope
This Code of Conduct applies to all official spaces:
@@ -127,16 +132,19 @@ This Code of Conduct applies to all official spaces:
---
## 8. Maintainer Stewardship
## Maintainer Stewardship
Maintainers are responsible for final decisions in matters of conduct, scope, and direction.
This responsibility is stewardship: preserving continuity, protecting signal, maintaining standards, and keeping Telemt workable for others.
This responsibility is stewardship:
- preserving continuity,
- protecting signal,
- maintaining standards,
- keeping Telemt workable for others.
Judgment should be exercised with restraint, consistency, and institutional responsibility.
Not every decision requires extended debate.
Not every intervention requires public explanation.
- Not every decision requires extended debate.
- Not every intervention requires public explanation.
All decisions are expected to serve the durability, clarity, and integrity of Telemt.
@@ -146,7 +154,7 @@ All decisions are expected to serve the durability, clarity, and integrity of Te
---
## 9. Enforcement
## Enforcement
Maintainers may act to preserve the integrity of Telemt, including by:
@@ -156,44 +164,40 @@ Maintainers may act to preserve the integrity of Telemt, including by:
* Restricting or banning participants
Actions are taken to maintain function, continuity, and signal quality.
Where possible, correction is preferred to exclusion.
Where necessary, exclusion is preferred to decay.
- Where possible, correction is preferred to exclusion.
- Where necessary, exclusion is preferred to decay.
---
## 10. Final
## Final
Telemt is built on discipline, structure, and shared intent.
- Signal over noise.
- Facts over opinion.
- Systems over rhetoric.
Signal over noise.
Facts over opinion.
Systems over rhetoric.
- Work is collective.
- Outcomes are shared.
- Responsibility is distributed.
Work is collective.
Outcomes are shared.
Responsibility is distributed.
Precision is learned.
Rigor is expected.
Help is part of the work.
- Precision is learned.
- Rigor is expected.
- Help is part of the work.
> **Ordnung ist Voraussetzung der Freiheit.**
If you contribute — contribute with care.
If you speak — speak with substance.
If you engage — engage constructively.
- If you contribute — contribute with care.
- If you speak — speak with substance.
- If you engage — engage constructively.
---
## 11. After All
## After All
Systems outlive intentions.
What is built will be used.
What is released will propagate.
What is maintained will define the future state.
- What is built will be used.
- What is released will propagate.
- What is maintained will define the future state.
There is no neutral infrastructure, only infrastructure shaped well or poorly.
@@ -201,8 +205,8 @@ There is no neutral infrastructure, only infrastructure shaped well or poorly.
> Every system carries responsibility.
Stability requires discipline.
Freedom requires structure.
Trust requires honesty.
- Stability requires discipline.
- Freedom requires structure.
- Trust requires honesty.
In the end, the system reflects its contributors.
In the end: the system reflects its contributors.

View File

@@ -28,9 +28,23 @@ RUN cargo build --release && strip target/release/telemt
FROM debian:12-slim AS minimal
RUN apt-get update && apt-get install -y --no-install-recommends \
upx \
binutils \
&& rm -rf /var/lib/apt/lists/*
curl \
ca-certificates \
&& rm -rf /var/lib/apt/lists/* \
\
# install UPX from Telemt releases
&& curl -fL \
--retry 5 \
--retry-delay 3 \
--connect-timeout 10 \
--max-time 120 \
-o /tmp/upx.tar.xz \
https://github.com/telemt/telemt/releases/download/toolchains/upx-amd64_linux.tar.xz \
&& tar -xf /tmp/upx.tar.xz -C /tmp \
&& mv /tmp/upx*/upx /usr/local/bin/upx \
&& chmod +x /usr/local/bin/upx \
&& rm -rf /tmp/upx*
COPY --from=builder /build/target/release/telemt /telemt

View File

@@ -198,11 +198,14 @@ This document lists all configuration keys accepted by `config.toml`.
| listen_tcp | `bool \| null` | `null` (auto) | — | Explicit TCP listener enable/disable override. |
| proxy_protocol | `bool` | `false` | — | Enables HAProxy PROXY protocol parsing on incoming client connections. |
| proxy_protocol_header_timeout_ms | `u64` | `500` | Must be `> 0`. | Timeout for PROXY protocol header read/parse (ms). |
| proxy_protocol_trusted_cidrs | `IpNetwork[]` | `[]` | — | When non-empty, only connections from these proxy source CIDRs are allowed to provide PROXY protocol headers. If empty, PROXY headers are rejected by default (security hardening). |
| metrics_port | `u16 \| null` | `null` | — | Metrics endpoint port (enables metrics listener). |
| metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. |
| metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. |
| max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). |
Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted.
## [server.api]
| Parameter | Type | Default | Constraints / validation | Description |

View File

@@ -174,6 +174,24 @@ pub(super) struct ZeroMiddleProxyData {
pub(super) route_drop_queue_full_total: u64,
pub(super) route_drop_queue_full_base_total: u64,
pub(super) route_drop_queue_full_high_total: u64,
pub(super) d2c_batches_total: u64,
pub(super) d2c_batch_frames_total: u64,
pub(super) d2c_batch_bytes_total: u64,
pub(super) d2c_flush_reason_queue_drain_total: u64,
pub(super) d2c_flush_reason_batch_frames_total: u64,
pub(super) d2c_flush_reason_batch_bytes_total: u64,
pub(super) d2c_flush_reason_max_delay_total: u64,
pub(super) d2c_flush_reason_ack_immediate_total: u64,
pub(super) d2c_flush_reason_close_total: u64,
pub(super) d2c_data_frames_total: u64,
pub(super) d2c_ack_frames_total: u64,
pub(super) d2c_payload_bytes_total: u64,
pub(super) d2c_write_mode_coalesced_total: u64,
pub(super) d2c_write_mode_split_total: u64,
pub(super) d2c_quota_reject_pre_write_total: u64,
pub(super) d2c_quota_reject_post_write_total: u64,
pub(super) d2c_frame_buf_shrink_total: u64,
pub(super) d2c_frame_buf_shrink_bytes_total: u64,
pub(super) socks_kdf_strict_reject_total: u64,
pub(super) socks_kdf_compat_fallback_total: u64,
pub(super) endpoint_quarantine_total: u64,

View File

@@ -68,6 +68,25 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer
route_drop_queue_full_total: stats.get_me_route_drop_queue_full(),
route_drop_queue_full_base_total: stats.get_me_route_drop_queue_full_base(),
route_drop_queue_full_high_total: stats.get_me_route_drop_queue_full_high(),
d2c_batches_total: stats.get_me_d2c_batches_total(),
d2c_batch_frames_total: stats.get_me_d2c_batch_frames_total(),
d2c_batch_bytes_total: stats.get_me_d2c_batch_bytes_total(),
d2c_flush_reason_queue_drain_total: stats.get_me_d2c_flush_reason_queue_drain_total(),
d2c_flush_reason_batch_frames_total: stats.get_me_d2c_flush_reason_batch_frames_total(),
d2c_flush_reason_batch_bytes_total: stats.get_me_d2c_flush_reason_batch_bytes_total(),
d2c_flush_reason_max_delay_total: stats.get_me_d2c_flush_reason_max_delay_total(),
d2c_flush_reason_ack_immediate_total: stats
.get_me_d2c_flush_reason_ack_immediate_total(),
d2c_flush_reason_close_total: stats.get_me_d2c_flush_reason_close_total(),
d2c_data_frames_total: stats.get_me_d2c_data_frames_total(),
d2c_ack_frames_total: stats.get_me_d2c_ack_frames_total(),
d2c_payload_bytes_total: stats.get_me_d2c_payload_bytes_total(),
d2c_write_mode_coalesced_total: stats.get_me_d2c_write_mode_coalesced_total(),
d2c_write_mode_split_total: stats.get_me_d2c_write_mode_split_total(),
d2c_quota_reject_pre_write_total: stats.get_me_d2c_quota_reject_pre_write_total(),
d2c_quota_reject_post_write_total: stats.get_me_d2c_quota_reject_post_write_total(),
d2c_frame_buf_shrink_total: stats.get_me_d2c_frame_buf_shrink_total(),
d2c_frame_buf_shrink_bytes_total: stats.get_me_d2c_frame_buf_shrink_bytes_total(),
socks_kdf_strict_reject_total: stats.get_me_socks_kdf_strict_reject(),
socks_kdf_compat_fallback_total: stats.get_me_socks_kdf_compat_fallback(),
endpoint_quarantine_total: stats.get_me_endpoint_quarantine_total(),

View File

@@ -29,6 +29,8 @@ const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_FRAMES: usize = 32;
const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_BYTES: usize = 128 * 1024;
const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_DELAY_US: u64 = 500;
const DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE: bool = true;
const DEFAULT_ME_QUOTA_SOFT_OVERSHOOT_BYTES: u64 = 64 * 1024;
const DEFAULT_ME_D2C_FRAME_BUF_SHRINK_THRESHOLD_BYTES: usize = 256 * 1024;
const DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES: usize = 64 * 1024;
const DEFAULT_DIRECT_RELAY_COPY_BUF_S2C_BYTES: usize = 256 * 1024;
const DEFAULT_ME_WRITER_PICK_SAMPLE_SIZE: u8 = 3;
@@ -387,6 +389,14 @@ pub(crate) fn default_me_d2c_ack_flush_immediate() -> bool {
DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE
}
pub(crate) fn default_me_quota_soft_overshoot_bytes() -> u64 {
DEFAULT_ME_QUOTA_SOFT_OVERSHOOT_BYTES
}
pub(crate) fn default_me_d2c_frame_buf_shrink_threshold_bytes() -> usize {
DEFAULT_ME_D2C_FRAME_BUF_SHRINK_THRESHOLD_BYTES
}
pub(crate) fn default_direct_relay_copy_buf_c2s_bytes() -> usize {
DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES
}

View File

@@ -106,6 +106,8 @@ pub struct HotFields {
pub me_d2c_flush_batch_max_bytes: usize,
pub me_d2c_flush_batch_max_delay_us: u64,
pub me_d2c_ack_flush_immediate: bool,
pub me_quota_soft_overshoot_bytes: u64,
pub me_d2c_frame_buf_shrink_threshold_bytes: usize,
pub direct_relay_copy_buf_c2s_bytes: usize,
pub direct_relay_copy_buf_s2c_bytes: usize,
pub me_health_interval_ms_unhealthy: u64,
@@ -225,6 +227,8 @@ impl HotFields {
me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes,
me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us,
me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate,
me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes,
me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes,
direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes,
direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes,
me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy,
@@ -511,6 +515,9 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
cfg.general.me_d2c_flush_batch_max_bytes = new.general.me_d2c_flush_batch_max_bytes;
cfg.general.me_d2c_flush_batch_max_delay_us = new.general.me_d2c_flush_batch_max_delay_us;
cfg.general.me_d2c_ack_flush_immediate = new.general.me_d2c_ack_flush_immediate;
cfg.general.me_quota_soft_overshoot_bytes = new.general.me_quota_soft_overshoot_bytes;
cfg.general.me_d2c_frame_buf_shrink_threshold_bytes =
new.general.me_d2c_frame_buf_shrink_threshold_bytes;
cfg.general.direct_relay_copy_buf_c2s_bytes = new.general.direct_relay_copy_buf_c2s_bytes;
cfg.general.direct_relay_copy_buf_s2c_bytes = new.general.direct_relay_copy_buf_s2c_bytes;
cfg.general.me_health_interval_ms_unhealthy = new.general.me_health_interval_ms_unhealthy;
@@ -1030,15 +1037,20 @@ fn log_changes(
|| old_hot.me_d2c_flush_batch_max_bytes != new_hot.me_d2c_flush_batch_max_bytes
|| old_hot.me_d2c_flush_batch_max_delay_us != new_hot.me_d2c_flush_batch_max_delay_us
|| old_hot.me_d2c_ack_flush_immediate != new_hot.me_d2c_ack_flush_immediate
|| old_hot.me_quota_soft_overshoot_bytes != new_hot.me_quota_soft_overshoot_bytes
|| old_hot.me_d2c_frame_buf_shrink_threshold_bytes
!= new_hot.me_d2c_frame_buf_shrink_threshold_bytes
|| old_hot.direct_relay_copy_buf_c2s_bytes != new_hot.direct_relay_copy_buf_c2s_bytes
|| old_hot.direct_relay_copy_buf_s2c_bytes != new_hot.direct_relay_copy_buf_s2c_bytes
{
info!(
"config reload: relay_tuning: me_d2c_frames={} me_d2c_bytes={} me_d2c_delay_us={} me_ack_flush_immediate={} direct_buf_c2s={} direct_buf_s2c={}",
"config reload: relay_tuning: me_d2c_frames={} me_d2c_bytes={} me_d2c_delay_us={} me_ack_flush_immediate={} me_quota_soft_overshoot_bytes={} me_d2c_frame_buf_shrink_threshold_bytes={} direct_buf_c2s={} direct_buf_s2c={}",
new_hot.me_d2c_flush_batch_max_frames,
new_hot.me_d2c_flush_batch_max_bytes,
new_hot.me_d2c_flush_batch_max_delay_us,
new_hot.me_d2c_ack_flush_immediate,
new_hot.me_quota_soft_overshoot_bytes,
new_hot.me_d2c_frame_buf_shrink_threshold_bytes,
new_hot.direct_relay_copy_buf_c2s_bytes,
new_hot.direct_relay_copy_buf_s2c_bytes,
);

View File

@@ -533,6 +533,19 @@ impl ProxyConfig {
));
}
if config.general.me_quota_soft_overshoot_bytes > 16 * 1024 * 1024 {
return Err(ProxyError::Config(
"general.me_quota_soft_overshoot_bytes must be within [0, 16777216]".to_string(),
));
}
if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) {
return Err(ProxyError::Config(
"general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]"
.to_string(),
));
}
if !(4096..=1024 * 1024).contains(&config.general.direct_relay_copy_buf_c2s_bytes) {
return Err(ProxyError::Config(
"general.direct_relay_copy_buf_c2s_bytes must be within [4096, 1048576]"

View File

@@ -468,7 +468,7 @@ pub struct GeneralConfig {
pub me_c2me_send_timeout_ms: u64,
/// Bounded wait in milliseconds for routing ME DATA to per-connection queue.
/// `0` keeps legacy no-wait behavior.
/// `0` keeps non-blocking routing; values >0 enable bounded wait for compatibility.
#[serde(default = "default_me_reader_route_data_wait_ms")]
pub me_reader_route_data_wait_ms: u64,
@@ -489,6 +489,14 @@ pub struct GeneralConfig {
#[serde(default = "default_me_d2c_ack_flush_immediate")]
pub me_d2c_ack_flush_immediate: bool,
/// Additional bytes above strict per-user quota allowed in hot-path soft mode.
#[serde(default = "default_me_quota_soft_overshoot_bytes")]
pub me_quota_soft_overshoot_bytes: u64,
/// Shrink threshold for reusable ME->Client frame assembly buffer.
#[serde(default = "default_me_d2c_frame_buf_shrink_threshold_bytes")]
pub me_d2c_frame_buf_shrink_threshold_bytes: usize,
/// Copy buffer size for client->DC direction in direct relay.
#[serde(default = "default_direct_relay_copy_buf_c2s_bytes")]
pub direct_relay_copy_buf_c2s_bytes: usize,
@@ -945,6 +953,8 @@ impl Default for GeneralConfig {
me_d2c_flush_batch_max_bytes: default_me_d2c_flush_batch_max_bytes(),
me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(),
me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(),
me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(),
me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(),
direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(),
direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(),
me_warmup_stagger_enabled: default_true(),

View File

@@ -935,6 +935,462 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_batches_total Total DC->Client flush batches"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_batches_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_batches_total {}",
if me_allows_normal {
stats.get_me_d2c_batches_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_batch_frames_total Total DC->Client frames flushed in batches"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_frames_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_batch_frames_total {}",
if me_allows_normal {
stats.get_me_d2c_batch_frames_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_batch_bytes_total Total DC->Client bytes flushed in batches"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_total {}",
if me_allows_normal {
stats.get_me_d2c_batch_bytes_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_flush_reason_total DC->Client flush reasons"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_flush_reason_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_flush_reason_total{{reason=\"queue_drain\"}} {}",
if me_allows_normal {
stats.get_me_d2c_flush_reason_queue_drain_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_reason_total{{reason=\"batch_frames\"}} {}",
if me_allows_normal {
stats.get_me_d2c_flush_reason_batch_frames_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_reason_total{{reason=\"batch_bytes\"}} {}",
if me_allows_normal {
stats.get_me_d2c_flush_reason_batch_bytes_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_reason_total{{reason=\"max_delay\"}} {}",
if me_allows_normal {
stats.get_me_d2c_flush_reason_max_delay_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_reason_total{{reason=\"ack_immediate\"}} {}",
if me_allows_normal {
stats.get_me_d2c_flush_reason_ack_immediate_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_reason_total{{reason=\"close\"}} {}",
if me_allows_normal {
stats.get_me_d2c_flush_reason_close_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_data_frames_total DC->Client data frames"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_data_frames_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_data_frames_total {}",
if me_allows_normal {
stats.get_me_d2c_data_frames_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_ack_frames_total DC->Client quick-ack frames"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_ack_frames_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_ack_frames_total {}",
if me_allows_normal {
stats.get_me_d2c_ack_frames_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_payload_bytes_total DC->Client payload bytes before transport framing"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_payload_bytes_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_payload_bytes_total {}",
if me_allows_normal {
stats.get_me_d2c_payload_bytes_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_write_mode_total DC->Client writer mode selection"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_write_mode_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_write_mode_total{{mode=\"coalesced\"}} {}",
if me_allows_normal {
stats.get_me_d2c_write_mode_coalesced_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_write_mode_total{{mode=\"split\"}} {}",
if me_allows_normal {
stats.get_me_d2c_write_mode_split_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_quota_reject_total DC->Client quota rejects"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_quota_reject_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_quota_reject_total{{stage=\"pre_write\"}} {}",
if me_allows_normal {
stats.get_me_d2c_quota_reject_pre_write_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_quota_reject_total{{stage=\"post_write\"}} {}",
if me_allows_normal {
stats.get_me_d2c_quota_reject_post_write_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_frame_buf_shrink_total DC->Client reusable frame buffer shrink events"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_frame_buf_shrink_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_frame_buf_shrink_total {}",
if me_allows_normal {
stats.get_me_d2c_frame_buf_shrink_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_frame_buf_shrink_bytes_total DC->Client reusable frame buffer bytes released"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_frame_buf_shrink_bytes_total counter"
);
let _ = writeln!(
out,
"telemt_me_d2c_frame_buf_shrink_bytes_total {}",
if me_allows_normal {
stats.get_me_d2c_frame_buf_shrink_bytes_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_batch_frames_bucket_total DC->Client batch frame count buckets"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_batch_frames_bucket_total counter"
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_frames_bucket_total{{bucket=\"1\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_frames_bucket_1()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_frames_bucket_total{{bucket=\"2_4\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_frames_bucket_2_4()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_frames_bucket_total{{bucket=\"5_8\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_frames_bucket_5_8()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_frames_bucket_total{{bucket=\"9_16\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_frames_bucket_9_16()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_frames_bucket_total{{bucket=\"17_32\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_frames_bucket_17_32()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_frames_bucket_total{{bucket=\"gt_32\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_frames_bucket_gt_32()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_bytes_bucket_0_1k()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"1k_4k\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_bytes_bucket_1k_4k()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"4k_16k\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_bytes_bucket_4k_16k()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"16k_64k\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_bytes_bucket_16k_64k()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"64k_128k\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_bytes_bucket_64k_128k()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"gt_128k\"}} {}",
if me_allows_debug {
stats.get_me_d2c_batch_bytes_bucket_gt_128k()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_flush_duration_us_bucket_total DC->Client flush duration buckets"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_flush_duration_us_bucket_total counter"
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"0_50\"}} {}",
if me_allows_debug {
stats.get_me_d2c_flush_duration_us_bucket_0_50()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"51_200\"}} {}",
if me_allows_debug {
stats.get_me_d2c_flush_duration_us_bucket_51_200()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"201_1000\"}} {}",
if me_allows_debug {
stats.get_me_d2c_flush_duration_us_bucket_201_1000()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"1001_5000\"}} {}",
if me_allows_debug {
stats.get_me_d2c_flush_duration_us_bucket_1001_5000()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"5001_20000\"}} {}",
if me_allows_debug {
stats.get_me_d2c_flush_duration_us_bucket_5001_20000()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"gt_20000\"}} {}",
if me_allows_debug {
stats.get_me_d2c_flush_duration_us_bucket_gt_20000()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_batch_timeout_armed_total DC->Client max-delay timer armed events"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_batch_timeout_armed_total counter"
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_timeout_armed_total {}",
if me_allows_debug {
stats.get_me_d2c_batch_timeout_armed_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_d2c_batch_timeout_fired_total DC->Client max-delay timer fired events"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_batch_timeout_fired_total counter"
);
let _ = writeln!(
out,
"telemt_me_d2c_batch_timeout_fired_total {}",
if me_allows_debug {
stats.get_me_d2c_batch_timeout_fired_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_writer_pick_total ME writer-pick outcomes by mode and result"
@@ -2145,6 +2601,16 @@ mod tests {
stats.increment_relay_idle_hard_close_total();
stats.increment_relay_pressure_evict_total();
stats.increment_relay_protocol_desync_close_total();
stats.increment_me_d2c_batches_total();
stats.add_me_d2c_batch_frames_total(3);
stats.add_me_d2c_batch_bytes_total(2048);
stats.increment_me_d2c_flush_reason(crate::stats::MeD2cFlushReason::AckImmediate);
stats.increment_me_d2c_data_frames_total();
stats.increment_me_d2c_ack_frames_total();
stats.add_me_d2c_payload_bytes_total(1800);
stats.increment_me_d2c_write_mode(crate::stats::MeD2cWriteMode::Coalesced);
stats.increment_me_d2c_quota_reject_total(crate::stats::MeD2cQuotaRejectStage::PostWrite);
stats.observe_me_d2c_frame_buf_shrink(4096);
stats.increment_user_connects("alice");
stats.increment_user_curr_connects("alice");
stats.add_user_octets_from("alice", 1024);
@@ -2184,6 +2650,17 @@ mod tests {
assert!(output.contains("telemt_relay_idle_hard_close_total 1"));
assert!(output.contains("telemt_relay_pressure_evict_total 1"));
assert!(output.contains("telemt_relay_protocol_desync_close_total 1"));
assert!(output.contains("telemt_me_d2c_batches_total 1"));
assert!(output.contains("telemt_me_d2c_batch_frames_total 3"));
assert!(output.contains("telemt_me_d2c_batch_bytes_total 2048"));
assert!(output.contains("telemt_me_d2c_flush_reason_total{reason=\"ack_immediate\"} 1"));
assert!(output.contains("telemt_me_d2c_data_frames_total 1"));
assert!(output.contains("telemt_me_d2c_ack_frames_total 1"));
assert!(output.contains("telemt_me_d2c_payload_bytes_total 1800"));
assert!(output.contains("telemt_me_d2c_write_mode_total{mode=\"coalesced\"} 1"));
assert!(output.contains("telemt_me_d2c_quota_reject_total{stage=\"post_write\"} 1"));
assert!(output.contains("telemt_me_d2c_frame_buf_shrink_total 1"));
assert!(output.contains("telemt_me_d2c_frame_buf_shrink_bytes_total 4096"));
assert!(output.contains("telemt_user_connections_total{user=\"alice\"} 1"));
assert!(output.contains("telemt_user_connections_current{user=\"alice\"} 1"));
assert!(output.contains("telemt_user_octets_from_client{user=\"alice\"} 1024"));
@@ -2245,6 +2722,11 @@ mod tests {
assert!(output.contains("# TYPE telemt_relay_idle_hard_close_total counter"));
assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter"));
assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter"));
assert!(output.contains("# TYPE telemt_me_d2c_batches_total counter"));
assert!(output.contains("# TYPE telemt_me_d2c_flush_reason_total counter"));
assert!(output.contains("# TYPE telemt_me_d2c_write_mode_total counter"));
assert!(output.contains("# TYPE telemt_me_d2c_batch_frames_bucket_total counter"));
assert!(output.contains("# TYPE telemt_me_d2c_flush_duration_us_bucket_total counter"));
assert!(output.contains("# TYPE telemt_me_writer_removed_total counter"));
assert!(
output

View File

@@ -1328,3 +1328,15 @@ mod beobachten_ttl_bounds_security_tests;
#[cfg(test)]
#[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"]
mod tls_record_wrap_hardening_security_tests;
#[cfg(test)]
#[path = "tests/client_clever_advanced_tests.rs"]
mod client_clever_advanced_tests;
#[cfg(test)]
#[path = "tests/client_more_advanced_tests.rs"]
mod client_more_advanced_tests;
#[cfg(test)]
#[path = "tests/client_deep_invariants_tests.rs"]
mod client_deep_invariants_tests;

View File

@@ -605,16 +605,6 @@ where
}
};
// Replay tracking is applied only after successful authentication to avoid
// letting unauthenticated probes evict valid entries from the replay cache.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => {
@@ -670,6 +660,16 @@ where
None
};
// Replay tracking is applied only after full policy validation (including
// ALPN checks) so rejected handshakes cannot poison replay state.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
emulator::build_emulated_server_hello(
secret,
@@ -979,6 +979,22 @@ mod saturation_poison_security_tests;
#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"]
mod auth_probe_hardening_adversarial_tests;
#[cfg(test)]
#[path = "tests/handshake_advanced_clever_tests.rs"]
mod advanced_clever_tests;
#[cfg(test)]
#[path = "tests/handshake_more_clever_tests.rs"]
mod more_clever_tests;
#[cfg(test)]
#[path = "tests/handshake_real_bug_stress_tests.rs"]
mod real_bug_stress_tests;
#[cfg(test)]
#[path = "tests/handshake_timing_manual_bench_tests.rs"]
mod timing_manual_bench_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.

View File

@@ -21,7 +21,7 @@ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
cutover_stagger_delay,
};
use crate::stats::Stats;
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
@@ -45,6 +45,8 @@ const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
@@ -214,6 +216,8 @@ struct MeD2cFlushPolicy {
max_bytes: usize,
max_delay: Duration,
ack_flush_immediate: bool,
quota_soft_overshoot_bytes: u64,
frame_buf_shrink_threshold_bytes: usize,
}
#[derive(Clone, Copy)]
@@ -284,6 +288,11 @@ impl MeD2cFlushPolicy {
.max(ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN),
max_delay: Duration::from_micros(config.general.me_d2c_flush_batch_max_delay_us),
ack_flush_immediate: config.general.me_d2c_ack_flush_immediate,
quota_soft_overshoot_bytes: config.general.me_quota_soft_overshoot_bytes,
frame_buf_shrink_threshold_bytes: config
.general
.me_d2c_frame_buf_shrink_threshold_bytes
.max(4096),
}
}
}
@@ -538,6 +547,76 @@ fn quota_would_be_exceeded_for_user(
})
}
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
fn quota_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot))
}
fn quota_would_be_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let cap = quota_soft_cap(quota, overshoot);
let used = stats.get_user_total_octets(user);
used >= cap || bytes > cap.saturating_sub(used)
})
}
fn classify_me_d2c_flush_reason(
flush_immediately: bool,
batch_frames: usize,
max_frames: usize,
batch_bytes: usize,
max_bytes: usize,
max_delay_fired: bool,
) -> MeD2cFlushReason {
if flush_immediately {
return MeD2cFlushReason::AckImmediate;
}
if batch_frames >= max_frames {
return MeD2cFlushReason::BatchFrames;
}
if batch_bytes >= max_bytes {
return MeD2cFlushReason::BatchBytes;
}
if max_delay_fired {
return MeD2cFlushReason::MaxDelay;
}
MeD2cFlushReason::QueueDrain
}
fn observe_me_d2c_flush_event(
stats: &Stats,
reason: MeD2cFlushReason,
batch_frames: usize,
batch_bytes: usize,
flush_duration_us: Option<u64>,
) {
stats.increment_me_d2c_flush_reason(reason);
if batch_frames > 0 || batch_bytes > 0 {
stats.increment_me_d2c_batches_total();
stats.add_me_d2c_batch_frames_total(batch_frames as u64);
stats.add_me_d2c_batch_bytes_total(batch_bytes as u64);
stats.observe_me_d2c_batch_frames(batch_frames as u64);
stats.observe_me_d2c_batch_bytes(batch_bytes as u64);
}
if let Some(duration_us) = flush_duration_us {
stats.observe_me_d2c_flush_duration_us(duration_us);
}
}
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@@ -774,6 +853,7 @@ where
let mut batch_frames = 0usize;
let mut batch_bytes = 0usize;
let mut flush_immediately;
let mut max_delay_fired = false;
let first_is_downstream_activity =
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
@@ -786,6 +866,7 @@ where
stats_clone.as_ref(),
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -801,7 +882,25 @@ where
flush_immediately = immediate;
}
MeWriterResponseOutcome::Close => {
let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() {
Some(Instant::now())
} else {
None
};
let _ = writer.flush().await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
.as_micros()
.min(u128::from(u64::MAX)) as u64
});
observe_me_d2c_flush_event(
stats_clone.as_ref(),
MeD2cFlushReason::Close,
batch_frames,
batch_bytes,
flush_duration_us,
);
return Ok(());
}
}
@@ -825,6 +924,7 @@ where
stats_clone.as_ref(),
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -840,7 +940,27 @@ where
flush_immediately |= immediate;
}
MeWriterResponseOutcome::Close => {
let flush_started_at =
if stats_clone.telemetry_policy().me_level.allows_debug() {
Some(Instant::now())
} else {
None
};
let _ = writer.flush().await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
.as_micros()
.min(u128::from(u64::MAX))
as u64
});
observe_me_d2c_flush_event(
stats_clone.as_ref(),
MeD2cFlushReason::Close,
batch_frames,
batch_bytes,
flush_duration_us,
);
return Ok(());
}
}
@@ -851,6 +971,7 @@ where
&& batch_frames < d2c_flush_policy.max_frames
&& batch_bytes < d2c_flush_policy.max_bytes
{
stats_clone.increment_me_d2c_batch_timeout_armed_total();
match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await {
Ok(Some(next)) => {
let next_is_downstream_activity =
@@ -864,6 +985,7 @@ where
stats_clone.as_ref(),
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -879,7 +1001,30 @@ where
flush_immediately |= immediate;
}
MeWriterResponseOutcome::Close => {
let flush_started_at = if stats_clone
.telemetry_policy()
.me_level
.allows_debug()
{
Some(Instant::now())
} else {
None
};
let _ = writer.flush().await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
.as_micros()
.min(u128::from(u64::MAX))
as u64
});
observe_me_d2c_flush_event(
stats_clone.as_ref(),
MeD2cFlushReason::Close,
batch_frames,
batch_bytes,
flush_duration_us,
);
return Ok(());
}
}
@@ -903,6 +1048,7 @@ where
stats_clone.as_ref(),
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -918,7 +1064,30 @@ where
flush_immediately |= immediate;
}
MeWriterResponseOutcome::Close => {
let flush_started_at = if stats_clone
.telemetry_policy()
.me_level
.allows_debug()
{
Some(Instant::now())
} else {
None
};
let _ = writer.flush().await;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
.as_micros()
.min(u128::from(u64::MAX))
as u64
});
observe_me_d2c_flush_event(
stats_clone.as_ref(),
MeD2cFlushReason::Close,
batch_frames,
batch_bytes,
flush_duration_us,
);
return Ok(());
}
}
@@ -928,11 +1097,50 @@ where
debug!(conn_id, "ME channel closed");
return Err(ProxyError::Proxy("ME connection lost".into()));
}
Err(_) => {}
Err(_) => {
max_delay_fired = true;
stats_clone.increment_me_d2c_batch_timeout_fired_total();
}
}
}
let flush_reason = classify_me_d2c_flush_reason(
flush_immediately,
batch_frames,
d2c_flush_policy.max_frames,
batch_bytes,
d2c_flush_policy.max_bytes,
max_delay_fired,
);
let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() {
Some(Instant::now())
} else {
None
};
writer.flush().await.map_err(ProxyError::Io)?;
let flush_duration_us = flush_started_at.map(|started| {
started
.elapsed()
.as_micros()
.min(u128::from(u64::MAX)) as u64
});
observe_me_d2c_flush_event(
stats_clone.as_ref(),
flush_reason,
batch_frames,
batch_bytes,
flush_duration_us,
);
let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes;
let shrink_trigger = shrink_threshold
.saturating_mul(ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR);
if frame_buf.capacity() > shrink_trigger {
let cap_before = frame_buf.capacity();
frame_buf.shrink_to(shrink_threshold);
let cap_after = frame_buf.capacity();
let bytes_freed = cap_before.saturating_sub(cap_after) as u64;
stats_clone.observe_me_d2c_frame_buf_shrink(bytes_freed);
}
}
_ = &mut stop_rx => {
debug!(conn_id, "ME writer stop signal");
@@ -1482,6 +1690,7 @@ async fn process_me_writer_response<W>(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
@@ -1498,31 +1707,39 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
let data_len = data.len() as u64;
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(user);
let _quota_guard = quota_lock.lock().await;
if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
if quota_would_be_exceeded_for_user_soft(
stats,
user,
quota_limit,
data_len,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
stats.increment_me_d2c_write_mode(write_mode);
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data.len() as u64);
if quota_exceeded_for_user(stats, user, Some(limit)) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
} else {
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
if quota_exceeded_for_user_soft(
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
Ok(MeWriterResponseOutcome::Continue {
@@ -1538,6 +1755,7 @@ where
trace!(conn_id, confirm, "ME->C quickack");
}
write_client_ack(client_writer, proto_tag, confirm).await?;
stats.increment_me_d2c_ack_frames_total();
Ok(MeWriterResponseOutcome::Continue {
frames: 1,
@@ -1588,13 +1806,13 @@ async fn write_client_payload<W>(
data: &[u8],
rng: &SecureRandom,
frame_buf: &mut Vec<u8>,
) -> Result<()>
) -> Result<MeD2cWriteMode>
where
W: AsyncWrite + Unpin + Send + 'static,
{
let quickack = (flags & RPC_FLAG_QUICKACK) != 0;
match proto_tag {
let write_mode = match proto_tag {
ProtoTag::Abridged => {
if !data.len().is_multiple_of(4) {
return Err(ProxyError::Proxy(format!(
@@ -1609,28 +1827,46 @@ where
if quickack {
first |= 0x80;
}
frame_buf.clear();
frame_buf.reserve(1 + data.len());
frame_buf.push(first);
frame_buf.extend_from_slice(data);
client_writer
.write_all(frame_buf)
.await
.map_err(ProxyError::Io)?;
let wire_len = 1usize.saturating_add(data.len());
if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES {
frame_buf.clear();
frame_buf.reserve(wire_len);
frame_buf.push(first);
frame_buf.extend_from_slice(data);
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced
} else {
let header = [first];
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split
}
} else if len_words < (1 << 24) {
let mut first = 0x7fu8;
if quickack {
first |= 0x80;
}
let lw = (len_words as u32).to_le_bytes();
frame_buf.clear();
frame_buf.reserve(4 + data.len());
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data);
client_writer
.write_all(frame_buf)
.await
.map_err(ProxyError::Io)?;
let wire_len = 4usize.saturating_add(data.len());
if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES {
frame_buf.clear();
frame_buf.reserve(wire_len);
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data);
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced
} else {
let header = [first, lw[0], lw[1], lw[2]];
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split
}
} else {
return Err(ProxyError::Proxy(format!(
"Abridged frame too large: {}",
@@ -1650,25 +1886,46 @@ where
} else {
0
};
let (len_val, total) =
compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?;
frame_buf.clear();
frame_buf.reserve(total);
frame_buf.extend_from_slice(&len_val.to_le_bytes());
frame_buf.extend_from_slice(data);
if padding_len > 0 {
let start = frame_buf.len();
frame_buf.resize(start + padding_len, 0);
rng.fill(&mut frame_buf[start..]);
if total <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES {
frame_buf.clear();
frame_buf.reserve(total);
frame_buf.extend_from_slice(&len_val.to_le_bytes());
frame_buf.extend_from_slice(data);
if padding_len > 0 {
let start = frame_buf.len();
frame_buf.resize(start + padding_len, 0);
rng.fill(&mut frame_buf[start..]);
}
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced
} else {
let header = len_val.to_le_bytes();
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
if padding_len > 0 {
frame_buf.clear();
if frame_buf.capacity() < padding_len {
frame_buf.reserve(padding_len);
}
frame_buf.resize(padding_len, 0);
rng.fill(frame_buf.as_mut_slice());
client_writer
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
}
MeD2cWriteMode::Split
}
client_writer
.write_all(frame_buf)
.await
.map_err(ProxyError::Io)?;
}
}
};
Ok(())
Ok(write_mode)
}
async fn write_client_ack<W>(

View File

@@ -0,0 +1,409 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig};
use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
use crate::stats::Stats;
use crate::transport::UpstreamManager;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf, duplex};
use tokio::net::TcpListener;
#[test]
fn edge_mask_reject_delay_min_greater_than_max_does_not_panic() {
let mut config = ProxyConfig::default();
config.censorship.server_hello_delay_min_ms = 5000;
config.censorship.server_hello_delay_max_ms = 1000;
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let start = std::time::Instant::now();
maybe_apply_mask_reject_delay(&config).await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(1000));
assert!(elapsed < Duration::from_millis(1500));
});
}
#[test]
fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = u64::MAX;
config.censorship.mask = true;
let timeout = handshake_timeout_with_mask_grace(&config);
assert_eq!(timeout.as_secs(), u64::MAX);
}
#[test]
fn edge_tls_clienthello_len_in_bounds_exact_boundaries() {
assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE));
assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1));
assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE));
assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1));
}
#[test]
fn edge_synthetic_local_addr_boundaries() {
assert_eq!(synthetic_local_addr(0).port(), 0);
assert_eq!(synthetic_local_addr(80).port(), 80);
assert_eq!(synthetic_local_addr(u16::MAX).port(), u16::MAX);
}
#[test]
fn edge_beobachten_record_handshake_failure_class_stream_error_eof() {
let beobachten = BeobachtenStore::new();
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
let eof_err = ProxyError::Stream(crate::error::StreamError::UnexpectedEof);
let peer_ip: IpAddr = "198.51.100.100".parse().unwrap();
record_handshake_failure_class(&beobachten, &config, peer_ip, &eof_err);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[expected_64_got_0]"));
}
#[tokio::test]
async fn adversarial_tls_handshake_timeout_during_masking_delay() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.server_hello_delay_min_ms = 3000;
cfg.censorship.server_hello_delay_max_ms = 3000;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.1:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handle)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout)));
assert_eq!(stats.get_handshake_timeouts(), 1);
}
#[tokio::test]
async fn blackhat_proxy_protocol_slowloris_timeout() {
let mut cfg = ProxyConfig::default();
cfg.server.proxy_protocol_header_timeout_ms = 200;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.2:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
true,
));
client_side.write_all(b"PROXY TCP4 192.").await.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
let result = tokio::time::timeout(Duration::from_secs(2), handle)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
assert_eq!(stats.get_connects_bad(), 1);
}
#[test]
fn blackhat_ipv4_mapped_ipv6_proxy_source_bypass_attempt() {
let trusted = vec!["192.0.2.0/24".parse().unwrap()];
let peer_ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201));
assert!(!is_trusted_proxy_source(peer_ip, &trusted));
}
#[tokio::test]
async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() {
let mut cfg = ProxyConfig::default();
cfg.server.proxy_protocol_header_timeout_ms = 500;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.3:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
true,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handle)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn edge_client_stream_exactly_4_bytes_eof() {
let config = Arc::new(ProxyConfig::default());
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.4:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[expected_64_got_0]"));
}
#[tokio::test]
async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() {
let config = Arc::new(ProxyConfig::default());
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.5:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap();
client_side.write_all(&vec![0x41; 99]).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn integration_non_tls_modes_disabled_immediately_masks() {
let mut cfg = ProxyConfig::default();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
cfg.censorship.mask = true;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.6:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(b"GET / HTTP/1.1\r\n").await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
assert_eq!(stats.get_connects_bad(), 1);
}
struct YieldingReader {
data: Vec<u8>,
pos: usize,
yields_left: usize,
}
impl AsyncRead for YieldingReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if this.yields_left > 0 {
this.yields_left -= 1;
cx.waker().wake_by_ref();
return Poll::Pending;
}
if this.pos >= this.data.len() {
return Poll::Ready(Ok(()));
}
buf.put_slice(&this.data[this.pos..this.pos + 1]);
this.pos += 1;
this.yields_left = 2;
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn fuzz_read_with_progress_heavy_yielding() {
let expected_data = b"HEAVY_YIELD_TEST_DATA".to_vec();
let mut reader = YieldingReader {
data: expected_data.clone(),
pos: 0,
yields_left: 2,
};
let mut buf = vec![0u8; expected_data.len()];
let read_bytes = read_with_progress(&mut reader, &mut buf).await.unwrap();
assert_eq!(read_bytes, expected_data.len());
assert_eq!(buf, expected_data);
}
#[test]
fn edge_wrap_tls_application_record_exactly_u16_max() {
let payload = vec![0u8; 65535];
let wrapped = wrap_tls_application_record(&payload);
assert_eq!(wrapped.len(), 65540);
assert_eq!(wrapped[0], TLS_RECORD_APPLICATION);
assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes());
}
#[test]
fn fuzz_wrap_tls_application_record_lengths() {
let lengths = [0, 1, 65534, 65535, 65536, 131070, 131071, 131072];
for len in lengths {
let payload = vec![0u8; len];
let wrapped = wrap_tls_application_record(&payload);
let expected_chunks = len.div_ceil(65535).max(1);
assert_eq!(wrapped.len(), len + 5 * expected_chunks);
}
}
#[tokio::test]
async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() {
let user = "stress-same-ip-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 5);
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 10).await;
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)), 55000);
let mut tasks = tokio::task::JoinSet::new();
let mut reservations = Vec::new();
for _ in 0..10 {
let config = config.clone();
let stats = stats.clone();
let ip_tracker = ip_tracker.clone();
tasks.spawn(async move {
RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats,
peer,
ip_tracker,
)
.await
});
}
let mut successes = 0;
let mut failures = 0;
while let Some(res) = tasks.join_next().await {
match res.unwrap() {
Ok(r) => {
successes += 1;
reservations.push(r);
}
Err(_) => failures += 1,
}
}
assert_eq!(successes, 5);
assert_eq!(failures, 5);
assert_eq!(stats.get_user_curr_connects(user), 5);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1);
for reservation in reservations {
reservation.release().await;
}
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}

View File

@@ -0,0 +1,196 @@
use super::*;
use crate::config::ProxyConfig;
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use crate::stats::Stats;
use crate::transport::UpstreamManager;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncWriteExt, duplex};
#[test]
fn invariant_wrap_tls_application_record_exact_multiples() {
let chunk_size = u16::MAX as usize;
let payload = vec![0xAA; chunk_size * 2];
let wrapped = wrap_tls_application_record(&payload);
assert_eq!(wrapped.len(), 2 * (5 + chunk_size));
assert_eq!(wrapped[0], TLS_RECORD_APPLICATION);
assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes());
let second_header_idx = 5 + chunk_size;
assert_eq!(wrapped[second_header_idx], TLS_RECORD_APPLICATION);
assert_eq!(
&wrapped[second_header_idx + 3..second_header_idx + 5],
&65535u16.to_be_bytes()
);
}
#[tokio::test]
async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() {
let config = Arc::new(ProxyConfig::default());
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.20:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let claimed_len = MIN_TLS_CLIENT_HELLO_SIZE as u16;
let mut header = vec![0x16, 0x03, 0x01];
header.extend_from_slice(&claimed_len.to_be_bytes());
client_side.write_all(&header).await.unwrap();
client_side
.write_all(&vec![0x42; MIN_TLS_CLIENT_HELLO_SIZE - 1])
.await
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn invariant_acquire_reservation_ip_limit_rollback() {
let user = "rollback-test-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 10);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let peer_a = "198.51.100.21:55000".parse().unwrap();
let _res_a = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer_a,
ip_tracker.clone(),
)
.await
.unwrap();
assert_eq!(stats.get_user_curr_connects(user), 1);
let peer_b = "203.0.113.22:55000".parse().unwrap();
let res_b = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer_b,
ip_tracker.clone(),
)
.await;
assert!(matches!(
res_b,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
assert_eq!(stats.get_user_curr_connects(user), 1);
}
#[tokio::test]
async fn invariant_quota_exact_boundary_inclusive() {
let user = "quota-strict-user";
let mut config = ProxyConfig::default();
config.access.user_data_quota.insert(user.to_string(), 1000);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.23:55000".parse().unwrap();
stats.add_user_octets_from(user, 999);
let res1 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(res1.is_ok());
res1.unwrap().release().await;
stats.add_user_octets_from(user, 1);
let res2 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(matches!(res2, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.25:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(&[0xEF, 0xEF, 0xEF]).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_err());
assert_eq!(stats.get_connects_bad(), 0);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[expected_64_got_0]"));
}
#[tokio::test]
async fn invariant_route_mode_snapshot_picks_up_latest_mode() {
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
assert!(matches!(
route_runtime.snapshot().mode,
RelayRouteMode::Direct
));
route_runtime.set_mode(RelayRouteMode::Middle);
assert!(matches!(
route_runtime.snapshot().mode,
RelayRouteMode::Middle
));
}

View File

@@ -0,0 +1,257 @@
use super::*;
use crate::config::ProxyConfig;
use crate::stats::Stats;
use crate::transport::UpstreamManager;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
#[tokio::test]
async fn edge_mask_delay_bypassed_if_max_is_zero() {
let mut config = ProxyConfig::default();
config.censorship.server_hello_delay_min_ms = 10_000;
config.censorship.server_hello_delay_max_ms = 0;
let start = std::time::Instant::now();
maybe_apply_mask_reject_delay(&config).await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[test]
fn edge_beobachten_ttl_clamps_exactly_to_24_hours() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 100_000;
let ttl = beobachten_ttl(&config);
assert_eq!(ttl.as_secs(), 24 * 60 * 60);
}
#[test]
fn edge_wrap_tls_application_record_empty_payload() {
let wrapped = wrap_tls_application_record(&[]);
assert_eq!(wrapped.len(), 5);
assert_eq!(wrapped[0], TLS_RECORD_APPLICATION);
assert_eq!(&wrapped[3..5], &[0, 0]);
}
#[tokio::test]
async fn boundary_user_data_quota_exact_match_rejects() {
let user = "quota-boundary-user";
let mut config = ProxyConfig::default();
config.access.user_data_quota.insert(user.to_string(), 1024);
let stats = Arc::new(Stats::new());
stats.add_user_octets_from(user, 1024);
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.10:55000".parse().unwrap();
let result = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats,
peer,
ip_tracker,
)
.await;
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn boundary_user_expiration_in_past_rejects() {
let user = "expired-boundary-user";
let mut config = ProxyConfig::default();
let expired_time = chrono::Utc::now() - chrono::Duration::milliseconds(1);
config
.access
.user_expirations
.insert(user.to_string(), expired_time);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.11:55000".parse().unwrap();
let result = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats,
peer,
ip_tracker,
)
.await;
assert!(matches!(result, Err(ProxyError::UserExpired { .. })));
}
#[tokio::test]
async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() {
let mut cfg = ProxyConfig::default();
cfg.server.proxy_protocol_header_timeout_ms = 300;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.12:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
true,
));
client_side.write_all(&vec![b'A'; 2000]).await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.13:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn security_classic_mode_disabled_masks_valid_length_payload() {
let mut cfg = ProxyConfig::default();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
cfg.censorship.mask = true;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.15:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(&vec![0xEF; 64]).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() {
let user = "rapid-churn-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 10);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let peer = "198.51.100.16:55000".parse().unwrap();
for _ in 0..500 {
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await
.unwrap();
reservation.release().await;
}
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn quirk_read_with_progress_zero_length_buffer_returns_zero_immediately() {
let (mut server_side, _client_side) = duplex(4096);
let mut empty_buf = &mut [][..];
let result = tokio::time::timeout(
Duration::from_millis(50),
read_with_progress(&mut server_side, &mut empty_buf),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().unwrap(), 0);
}
#[tokio::test]
async fn stress_read_with_progress_cancellation_safety() {
let (mut server_side, mut client_side) = duplex(4096);
client_side.write_all(b"12345").await.unwrap();
let mut buf = [0u8; 10];
let result = tokio::time::timeout(
Duration::from_millis(50),
read_with_progress(&mut server_side, &mut buf),
)
.await;
assert!(result.is_err());
client_side.write_all(b"67890").await.unwrap();
let mut buf2 = [0u8; 5];
server_side.read_exact(&mut buf2).await.unwrap();
assert_eq!(&buf2, b"67890");
}

View File

@@ -7,6 +7,9 @@ use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use std::net::Ipv4Addr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
@@ -25,6 +28,202 @@ fn synthetic_local_addr_uses_configured_port_for_max() {
assert_eq!(addr.port(), u16::MAX);
}
#[test]
fn handshake_timeout_with_mask_grace_includes_mask_margin() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = 2;
config.censorship.mask = false;
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2));
config.censorship.mask = true;
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_millis(2750),
"mask mode extends handshake timeout by 750 ms"
);
}
#[tokio::test]
async fn read_with_progress_reads_partial_buffers_before_eof() {
let data = vec![0xAA, 0xBB, 0xCC];
let mut reader = std::io::Cursor::new(data);
let mut buf = [0u8; 5];
let read = read_with_progress(&mut reader, &mut buf).await.unwrap();
assert_eq!(read, 3);
assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]);
}
#[test]
fn is_trusted_proxy_source_respects_cidr_list_and_empty_rejects_all() {
let peer: IpAddr = "10.10.10.10".parse().unwrap();
assert!(!is_trusted_proxy_source(peer, &[]));
let trusted = vec!["10.0.0.0/8".parse().unwrap()];
assert!(is_trusted_proxy_source(peer, &trusted));
let not_trusted = vec!["192.0.2.0/24".parse().unwrap()];
assert!(!is_trusted_proxy_source(peer, &not_trusted));
}
#[test]
fn is_trusted_proxy_source_accepts_cidr_zero_zero_as_global_cidr() {
let peer: IpAddr = "203.0.113.42".parse().unwrap();
let trust_all = vec!["0.0.0.0/0".parse().unwrap()];
assert!(is_trusted_proxy_source(peer, &trust_all));
let peer_v6: IpAddr = "2001:db8::1".parse().unwrap();
let trust_all_v6 = vec!["::/0".parse().unwrap()];
assert!(is_trusted_proxy_source(peer_v6, &trust_all_v6));
}
struct ErrorReader;
impl tokio::io::AsyncRead for ErrorReader {
fn poll_read(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error")))
}
}
#[tokio::test]
async fn read_with_progress_returns_error_from_failed_reader() {
let mut reader = ErrorReader;
let mut buf = [0u8; 8];
let err = read_with_progress(&mut reader, &mut buf).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn handshake_timeout_with_mask_grace_handles_maximum_values_without_overflow() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = u64::MAX;
config.censorship.mask = true;
let timeout = handshake_timeout_with_mask_grace(&config);
assert!(timeout >= Duration::from_secs(u64::MAX));
}
#[tokio::test]
async fn read_with_progress_zero_length_buffer_returns_zero() {
let data = vec![1, 2, 3];
let mut reader = std::io::Cursor::new(data);
let mut buf = [];
let read = read_with_progress(&mut reader, &mut buf).await.unwrap();
assert_eq!(read, 0);
}
#[test]
fn handshake_timeout_without_mask_is_exact_base() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = 7;
config.censorship.mask = false;
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7));
}
#[test]
fn handshake_timeout_mask_enabled_adds_750ms() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = 3;
config.censorship.mask = true;
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750));
}
#[tokio::test]
async fn read_with_progress_full_then_empty_transition() {
let data = vec![0x10, 0x20];
let mut cursor = std::io::Cursor::new(data);
let mut buf = [0u8; 2];
assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 2);
assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 0);
}
#[tokio::test]
async fn read_with_progress_fragmented_io_works_over_multiple_calls() {
let mut cursor = std::io::Cursor::new(vec![1, 2, 3, 4, 5]);
let mut result = Vec::new();
for chunk_size in 1..=5 {
let mut b = vec![0u8; chunk_size];
let n = read_with_progress(&mut cursor, &mut b).await.unwrap();
result.extend_from_slice(&b[..n]);
if n == 0 { break; }
}
assert_eq!(result, vec![1,2,3,4,5]);
}
#[tokio::test]
async fn read_with_progress_stress_randomized_chunk_sizes() {
for i in 0..128 {
let mut rng = StdRng::seed_from_u64(i as u64 + 1);
let mut input: Vec<u8> = (0..(i % 41)).map(|_| rng.next_u32() as u8).collect();
let mut cursor = std::io::Cursor::new(input.clone());
let mut collected = Vec::new();
while cursor.position() < cursor.get_ref().len() as u64 {
let chunk = 1 + (rng.next_u32() as usize % 8);
let mut b = vec![0u8; chunk];
let read = read_with_progress(&mut cursor, &mut b).await.unwrap();
collected.extend_from_slice(&b[..read]);
if read == 0 { break; }
}
assert_eq!(collected, input);
}
}
#[test]
fn is_trusted_proxy_source_boundary_narrow_ipv4() {
let matching = "172.16.0.1".parse().unwrap();
let not_matching = "172.15.255.255".parse().unwrap();
let cidr = vec!["172.16.0.0/12".parse().unwrap()];
assert!(is_trusted_proxy_source(matching, &cidr));
assert!(!is_trusted_proxy_source(not_matching, &cidr));
}
#[test]
fn is_trusted_proxy_source_rejects_out_of_family_ipv6_v4_cidr() {
let peer = "2001:db8::1".parse().unwrap();
let cidr = vec!["10.0.0.0/8".parse().unwrap()];
assert!(!is_trusted_proxy_source(peer, &cidr));
}
#[test]
fn wrap_tls_application_record_reserved_chunks_look_reasonable() {
let payload = vec![0xAA; 1 + (u16::MAX as usize) + 2];
let wrapped = wrap_tls_application_record(&payload);
assert!(wrapped.len() > payload.len());
assert!(wrapped.contains(&0x17));
}
#[test]
fn wrap_tls_application_record_roundtrip_size_check() {
let payload_len = 3000;
let payload = vec![0x55; payload_len];
let wrapped = wrap_tls_application_record(&payload);
let mut idx = 0;
let mut consumed = 0;
while idx + 5 <= wrapped.len() {
assert_eq!(wrapped[idx], 0x17);
let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize;
consumed += len;
idx += 5 + len;
if idx >= wrapped.len() { break; }
}
assert_eq!(consumed, payload_len);
}
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where
R: tokio::io::AsyncRead + Unpin,

View File

@@ -0,0 +1,647 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr};
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
// --- Helpers ---
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg.general.modes.classic = true;
cfg.general.modes.tls = true;
cfg
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_tls_client_hello_with_alpn(
secret: &[u8],
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
// --- Category 1: Edge Cases & Protocol Boundaries ---
#[tokio::test]
async fn tls_minimum_viable_length_boundary() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x11u8; 16];
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1;
let mut exact_min_handshake = vec![0x42u8; min_len];
exact_min_handshake[min_len - 1] = 0;
exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let digest = sha256_hmac(&secret, &exact_min_handshake);
exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
let res = handle_tls_handshake(
&exact_min_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed");
let short_handshake = vec![0x42u8; min_len - 1];
let res_short = handle_tls_handshake(
&short_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed");
}
#[tokio::test]
async fn mtproto_extreme_dc_index_serialization() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "22222222222222222222222222222222";
let config = test_config_with_secret_hex(secret_hex);
for (idx, extreme_dc) in [i16::MIN, i16::MAX, -1, 0].into_iter().enumerate() {
// Keep replay state independent per case so we validate dc_idx encoding,
// not duplicate-handshake rejection behavior.
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)), 12345 + idx as u16);
let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, extreme_dc);
let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
match res {
HandshakeResult::Success((_, _, success)) => {
assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc);
}
_ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc),
}
}
}
#[tokio::test]
async fn alpn_strict_case_and_padding_rejection() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x33u8; 16];
let mut config = test_config_with_secret_hex("33333333333333333333333333333333");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap();
let bad_alpns: &[&[u8]] = &[b"H2", b"h2\0", b" http/1.1", b"http/1.1\n"];
for bad_alpn in bad_alpns {
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[*bad_alpn]);
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn);
}
}
#[test]
fn ipv4_mapped_ipv6_bucketing_anomaly() {
let ipv4_mapped_1 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201));
let ipv4_mapped_2 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc633, 0x6402));
let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1);
let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2);
assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)");
assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0");
}
// --- Category 2: Adversarial & Black Hat ---
#[tokio::test]
async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "55555555555555555555555555555555";
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.5:12345".parse().unwrap();
let valid_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let mut invalid_handshake = valid_handshake;
invalid_handshake[SKIP_LEN + PREKEY_LEN + IV_LEN + 1] ^= 0xFF;
let res_invalid = handle_mtproto_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(res_invalid, HandshakeResult::BadClient { .. }));
let res_valid = handle_mtproto_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache");
}
#[tokio::test]
async fn tls_invalid_session_does_not_poison_replay_cache() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x66u8; 16];
let config = test_config_with_secret_hex("66666666666666666666666666666666");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.6:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let mut invalid_handshake = valid_handshake.clone();
let session_idx = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1;
invalid_handshake[session_idx] ^= 0xFF;
let res_invalid = handle_tls_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res_invalid, HandshakeResult::BadClient { .. }));
let res_valid = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache");
}
#[tokio::test]
async fn server_hello_delay_timing_neutrality_on_hmac_failure() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x77u8; 16];
let mut config = test_config_with_secret_hex("77777777777777777777777777777777");
config.censorship.server_hello_delay_min_ms = 50;
config.censorship.server_hello_delay_max_ms = 50;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.7:12345".parse().unwrap();
let mut invalid_handshake = make_valid_tls_handshake(&secret, 0);
invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF;
let start = Instant::now();
let res = handle_tls_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let elapsed = start.elapsed();
assert!(matches!(res, HandshakeResult::BadClient { .. }));
assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels");
}
#[tokio::test]
async fn server_hello_delay_inversion_resilience() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x88u8; 16];
let mut config = test_config_with_secret_hex("88888888888888888888888888888888");
config.censorship.server_hello_delay_min_ms = 100;
config.censorship.server_hello_delay_max_ms = 10;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.8:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let start = Instant::now();
let res = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let elapsed = start.elapsed();
assert!(matches!(res, HandshakeResult::Success(_)));
assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)");
}
#[tokio::test]
async fn mixed_valid_and_invalid_user_secrets_configuration() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let _warn_guard = warned_secrets_test_lock().lock().unwrap();
clear_warned_secrets_for_testing();
let mut config = ProxyConfig::default();
config.access.ignore_time_skew = true;
for i in 0..9 {
let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" };
config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string());
}
let valid_secret_hex = "99999999999999999999999999999999";
config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string());
config.general.modes.secure = true;
config.general.modes.classic = true;
config.general.modes.tls = true;
let secret = [0x99u8; 16];
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.9:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let res = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one");
}
#[tokio::test]
async fn tls_emulation_fallback_when_cache_missing() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0xAAu8; 16];
let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
config.censorship.tls_emulation = true;
config.general.modes.tls = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let res = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing");
}
#[tokio::test]
async fn classic_mode_over_tls_transport_protocol_confusion() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
let mut config = test_config_with_secret_hex(secret_hex);
config.general.modes.classic = true;
config.general.modes.tls = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.11:12345".parse().unwrap();
let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Intermediate, 1);
let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
true,
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior");
}
#[test]
fn generate_tg_nonce_never_emits_reserved_bytes() {
let client_enc_key = [0xCCu8; 32];
let client_enc_iv = 123456789u128;
let rng = SecureRandom::new();
for _ in 0..10_000 {
let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure,
1,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes");
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn dashmap_concurrent_saturation_stress() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let ip_a: IpAddr = "192.0.2.13".parse().unwrap();
let ip_b: IpAddr = "198.51.100.13".parse().unwrap();
let mut tasks = Vec::new();
for i in 0..100 {
let target_ip = if i % 2 == 0 { ip_a } else { ip_b };
tasks.push(tokio::spawn(async move {
for _ in 0..50 {
auth_probe_record_failure(target_ip, Instant::now());
}
}));
}
for task in tasks {
task.await.expect("Task panicked during concurrent DashMap stress");
}
assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress");
assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress");
}
#[test]
fn prototag_invalid_bytes_fail_closed() {
let invalid_tags: [[u8; 4]; 5] = [
[0, 0, 0, 0],
[0xFF, 0xFF, 0xFF, 0xFF],
[0xDE, 0xAD, 0xBE, 0xEF],
[0xDD, 0xDD, 0xDD, 0xDE],
[0x11, 0x22, 0x33, 0x44],
];
for tag in invalid_tags {
assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag);
}
}
#[test]
fn auth_probe_eviction_hash_collision_stress() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let now = Instant::now();
for i in 0..10_000u32 {
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8));
auth_probe_record_failure_with_state(state, ip, now);
}
assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress");
}
#[test]
fn encrypt_tg_nonce_with_ciphers_advances_counter_correctly() {
let client_enc_key = [0xDDu8; 32];
let client_enc_iv = 987654321u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure,
2,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
let (_, mut returned_encryptor, _) = encrypt_tg_nonce_with_ciphers(&nonce);
let zeros = [0u8; 64];
let returned_keystream = returned_encryptor.encrypt(&zeros);
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let mut expected_enc_key = [0u8; 32];
expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
let mut expected_enc_iv_arr = [0u8; IV_LEN];
expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr);
let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv);
let mut manual_input = Vec::new();
manual_input.extend_from_slice(&nonce);
manual_input.extend_from_slice(&zeros);
let manual_output = manual_encryptor.encrypt(&manual_input);
assert_eq!(
returned_keystream,
&manual_output[64..128],
"encrypt_tg_nonce_with_ciphers must correctly advance the AES-CTR counter by exactly the nonce length"
);
}

View File

@@ -0,0 +1,614 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr};
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Barrier;
// --- Helpers ---
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg.general.modes.classic = true;
cfg.general.modes.tls = true;
cfg
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_tls_client_hello_with_sni_and_alpn(
secret: &[u8],
timestamp: u32,
sni_host: &str,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = sni_host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
// --- Category 1: Timing & Delay Invariants ---
#[tokio::test]
async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x1Au8; 16];
let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a");
config.censorship.server_hello_delay_min_ms = 5000;
config.censorship.server_hello_delay_max_ms = 0;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.101:12345".parse().unwrap();
let mut invalid_handshake = make_valid_tls_handshake(&secret, 0);
invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF;
let fut = handle_tls_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
);
// Deterministic assertion: with max_ms == 0 there must be no sleep path,
// so the handshake should complete promptly under a generous timeout budget.
let res = tokio::time::timeout(Duration::from_millis(250), fut)
.await
.expect("max_ms=0 should bypass artificial delay and complete quickly");
assert!(matches!(res, HandshakeResult::BadClient { .. }));
}
#[test]
fn auth_probe_backoff_extreme_fail_streak_clamps_safely() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99));
let now = Instant::now();
state.insert(
peer_ip,
AuthProbeState {
fail_streak: u32::MAX - 1,
blocked_until: now,
last_seen: now,
},
);
auth_probe_record_failure_with_state(&state, peer_ip, now);
let updated = state.get(&peer_ip).unwrap();
assert_eq!(updated.fail_streak, u32::MAX);
let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS);
assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS");
}
#[test]
fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() {
let client_enc_key = [0x2Bu8; 32];
let client_enc_iv = 1337u128;
let rng = SecureRandom::new();
let mut nonces = HashSet::new();
let mut total_set_bits = 0usize;
let iterations = 5_000;
for _ in 0..iterations {
let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure,
2,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
for byte in nonce.iter() {
total_set_bits += byte.count_ones() as usize;
}
assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck.");
}
let total_bits = iterations * HANDSHAKE_LEN * 8;
let ratio = (total_set_bits as f64) / (total_bits as f64);
assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio);
}
#[tokio::test]
async fn mtproto_multi_user_decryption_isolation() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let mut config = ProxyConfig::default();
config.general.modes.secure = true;
config.access.ignore_time_skew = true;
config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string());
config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string());
let good_secret_hex = "33333333333333333333333333333333";
config.access.users.insert("user_c".to_string(), good_secret_hex.to_string());
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap();
let valid_handshake = make_valid_mtproto_handshake(good_secret_hex, ProtoTag::Secure, 1);
let res = handle_mtproto_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
match res {
HandshakeResult::Success((_, _, success)) => {
assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user");
}
_ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn invalid_secret_warning_lock_contention_and_bound() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
let tasks = 50;
let iterations_per_task = 100;
let barrier = Arc::new(Barrier::new(tasks));
let mut handles = Vec::new();
for t in 0..tasks {
let b = barrier.clone();
handles.push(tokio::spawn(async move {
b.wait().await;
for i in 0..iterations_per_task {
let user_name = format!("contention_user_{}_{}", t, i);
warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None);
}
}));
}
for handle in handles {
handle.await.unwrap();
}
let warned = INVALID_SECRET_WARNED.get().unwrap();
let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
assert_eq!(
guard.len(),
WARNED_SECRET_MAX_ENTRIES,
"Concurrent spam of invalid secrets must strictly bound the HashSet memory to WARNED_SECRET_MAX_ENTRIES"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn mtproto_strict_concurrent_replay_race_condition() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A";
let config = Arc::new(test_config_with_secret_hex(secret_hex));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1));
let tasks = 100;
let barrier = Arc::new(Barrier::new(tasks));
let mut handles = Vec::new();
for i in 0..tasks {
let b = barrier.clone();
let cfg = config.clone();
let rc = replay_checker.clone();
let hs = valid_handshake.clone();
handles.push(tokio::spawn(async move {
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16);
b.wait().await;
handle_mtproto_handshake(
&hs,
tokio::io::empty(),
tokio::io::sink(),
peer,
&cfg,
&rc,
false,
None,
)
.await
}));
}
let mut successes = 0;
let mut failures = 0;
for handle in handles {
match handle.await.unwrap() {
HandshakeResult::Success(_) => successes += 1,
HandshakeResult::BadClient { .. } => failures += 1,
_ => panic!("Unexpected error result in concurrent MTProto replay test"),
}
}
assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed");
assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates");
}
#[tokio::test]
async fn tls_alpn_zero_length_protocol_handled_safely() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x5Bu8; 16];
let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]);
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking");
}
#[tokio::test]
async fn tls_sni_massive_hostname_does_not_panic() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x6Cu8; 16];
let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap();
let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap();
let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]);
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic");
}
#[tokio::test]
async fn tls_progressive_truncation_fuzzing_no_panics() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x7Du8; 16];
let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap();
let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]);
let full_len = valid_handshake.len();
// Truncated corpus only: full_len is a valid baseline and should not be
// asserted as BadClient in a truncation-specific test.
for i in (0..full_len).rev() {
let truncated = &valid_handshake[..i];
let res = handle_tls_handshake(
truncated,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i);
}
}
#[tokio::test]
async fn mtproto_pure_entropy_fuzzing_no_panics() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.110:12345".parse().unwrap();
let mut seeded = StdRng::seed_from_u64(0xDEADBEEFCAFE);
for _ in 0..10_000 {
let mut noise = [0u8; HANDSHAKE_LEN];
seeded.fill_bytes(&mut noise);
let res = handle_mtproto_handshake(
&noise,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic");
}
}
#[test]
fn decode_user_secret_odd_length_hex_rejection() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
let mut config = ProxyConfig::default();
config.access.users.clear();
config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string());
let decoded = decode_user_secrets(&config, None);
assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping");
}
#[test]
fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112));
let now = Instant::now();
let extreme_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + 5;
state.insert(
peer_ip,
AuthProbeState {
fail_streak: extreme_streak,
blocked_until: now + Duration::from_secs(5),
last_seen: now,
},
);
{
let mut guard = auth_probe_saturation_state_lock();
*guard = Some(AuthProbeSaturationState {
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS,
blocked_until: now + Duration::from_secs(5),
last_seen: now,
});
}
let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now);
assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period");
}
#[test]
fn auth_probe_saturation_note_resets_retention_window() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let base_time = Instant::now();
auth_probe_note_saturation(base_time);
let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1);
auth_probe_note_saturation(later);
let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5);
// This call may return false if backoff has elapsed, but it must not clear
// the saturation state because `later` refreshed last_seen.
let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time);
let guard = auth_probe_saturation_state_lock();
assert!(
guard.is_some(),
"Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window"
);
}
#[test]
fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() {
let mut config = ProxyConfig::default();
config.general.modes.classic = false;
config.general.modes.secure = true;
config.general.modes.tls = false;
assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false));
assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false));
}
#[test]
fn mtproto_secure_tag_rejected_when_only_classic_mode_enabled() {
let mut config = ProxyConfig::default();
config.general.modes.classic = true;
config.general.modes.secure = false;
config.general.modes.tls = false;
assert!(!mode_enabled_for_proto(&config, ProtoTag::Secure, false));
}
#[test]
fn ipv6_localhost_and_unspecified_normalization() {
let localhost = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
let unspecified = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0));
let norm_local = normalize_auth_probe_ip(localhost);
let norm_unspec = normalize_auth_probe_ip(unspecified);
let expected_bucket = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0));
assert_eq!(norm_local, expected_bucket);
assert_eq!(norm_unspec, expected_bucket);
}

View File

@@ -0,0 +1,337 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom};
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Barrier;
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg.general.modes.classic = true;
cfg.general.modes.tls = true;
cfg
}
fn make_valid_tls_client_hello_with_alpn(
secret: &[u8],
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
record
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
#[tokio::test]
async fn tls_alpn_reject_does_not_pollute_replay_cache() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x11u8; 16];
let mut config = test_config_with_secret_hex("11111111111111111111111111111111");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.201:12345".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let before = replay_checker.stats();
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let after = replay_checker.stats();
assert!(matches!(res, HandshakeResult::BadClient { .. }));
assert_eq!(
before.total_additions, after.total_additions,
"ALPN policy reject must not add TLS digest into replay cache"
);
}
#[tokio::test]
async fn tls_truncated_session_id_len_fails_closed_without_panic() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("33333333333333333333333333333333");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.203:12345".parse().unwrap();
let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1;
let mut malicious = vec![0x42u8; min_len];
malicious[min_len - 1] = u8::MAX;
let res = handle_tls_handshake(
&malicious,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }));
}
#[test]
fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let same = Instant::now();
for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES {
let ip = IpAddr::V4(Ipv4Addr::new(10, 1, (i >> 8) as u8, (i & 0xFF) as u8));
state.insert(
ip,
AuthProbeState {
fail_streak: 7,
blocked_until: same,
last_seen: same,
},
);
}
let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21));
auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1));
assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES);
assert!(state.contains_key(&new_ip));
}
#[test]
fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let saturation = auth_probe_saturation_state();
let poison_thread = std::thread::spawn(move || {
let _hold = saturation
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
panic!("intentional poison for regression coverage");
});
let _ = poison_thread.join();
clear_auth_probe_state_for_testing();
let guard = auth_probe_saturation_state()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
assert!(guard.is_none());
}
#[tokio::test]
async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() {
let _probe_guard = auth_probe_test_guard();
let _warn_guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
clear_warned_secrets_for_testing();
let mut config = ProxyConfig::default();
config.general.modes.secure = true;
config.access.ignore_time_skew = true;
config.access.users.insert(
"short_user".to_string(),
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(),
);
let valid_secret_hex = "77777777777777777777777777777777";
config
.access
.users
.insert("good_user".to_string(), valid_secret_hex.to_string());
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.207:12345".parse().unwrap();
let handshake = make_valid_mtproto_handshake(valid_secret_hex, ProtoTag::Secure, 1);
let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80));
let now = Instant::now();
{
let mut guard = auth_probe_saturation_state()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*guard = Some(AuthProbeSaturationState {
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS,
blocked_until: now + Duration::from_secs(5),
last_seen: now,
});
}
let state = auth_probe_state_map();
state.insert(
peer_ip,
AuthProbeState {
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1,
blocked_until: now,
last_seen: now,
},
);
let tasks = 32;
let barrier = Arc::new(Barrier::new(tasks));
let mut handles = Vec::new();
for _ in 0..tasks {
let b = barrier.clone();
handles.push(tokio::spawn(async move {
b.wait().await;
auth_probe_record_failure(peer_ip, Instant::now());
}));
}
for handle in handles {
handle.await.unwrap();
}
let final_state = state.get(&peer_ip).expect("state must exist");
assert!(
final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
);
assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now()));
}

View File

@@ -0,0 +1,318 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom};
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
salt: u8,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1).wrapping_add(salt);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn make_valid_tls_client_hello_with_sni_and_alpn(
secret: &[u8],
timestamp: u32,
sni_host: &str,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = sni_host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
fn median_ns(samples: &mut [u128]) -> u128 {
samples.sort_unstable();
samples[samples.len() / 2]
}
#[tokio::test]
#[ignore = "manual benchmark: timing-sensitive and host-dependent"]
async fn mtproto_user_scan_timing_manual_benchmark() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
const DECOY_USERS: usize = 8_000;
const ITERATIONS: usize = 250;
let preferred_user = "target_user";
let target_secret_hex = "dededededededededededededededede";
let mut config = ProxyConfig::default();
config.general.modes.secure = true;
config.access.ignore_time_skew = true;
for i in 0..DECOY_USERS {
config.access.users.insert(
format!("decoy_{i}"),
"00000000000000000000000000000000".to_string(),
);
}
config.access.users.insert(
preferred_user.to_string(),
target_secret_hex.to_string(),
);
let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60));
let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60));
let peer_a: SocketAddr = "192.0.2.241:12345".parse().unwrap();
let peer_b: SocketAddr = "192.0.2.242:12345".parse().unwrap();
let mut preferred_samples = Vec::with_capacity(ITERATIONS);
let mut full_scan_samples = Vec::with_capacity(ITERATIONS);
for i in 0..ITERATIONS {
let handshake = make_valid_mtproto_handshake(
target_secret_hex,
ProtoTag::Secure,
1 + i as i16,
(i % 251) as u8,
);
let started_preferred = Instant::now();
let preferred = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer_a,
&config,
&replay_checker_preferred,
false,
Some(preferred_user),
)
.await;
preferred_samples.push(started_preferred.elapsed().as_nanos());
assert!(matches!(preferred, HandshakeResult::Success(_)));
let started_scan = Instant::now();
let full_scan = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer_b,
&config,
&replay_checker_full_scan,
false,
None,
)
.await;
full_scan_samples.push(started_scan.elapsed().as_nanos());
assert!(matches!(full_scan, HandshakeResult::Success(_)));
}
let preferred_median = median_ns(&mut preferred_samples);
let full_scan_median = median_ns(&mut full_scan_samples);
let ratio = if preferred_median == 0 {
0.0
} else {
full_scan_median as f64 / preferred_median as f64
};
println!(
"manual timing benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, preferred_median_ns={preferred_median}, full_scan_median_ns={full_scan_median}, ratio={ratio:.3}"
);
assert!(
full_scan_median >= preferred_median,
"full user scan should not be faster than preferred-user path in this benchmark"
);
}
#[tokio::test]
#[ignore = "manual benchmark: timing-sensitive and host-dependent"]
async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() {
let _guard = auth_probe_test_guard();
const DECOY_USERS: usize = 8_000;
const ITERATIONS: usize = 250;
let preferred_user = "user-b";
let target_secret_hex = "abababababababababababababababab";
let target_secret = [0xABu8; 16];
let mut config = ProxyConfig::default();
config.general.modes.tls = true;
config.access.ignore_time_skew = true;
for i in 0..DECOY_USERS {
config.access.users.insert(
format!("decoy_{i}"),
"00000000000000000000000000000000".to_string(),
);
}
config
.access
.users
.insert(preferred_user.to_string(), target_secret_hex.to_string());
let mut sni_samples = Vec::with_capacity(ITERATIONS);
let mut no_sni_samples = Vec::with_capacity(ITERATIONS);
for i in 0..ITERATIONS {
let with_sni = make_valid_tls_client_hello_with_sni_and_alpn(
&target_secret,
i as u32,
preferred_user,
&[b"h2"],
);
let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000));
let started_sni = Instant::now();
let sni_secrets = decode_user_secrets(&config, Some(preferred_user));
let sni_result = tls::validate_tls_handshake_with_replay_window(
&with_sni,
&sni_secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
);
sni_samples.push(started_sni.elapsed().as_nanos());
assert!(sni_result.is_some());
let started_no_sni = Instant::now();
let no_sni_secrets = decode_user_secrets(&config, None);
let no_sni_result = tls::validate_tls_handshake_with_replay_window(
&no_sni,
&no_sni_secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
);
no_sni_samples.push(started_no_sni.elapsed().as_nanos());
assert!(no_sni_result.is_some());
}
let sni_median = median_ns(&mut sni_samples);
let no_sni_median = median_ns(&mut no_sni_samples);
let ratio = if sni_median == 0 {
0.0
} else {
no_sni_median as f64 / sni_median as f64
};
println!(
"manual tls benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, sni_median_ns={sni_median}, no_sni_median_ns={no_sni_median}, ratio_no_sni_over_sni={ratio:.3}"
);
}

View File

@@ -1540,6 +1540,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() {
&stats,
"user",
None,
0,
&bytes_me2c,
77,
true,
@@ -1566,6 +1567,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() {
&stats,
"user",
None,
0,
&bytes_me2c,
77,
false,
@@ -1606,6 +1608,7 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
&stats,
"user",
None,
0,
&bytes_me2c,
88,
false,
@@ -1652,6 +1655,7 @@ async fn process_me_writer_response_data_enforces_live_user_quota() {
&stats,
"quota-user",
Some(12),
0,
&bytes_me2c,
89,
false,
@@ -1700,6 +1704,7 @@ async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoo
&stats,
user,
Some(1),
0,
&bytes_me2c,
91,
false,
@@ -1717,6 +1722,7 @@ async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoo
&stats,
user,
Some(1),
0,
&bytes_me2c,
92,
false,
@@ -1765,6 +1771,7 @@ async fn process_me_writer_response_data_does_not_forward_partial_payload_when_r
&stats,
"partial-quota-user",
Some(4),
0,
&bytes_me2c,
90,
false,
@@ -1970,6 +1977,7 @@ async fn run_quota_race_attempt(
stats,
user,
Some(1),
0,
bytes_me2c,
conn_id,
false,

View File

@@ -26,6 +26,28 @@ enum RouteConnectionGauge {
Middle,
}
#[derive(Debug, Clone, Copy)]
pub enum MeD2cFlushReason {
QueueDrain,
BatchFrames,
BatchBytes,
MaxDelay,
AckImmediate,
Close,
}
#[derive(Debug, Clone, Copy)]
pub enum MeD2cWriteMode {
Coalesced,
Split,
}
#[derive(Debug, Clone, Copy)]
pub enum MeD2cQuotaRejectStage {
PreWrite,
PostWrite,
}
#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"]
pub struct RouteConnectionLease {
stats: Arc<Stats>,
@@ -140,6 +162,44 @@ pub struct Stats {
me_route_drop_queue_full: AtomicU64,
me_route_drop_queue_full_base: AtomicU64,
me_route_drop_queue_full_high: AtomicU64,
me_d2c_batches_total: AtomicU64,
me_d2c_batch_frames_total: AtomicU64,
me_d2c_batch_bytes_total: AtomicU64,
me_d2c_flush_reason_queue_drain_total: AtomicU64,
me_d2c_flush_reason_batch_frames_total: AtomicU64,
me_d2c_flush_reason_batch_bytes_total: AtomicU64,
me_d2c_flush_reason_max_delay_total: AtomicU64,
me_d2c_flush_reason_ack_immediate_total: AtomicU64,
me_d2c_flush_reason_close_total: AtomicU64,
me_d2c_data_frames_total: AtomicU64,
me_d2c_ack_frames_total: AtomicU64,
me_d2c_payload_bytes_total: AtomicU64,
me_d2c_write_mode_coalesced_total: AtomicU64,
me_d2c_write_mode_split_total: AtomicU64,
me_d2c_quota_reject_pre_write_total: AtomicU64,
me_d2c_quota_reject_post_write_total: AtomicU64,
me_d2c_frame_buf_shrink_total: AtomicU64,
me_d2c_frame_buf_shrink_bytes_total: AtomicU64,
me_d2c_batch_frames_bucket_1: AtomicU64,
me_d2c_batch_frames_bucket_2_4: AtomicU64,
me_d2c_batch_frames_bucket_5_8: AtomicU64,
me_d2c_batch_frames_bucket_9_16: AtomicU64,
me_d2c_batch_frames_bucket_17_32: AtomicU64,
me_d2c_batch_frames_bucket_gt_32: AtomicU64,
me_d2c_batch_bytes_bucket_0_1k: AtomicU64,
me_d2c_batch_bytes_bucket_1k_4k: AtomicU64,
me_d2c_batch_bytes_bucket_4k_16k: AtomicU64,
me_d2c_batch_bytes_bucket_16k_64k: AtomicU64,
me_d2c_batch_bytes_bucket_64k_128k: AtomicU64,
me_d2c_batch_bytes_bucket_gt_128k: AtomicU64,
me_d2c_flush_duration_us_bucket_0_50: AtomicU64,
me_d2c_flush_duration_us_bucket_51_200: AtomicU64,
me_d2c_flush_duration_us_bucket_201_1000: AtomicU64,
me_d2c_flush_duration_us_bucket_1001_5000: AtomicU64,
me_d2c_flush_duration_us_bucket_5001_20000: AtomicU64,
me_d2c_flush_duration_us_bucket_gt_20000: AtomicU64,
me_d2c_batch_timeout_armed_total: AtomicU64,
me_d2c_batch_timeout_fired_total: AtomicU64,
me_writer_pick_sorted_rr_success_try_total: AtomicU64,
me_writer_pick_sorted_rr_success_fallback_total: AtomicU64,
me_writer_pick_sorted_rr_full_total: AtomicU64,
@@ -594,6 +654,215 @@ impl Stats {
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_batches_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_d2c_batches_total.fetch_add(1, Ordering::Relaxed);
}
}
pub fn add_me_d2c_batch_frames_total(&self, frames: u64) {
if self.telemetry_me_allows_normal() {
self.me_d2c_batch_frames_total
.fetch_add(frames, Ordering::Relaxed);
}
}
pub fn add_me_d2c_batch_bytes_total(&self, bytes: u64) {
if self.telemetry_me_allows_normal() {
self.me_d2c_batch_bytes_total
.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_flush_reason(&self, reason: MeD2cFlushReason) {
if !self.telemetry_me_allows_normal() {
return;
}
match reason {
MeD2cFlushReason::QueueDrain => {
self.me_d2c_flush_reason_queue_drain_total
.fetch_add(1, Ordering::Relaxed);
}
MeD2cFlushReason::BatchFrames => {
self.me_d2c_flush_reason_batch_frames_total
.fetch_add(1, Ordering::Relaxed);
}
MeD2cFlushReason::BatchBytes => {
self.me_d2c_flush_reason_batch_bytes_total
.fetch_add(1, Ordering::Relaxed);
}
MeD2cFlushReason::MaxDelay => {
self.me_d2c_flush_reason_max_delay_total
.fetch_add(1, Ordering::Relaxed);
}
MeD2cFlushReason::AckImmediate => {
self.me_d2c_flush_reason_ack_immediate_total
.fetch_add(1, Ordering::Relaxed);
}
MeD2cFlushReason::Close => {
self.me_d2c_flush_reason_close_total
.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn increment_me_d2c_data_frames_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_d2c_data_frames_total.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_ack_frames_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_d2c_ack_frames_total.fetch_add(1, Ordering::Relaxed);
}
}
pub fn add_me_d2c_payload_bytes_total(&self, bytes: u64) {
if self.telemetry_me_allows_normal() {
self.me_d2c_payload_bytes_total
.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_write_mode(&self, mode: MeD2cWriteMode) {
if !self.telemetry_me_allows_normal() {
return;
}
match mode {
MeD2cWriteMode::Coalesced => {
self.me_d2c_write_mode_coalesced_total
.fetch_add(1, Ordering::Relaxed);
}
MeD2cWriteMode::Split => {
self.me_d2c_write_mode_split_total
.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn increment_me_d2c_quota_reject_total(&self, stage: MeD2cQuotaRejectStage) {
if !self.telemetry_me_allows_normal() {
return;
}
match stage {
MeD2cQuotaRejectStage::PreWrite => {
self.me_d2c_quota_reject_pre_write_total
.fetch_add(1, Ordering::Relaxed);
}
MeD2cQuotaRejectStage::PostWrite => {
self.me_d2c_quota_reject_post_write_total
.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn observe_me_d2c_frame_buf_shrink(&self, bytes_freed: u64) {
if !self.telemetry_me_allows_normal() {
return;
}
self.me_d2c_frame_buf_shrink_total
.fetch_add(1, Ordering::Relaxed);
self.me_d2c_frame_buf_shrink_bytes_total
.fetch_add(bytes_freed, Ordering::Relaxed);
}
pub fn observe_me_d2c_batch_frames(&self, frames: u64) {
if !self.telemetry_me_allows_debug() {
return;
}
match frames {
0 => {}
1 => {
self.me_d2c_batch_frames_bucket_1
.fetch_add(1, Ordering::Relaxed);
}
2..=4 => {
self.me_d2c_batch_frames_bucket_2_4
.fetch_add(1, Ordering::Relaxed);
}
5..=8 => {
self.me_d2c_batch_frames_bucket_5_8
.fetch_add(1, Ordering::Relaxed);
}
9..=16 => {
self.me_d2c_batch_frames_bucket_9_16
.fetch_add(1, Ordering::Relaxed);
}
17..=32 => {
self.me_d2c_batch_frames_bucket_17_32
.fetch_add(1, Ordering::Relaxed);
}
_ => {
self.me_d2c_batch_frames_bucket_gt_32
.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn observe_me_d2c_batch_bytes(&self, bytes: u64) {
if !self.telemetry_me_allows_debug() {
return;
}
match bytes {
0..=1024 => {
self.me_d2c_batch_bytes_bucket_0_1k
.fetch_add(1, Ordering::Relaxed);
}
1025..=4096 => {
self.me_d2c_batch_bytes_bucket_1k_4k
.fetch_add(1, Ordering::Relaxed);
}
4097..=16_384 => {
self.me_d2c_batch_bytes_bucket_4k_16k
.fetch_add(1, Ordering::Relaxed);
}
16_385..=65_536 => {
self.me_d2c_batch_bytes_bucket_16k_64k
.fetch_add(1, Ordering::Relaxed);
}
65_537..=131_072 => {
self.me_d2c_batch_bytes_bucket_64k_128k
.fetch_add(1, Ordering::Relaxed);
}
_ => {
self.me_d2c_batch_bytes_bucket_gt_128k
.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn observe_me_d2c_flush_duration_us(&self, duration_us: u64) {
if !self.telemetry_me_allows_debug() {
return;
}
match duration_us {
0..=50 => {
self.me_d2c_flush_duration_us_bucket_0_50
.fetch_add(1, Ordering::Relaxed);
}
51..=200 => {
self.me_d2c_flush_duration_us_bucket_51_200
.fetch_add(1, Ordering::Relaxed);
}
201..=1000 => {
self.me_d2c_flush_duration_us_bucket_201_1000
.fetch_add(1, Ordering::Relaxed);
}
1001..=5000 => {
self.me_d2c_flush_duration_us_bucket_1001_5000
.fetch_add(1, Ordering::Relaxed);
}
5001..=20_000 => {
self.me_d2c_flush_duration_us_bucket_5001_20000
.fetch_add(1, Ordering::Relaxed);
}
_ => {
self.me_d2c_flush_duration_us_bucket_gt_20000
.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn increment_me_d2c_batch_timeout_armed_total(&self) {
if self.telemetry_me_allows_debug() {
self.me_d2c_batch_timeout_armed_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_batch_timeout_fired_total(&self) {
if self.telemetry_me_allows_debug() {
self.me_d2c_batch_timeout_fired_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_writer_pick_success_try_total(&self, mode: MeWriterPickMode) {
if !self.telemetry_me_allows_normal() {
return;
@@ -1229,6 +1498,142 @@ impl Stats {
pub fn get_me_route_drop_queue_full_high(&self) -> u64 {
self.me_route_drop_queue_full_high.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batches_total(&self) -> u64 {
self.me_d2c_batches_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_frames_total(&self) -> u64 {
self.me_d2c_batch_frames_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_total(&self) -> u64 {
self.me_d2c_batch_bytes_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_reason_queue_drain_total(&self) -> u64 {
self.me_d2c_flush_reason_queue_drain_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_reason_batch_frames_total(&self) -> u64 {
self.me_d2c_flush_reason_batch_frames_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_reason_batch_bytes_total(&self) -> u64 {
self.me_d2c_flush_reason_batch_bytes_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_reason_max_delay_total(&self) -> u64 {
self.me_d2c_flush_reason_max_delay_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_reason_ack_immediate_total(&self) -> u64 {
self.me_d2c_flush_reason_ack_immediate_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_reason_close_total(&self) -> u64 {
self.me_d2c_flush_reason_close_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_data_frames_total(&self) -> u64 {
self.me_d2c_data_frames_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_ack_frames_total(&self) -> u64 {
self.me_d2c_ack_frames_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_payload_bytes_total(&self) -> u64 {
self.me_d2c_payload_bytes_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_write_mode_coalesced_total(&self) -> u64 {
self.me_d2c_write_mode_coalesced_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_write_mode_split_total(&self) -> u64 {
self.me_d2c_write_mode_split_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_quota_reject_pre_write_total(&self) -> u64 {
self.me_d2c_quota_reject_pre_write_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_quota_reject_post_write_total(&self) -> u64 {
self.me_d2c_quota_reject_post_write_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_frame_buf_shrink_total(&self) -> u64 {
self.me_d2c_frame_buf_shrink_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_frame_buf_shrink_bytes_total(&self) -> u64 {
self.me_d2c_frame_buf_shrink_bytes_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_frames_bucket_1(&self) -> u64 {
self.me_d2c_batch_frames_bucket_1.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_frames_bucket_2_4(&self) -> u64 {
self.me_d2c_batch_frames_bucket_2_4.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_frames_bucket_5_8(&self) -> u64 {
self.me_d2c_batch_frames_bucket_5_8.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_frames_bucket_9_16(&self) -> u64 {
self.me_d2c_batch_frames_bucket_9_16.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_frames_bucket_17_32(&self) -> u64 {
self.me_d2c_batch_frames_bucket_17_32
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_frames_bucket_gt_32(&self) -> u64 {
self.me_d2c_batch_frames_bucket_gt_32
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_0_1k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_0_1k.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_1k_4k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_4k_16k.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_16k_64k
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_64k_128k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_64k_128k
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_gt_128k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_gt_128k
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_duration_us_bucket_0_50(&self) -> u64 {
self.me_d2c_flush_duration_us_bucket_0_50
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_duration_us_bucket_51_200(&self) -> u64 {
self.me_d2c_flush_duration_us_bucket_51_200
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_duration_us_bucket_201_1000(&self) -> u64 {
self.me_d2c_flush_duration_us_bucket_201_1000
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_duration_us_bucket_1001_5000(&self) -> u64 {
self.me_d2c_flush_duration_us_bucket_1001_5000
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_duration_us_bucket_5001_20000(&self) -> u64 {
self.me_d2c_flush_duration_us_bucket_5001_20000
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_flush_duration_us_bucket_gt_20000(&self) -> u64 {
self.me_d2c_flush_duration_us_bucket_gt_20000
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_timeout_armed_total(&self) -> u64 {
self.me_d2c_batch_timeout_armed_total
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_timeout_fired_total(&self) -> u64 {
self.me_d2c_batch_timeout_fired_total
.load(Ordering::Relaxed)
}
pub fn get_me_writer_pick_sorted_rr_success_try_total(&self) -> u64 {
self.me_writer_pick_sorted_rr_success_try_total
.load(Ordering::Relaxed)
@@ -1898,9 +2303,83 @@ mod tests {
stats.increment_me_crc_mismatch();
stats.increment_me_keepalive_sent();
stats.increment_me_route_drop_queue_full();
stats.increment_me_d2c_batches_total();
stats.add_me_d2c_batch_frames_total(4);
stats.add_me_d2c_batch_bytes_total(4096);
stats.increment_me_d2c_flush_reason(MeD2cFlushReason::BatchBytes);
stats.increment_me_d2c_write_mode(MeD2cWriteMode::Coalesced);
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
stats.observe_me_d2c_frame_buf_shrink(1024);
stats.observe_me_d2c_batch_frames(4);
stats.observe_me_d2c_batch_bytes(4096);
stats.observe_me_d2c_flush_duration_us(120);
stats.increment_me_d2c_batch_timeout_armed_total();
stats.increment_me_d2c_batch_timeout_fired_total();
assert_eq!(stats.get_me_crc_mismatch(), 0);
assert_eq!(stats.get_me_keepalive_sent(), 0);
assert_eq!(stats.get_me_route_drop_queue_full(), 0);
assert_eq!(stats.get_me_d2c_batches_total(), 0);
assert_eq!(stats.get_me_d2c_flush_reason_batch_bytes_total(), 0);
assert_eq!(stats.get_me_d2c_write_mode_coalesced_total(), 0);
assert_eq!(stats.get_me_d2c_quota_reject_pre_write_total(), 0);
assert_eq!(stats.get_me_d2c_frame_buf_shrink_total(), 0);
assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0);
assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0);
assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0);
assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0);
assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0);
}
#[test]
fn test_telemetry_policy_me_normal_blocks_d2c_debug_metrics() {
let stats = Stats::new();
stats.apply_telemetry_policy(TelemetryPolicy {
core_enabled: true,
user_enabled: true,
me_level: MeTelemetryLevel::Normal,
});
stats.increment_me_d2c_batches_total();
stats.add_me_d2c_batch_frames_total(2);
stats.add_me_d2c_batch_bytes_total(2048);
stats.increment_me_d2c_flush_reason(MeD2cFlushReason::QueueDrain);
stats.observe_me_d2c_batch_frames(2);
stats.observe_me_d2c_batch_bytes(2048);
stats.observe_me_d2c_flush_duration_us(100);
stats.increment_me_d2c_batch_timeout_armed_total();
stats.increment_me_d2c_batch_timeout_fired_total();
assert_eq!(stats.get_me_d2c_batches_total(), 1);
assert_eq!(stats.get_me_d2c_batch_frames_total(), 2);
assert_eq!(stats.get_me_d2c_batch_bytes_total(), 2048);
assert_eq!(stats.get_me_d2c_flush_reason_queue_drain_total(), 1);
assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0);
assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0);
assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0);
assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0);
assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0);
}
#[test]
fn test_telemetry_policy_me_debug_enables_d2c_debug_metrics() {
let stats = Stats::new();
stats.apply_telemetry_policy(TelemetryPolicy {
core_enabled: true,
user_enabled: true,
me_level: MeTelemetryLevel::Debug,
});
stats.observe_me_d2c_batch_frames(7);
stats.observe_me_d2c_batch_bytes(70_000);
stats.observe_me_d2c_flush_duration_us(1400);
stats.increment_me_d2c_batch_timeout_armed_total();
stats.increment_me_d2c_batch_timeout_fired_total();
assert_eq!(stats.get_me_d2c_batch_frames_bucket_5_8(), 1);
assert_eq!(stats.get_me_d2c_batch_bytes_bucket_64k_128k(), 1);
assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_1001_5000(), 1);
assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 1);
assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 1);
}
#[test]

View File

@@ -126,14 +126,10 @@ pub(crate) async fn reader_loop(
let data = body.slice(12..);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let data_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed);
let routed = if data_wait_ms == 0 {
reg.route_nowait(cid, MeResponse::Data { flags, data })
.await
} else {
reg.route_with_timeout(cid, MeResponse::Data { flags, data }, data_wait_ms)
.await
};
let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed);
let routed = reg
.route_with_timeout(cid, MeResponse::Data { flags, data }, route_wait_ms)
.await;
if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),