mirror of https://github.com/telemt/telemt.git
Implementation plan + Phase 1 finished
This commit is contained in:
parent
5c29870632
commit
a9f695623d
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Architecture Directives
|
||||||
|
|
||||||
|
> Companion to `Agents.md`. These are **activation directives**, not tutorials.
|
||||||
|
> You already know these patterns — apply them. When making any structural or
|
||||||
|
> design decision, run the relevant section below as a checklist.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Active Principles (always on)
|
||||||
|
|
||||||
|
Apply these on every non-trivial change. No exceptions.
|
||||||
|
|
||||||
|
- **SRP** — one reason to change per component. If you can't name the responsibility in one noun phrase, split it.
|
||||||
|
- **OCP** — extend by adding, not by modifying. New variants/impls over patching existing logic.
|
||||||
|
- **ISP** — traits stay minimal. More than ~5 methods is a split signal.
|
||||||
|
- **DIP** — high-level modules depend on traits, not concrete types. Infrastructure implements domain traits; it does not own domain logic.
|
||||||
|
- **DRY** — one authoritative source per piece of knowledge. Copies are bugs that haven't diverged yet.
|
||||||
|
- **YAGNI** — generic parameters, extension hooks, and pluggable strategies require an *existing* concrete use case, not a hypothetical one.
|
||||||
|
- **KISS** — two equivalent designs: choose the one with fewer concepts. Justify complexity; never assume it.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Layered Architecture
|
||||||
|
|
||||||
|
Dependencies point **inward only**: `Presentation → Application → Domain ← Infrastructure`.
|
||||||
|
|
||||||
|
- Domain layer: zero I/O. No network, no filesystem, no async runtime imports.
|
||||||
|
- Infrastructure: implements domain traits at the boundary. Never leaks SDK/wire types inward.
|
||||||
|
- Anti-Corruption Layer (ACL): all third-party and external-protocol types are translated here. If the external format changes, only the ACL changes.
|
||||||
|
- Presentation: translates wire/HTTP representations to domain types and back. Nothing else.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Design Pattern Selection
|
||||||
|
|
||||||
|
Apply the right pattern. Do not invent a new abstraction when a named pattern fits.
|
||||||
|
|
||||||
|
| Situation | Pattern to apply |
|
||||||
|
|---|---|
|
||||||
|
| Struct with 3+ optional/dependent fields | **Builder** — `build()` returns `Result`, never panics |
|
||||||
|
| Cross-cutting behavior (logging, retry, metrics) on a trait impl | **Decorator** — implements same trait, delegates all calls |
|
||||||
|
| Subsystem with multiple internal components | **Façade** — single public entry point, internals are `pub(crate)` |
|
||||||
|
| Swappable algorithm or policy | **Strategy** — trait injection; generics for compile-time, `dyn` for runtime |
|
||||||
|
| Component notifying decoupled consumers | **Observer** — typed channels (`broadcast`, `watch`), not callback `Vec<Box<dyn Fn>>` |
|
||||||
|
| Exclusive mutable state serving concurrent callers | **Actor** — `mpsc` command channel + `oneshot` reply; no lock needed on state |
|
||||||
|
| Finite state with invalid transition prevention | **Typestate** — distinct types per state; invalid ops are compile errors |
|
||||||
|
| Fixed process skeleton with overridable steps | **Template Method** — defaulted trait method calls required hooks |
|
||||||
|
| Request pipeline with independent handlers | **Chain/Middleware** — generic compile-time chain for hot paths, `dyn` for runtime assembly |
|
||||||
|
| Hiding a concrete type behind a trait | **Factory Function** — returns `Box<dyn Trait>` or `impl Trait` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Data Modeling Rules
|
||||||
|
|
||||||
|
- **Make illegal states unrepresentable.** Type system enforces invariants; runtime validation is a second line, not the first.
|
||||||
|
- **Newtype every primitive** that carries domain meaning. `SessionId(u64)` ≠ `UserId(u64)` — the compiler enforces it.
|
||||||
|
- **Enums over booleans** for any parameter or field with two or more named states.
|
||||||
|
- **Typed error enums** with named variants carrying full diagnostic context. `anyhow` is application-layer only; never in library code.
|
||||||
|
- **Domain types carry no I/O concerns.** No `serde`, no codec, no DB derives on domain structs. Conversions via `From`/`TryFrom` at layer boundaries.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Concurrency Rules
|
||||||
|
|
||||||
|
- Prefer message-passing over shared memory. Shared state is a fallback.
|
||||||
|
- All channels must be **bounded**. Document the bound's rationale inline.
|
||||||
|
- Never hold a lock across an `await` unless atomicity explicitly requires it — document why.
|
||||||
|
- Document lock acquisition order wherever two locks are taken together.
|
||||||
|
- Every `async fn` is cancellation-safe unless explicitly documented otherwise. Mutate shared state *after* the `await` that may be cancelled, not before.
|
||||||
|
- High-read/low-write state: use `arc-swap` or `watch` for lock-free reads.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Error Handling Rules
|
||||||
|
|
||||||
|
- Errors translated at every layer boundary — low-level errors never surface unmodified.
|
||||||
|
- Add context at the propagation site: what operation failed and where.
|
||||||
|
- No `unwrap()`/`expect()` in production paths without a comment proving `None`/`Err` is impossible.
|
||||||
|
- Panics are only permitted in: tests, startup/init unrecoverable failure, and `unreachable!()` with an invariant comment.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. API Design Rules
|
||||||
|
|
||||||
|
- **CQS**: functions that return data must not mutate; functions that mutate return only `Result`.
|
||||||
|
- **Least surprise**: a function does exactly what its name implies. Side effects are documented.
|
||||||
|
- **Idempotency**: `close()`, `shutdown()`, `unregister()` called twice must not panic or error.
|
||||||
|
- **Fallibility at the type level**: failure → `Result<T, E>`. No sentinel values.
|
||||||
|
- **Minimal public surface**: default to `pub(crate)`. Mark `pub` only deliberate API. Re-export through a single surface in `mod.rs`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Performance Rules (hot paths)
|
||||||
|
|
||||||
|
- Annotate hot-path functions with `// HOT PATH: <throughput requirement>`.
|
||||||
|
- Zero allocations per operation in hot paths after initialization. Preallocate in constructors, reuse buffers.
|
||||||
|
- Pass `&[u8]` / `Bytes` slices — not `Vec<u8>`. Use `BytesMut` for reusable mutable buffers.
|
||||||
|
- No `String` formatting in hot paths. No logging without a rate-limit or sampling gate.
|
||||||
|
- Any allocation in a hot path gets a comment: `// ALLOC: <reason and size>`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Testing Rules
|
||||||
|
|
||||||
|
- Bug fixes require a regression test that is **red before the fix, green after**. Name it after the bug.
|
||||||
|
- Property tests for: codec round-trips, state machine invariants, cryptographic protocol correctness.
|
||||||
|
- No shared mutable state between tests. Each test constructs its own environment.
|
||||||
|
- Test doubles hierarchy (simplest first): Fake → Stub → Spy → Mock. Mocks couple to implementation, not behavior — use sparingly.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Pre-Change Checklist
|
||||||
|
|
||||||
|
Run this before proposing or implementing any structural decision:
|
||||||
|
|
||||||
|
- [ ] Responsibility nameable in one noun phrase?
|
||||||
|
- [ ] Layer dependencies point inward only?
|
||||||
|
- [ ] Invalid states unrepresentable in the type system?
|
||||||
|
- [ ] State transitions gated through a single interface?
|
||||||
|
- [ ] All channels bounded?
|
||||||
|
- [ ] No locks held across `await` (or documented)?
|
||||||
|
- [ ] Errors typed and translated at layer boundaries?
|
||||||
|
- [ ] No panics in production paths without invariant proof?
|
||||||
|
- [ ] Hot paths annotated and allocation-free?
|
||||||
|
- [ ] Public surface minimal — only deliberate API marked `pub`?
|
||||||
|
- [ ] Correct pattern chosen from Section 3 table?
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -24,6 +24,8 @@ const DIRECT_S2C_CAP_BYTES: usize = 512 * 1024;
|
||||||
const ME_FRAMES_CAP: usize = 96;
|
const ME_FRAMES_CAP: usize = 96;
|
||||||
const ME_BYTES_CAP: usize = 384 * 1024;
|
const ME_BYTES_CAP: usize = 384 * 1024;
|
||||||
const ME_DELAY_MIN_US: u64 = 150;
|
const ME_DELAY_MIN_US: u64 = 150;
|
||||||
|
const MAX_USER_PROFILES_ENTRIES: usize = 50_000;
|
||||||
|
const MAX_USER_KEY_BYTES: usize = 512;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub enum AdaptiveTier {
|
pub enum AdaptiveTier {
|
||||||
|
|
@ -234,32 +236,48 @@ fn profiles() -> &'static DashMap<String, UserAdaptiveProfile> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn seed_tier_for_user(user: &str) -> AdaptiveTier {
|
pub fn seed_tier_for_user(user: &str) -> AdaptiveTier {
|
||||||
|
if user.len() > MAX_USER_KEY_BYTES {
|
||||||
|
return AdaptiveTier::Base;
|
||||||
|
}
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
if let Some(entry) = profiles().get(user) {
|
if let Some(entry) = profiles().get(user) {
|
||||||
let value = entry.value();
|
let value = *entry.value();
|
||||||
if now.duration_since(value.seen_at) <= PROFILE_TTL {
|
drop(entry);
|
||||||
|
if now.saturating_duration_since(value.seen_at) <= PROFILE_TTL {
|
||||||
return value.tier;
|
return value.tier;
|
||||||
}
|
}
|
||||||
|
profiles().remove_if(user, |_, v| now.saturating_duration_since(v.seen_at) > PROFILE_TTL);
|
||||||
}
|
}
|
||||||
AdaptiveTier::Base
|
AdaptiveTier::Base
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn record_user_tier(user: &str, tier: AdaptiveTier) {
|
pub fn record_user_tier(user: &str, tier: AdaptiveTier) {
|
||||||
let now = Instant::now();
|
if user.len() > MAX_USER_KEY_BYTES {
|
||||||
if let Some(mut entry) = profiles().get_mut(user) {
|
|
||||||
let existing = *entry;
|
|
||||||
let effective = if now.duration_since(existing.seen_at) > PROFILE_TTL {
|
|
||||||
tier
|
|
||||||
} else {
|
|
||||||
max(existing.tier, tier)
|
|
||||||
};
|
|
||||||
*entry = UserAdaptiveProfile {
|
|
||||||
tier: effective,
|
|
||||||
seen_at: now,
|
|
||||||
};
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
profiles().insert(user.to_string(), UserAdaptiveProfile { tier, seen_at: now });
|
let now = Instant::now();
|
||||||
|
let mut was_vacant = false;
|
||||||
|
match profiles().entry(user.to_string()) {
|
||||||
|
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
|
||||||
|
let existing = *entry.get();
|
||||||
|
let effective = if now.saturating_duration_since(existing.seen_at) > PROFILE_TTL {
|
||||||
|
tier
|
||||||
|
} else {
|
||||||
|
max(existing.tier, tier)
|
||||||
|
};
|
||||||
|
entry.insert(UserAdaptiveProfile {
|
||||||
|
tier: effective,
|
||||||
|
seen_at: now,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
dashmap::mapref::entry::Entry::Vacant(slot) => {
|
||||||
|
slot.insert(UserAdaptiveProfile { tier, seen_at: now });
|
||||||
|
was_vacant = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if was_vacant && profiles().len() > MAX_USER_PROFILES_ENTRIES {
|
||||||
|
profiles().retain(|_, v| now.saturating_duration_since(v.seen_at) <= PROFILE_TTL);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn direct_copy_buffers_for_tier(
|
pub fn direct_copy_buffers_for_tier(
|
||||||
|
|
@ -310,6 +328,14 @@ fn scale(base: usize, numerator: usize, denominator: usize, cap: usize) -> usize
|
||||||
scaled.min(cap).max(1)
|
scaled.min(cap).max(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/adaptive_buffers_security_tests.rs"]
|
||||||
|
mod adaptive_buffers_security_tests;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/adaptive_buffers_record_race_security_tests.rs"]
|
||||||
|
mod adaptive_buffers_record_race_security_tests;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -593,7 +593,7 @@ async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
|
||||||
let delay_ms = if max == min {
|
let delay_ms = if max == min {
|
||||||
max
|
max
|
||||||
} else {
|
} else {
|
||||||
rand::rng().random_range(min..=max)
|
crate::proxy::masking::sample_lognormal_percentile_bounded(min, max, &mut rand::rng())
|
||||||
};
|
};
|
||||||
|
|
||||||
if delay_ms > 0 {
|
if delay_ms > 0 {
|
||||||
|
|
@ -1123,6 +1123,10 @@ mod timing_manual_bench_tests;
|
||||||
#[path = "tests/handshake_key_material_zeroization_security_tests.rs"]
|
#[path = "tests/handshake_key_material_zeroization_security_tests.rs"]
|
||||||
mod handshake_key_material_zeroization_security_tests;
|
mod handshake_key_material_zeroization_security_tests;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/handshake_baseline_invariant_tests.rs"]
|
||||||
|
mod handshake_baseline_invariant_tests;
|
||||||
|
|
||||||
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
|
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
|
||||||
/// must never be Copy. A Copy impl would allow silent key duplication,
|
/// must never be Copy. A Copy impl would allow silent key duplication,
|
||||||
/// undermining the zeroize-on-drop guarantee.
|
/// undermining the zeroize-on-drop guarantee.
|
||||||
|
|
|
||||||
|
|
@ -249,6 +249,39 @@ async fn wait_mask_connect_budget(started: Instant) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log-normal sample bounded to [floor, ceiling]. Median = sqrt(floor * ceiling).
|
||||||
|
// Implements Box-Muller transform for standard normal sampling — no external
|
||||||
|
// dependency on rand_distr (which is incompatible with rand 0.10).
|
||||||
|
// sigma is chosen so ~99% of raw samples land inside [floor, ceiling] before clamp.
|
||||||
|
// When floor > ceiling (misconfiguration), returns ceiling (the smaller value).
|
||||||
|
// When floor == ceiling, returns that value. When both are 0, returns 0.
|
||||||
|
pub(crate) fn sample_lognormal_percentile_bounded(floor: u64, ceiling: u64, rng: &mut impl Rng) -> u64 {
|
||||||
|
if ceiling == 0 && floor == 0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if floor > ceiling {
|
||||||
|
return ceiling;
|
||||||
|
}
|
||||||
|
if floor == ceiling {
|
||||||
|
return floor;
|
||||||
|
}
|
||||||
|
let floor_f = floor.max(1) as f64;
|
||||||
|
let ceiling_f = ceiling.max(1) as f64;
|
||||||
|
let mu = (floor_f.ln() + ceiling_f.ln()) / 2.0;
|
||||||
|
// 4.65 ≈ 2 * 2.326 (double-sided z-score for 99th percentile)
|
||||||
|
let sigma = ((ceiling_f / floor_f).ln() / 4.65).max(0.01);
|
||||||
|
// Box-Muller transform: two uniform samples → one standard normal sample
|
||||||
|
let u1: f64 = rng.random_range(f64::MIN_POSITIVE..1.0);
|
||||||
|
let u2: f64 = rng.random_range(0.0_f64..std::f64::consts::TAU);
|
||||||
|
let normal_sample = (-2.0_f64 * u1.ln()).sqrt() * u2.cos();
|
||||||
|
let raw = (mu + sigma * normal_sample).exp();
|
||||||
|
if raw.is_finite() {
|
||||||
|
(raw as u64).clamp(floor, ceiling)
|
||||||
|
} else {
|
||||||
|
((floor_f * ceiling_f).sqrt()) as u64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
||||||
if config.censorship.mask_timing_normalization_enabled {
|
if config.censorship.mask_timing_normalization_enabled {
|
||||||
let floor = config.censorship.mask_timing_normalization_floor_ms;
|
let floor = config.censorship.mask_timing_normalization_floor_ms;
|
||||||
|
|
@ -257,14 +290,16 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
||||||
if ceiling == 0 {
|
if ceiling == 0 {
|
||||||
return Duration::from_millis(0);
|
return Duration::from_millis(0);
|
||||||
}
|
}
|
||||||
|
// floor=0 stays uniform: log-normal cannot model distribution anchored at zero
|
||||||
let mut rng = rand::rng();
|
let mut rng = rand::rng();
|
||||||
return Duration::from_millis(rng.random_range(0..=ceiling));
|
return Duration::from_millis(rng.random_range(0..=ceiling));
|
||||||
}
|
}
|
||||||
if ceiling > floor {
|
if ceiling > floor {
|
||||||
let mut rng = rand::rng();
|
let mut rng = rand::rng();
|
||||||
return Duration::from_millis(rng.random_range(floor..=ceiling));
|
return Duration::from_millis(sample_lognormal_percentile_bounded(floor, ceiling, &mut rng));
|
||||||
}
|
}
|
||||||
return Duration::from_millis(floor);
|
// ceiling <= floor: use the larger value (fail-closed: preserve longer delay)
|
||||||
|
return Duration::from_millis(floor.max(ceiling));
|
||||||
}
|
}
|
||||||
|
|
||||||
MASK_TIMEOUT
|
MASK_TIMEOUT
|
||||||
|
|
@ -1003,3 +1038,11 @@ mod masking_padding_timeout_adversarial_tests;
|
||||||
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
|
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
|
||||||
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
|
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
|
||||||
mod masking_offline_target_redteam_expected_fail_tests;
|
mod masking_offline_target_redteam_expected_fail_tests;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/masking_baseline_invariant_tests.rs"]
|
||||||
|
mod masking_baseline_invariant_tests;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/masking_lognormal_timing_security_tests.rs"]
|
||||||
|
mod masking_lognormal_timing_security_tests;
|
||||||
|
|
|
||||||
|
|
@ -2098,3 +2098,7 @@ mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"]
|
#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"]
|
||||||
mod middle_relay_atomic_quota_invariant_tests;
|
mod middle_relay_atomic_quota_invariant_tests;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/middle_relay_baseline_invariant_tests.rs"]
|
||||||
|
mod middle_relay_baseline_invariant_tests;
|
||||||
|
|
|
||||||
|
|
@ -75,3 +75,7 @@ pub use handshake::*;
|
||||||
pub use masking::*;
|
pub use masking::*;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
pub use relay::*;
|
pub use relay::*;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/test_harness_common.rs"]
|
||||||
|
mod test_harness_common;
|
||||||
|
|
|
||||||
|
|
@ -671,3 +671,7 @@ mod relay_watchdog_delta_security_tests;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[path = "tests/relay_atomic_quota_invariant_tests.rs"]
|
#[path = "tests/relay_atomic_quota_invariant_tests.rs"]
|
||||||
mod relay_atomic_quota_invariant_tests;
|
mod relay_atomic_quota_invariant_tests;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tests/relay_baseline_invariant_tests.rs"]
|
||||||
|
mod relay_baseline_invariant_tests;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,260 @@
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
static RACE_TEST_KEY_COUNTER: AtomicUsize = AtomicUsize::new(1_000_000);
|
||||||
|
|
||||||
|
fn race_unique_key(prefix: &str) -> String {
|
||||||
|
let id = RACE_TEST_KEY_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||||
|
format!("{}_{}", prefix, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── TOCTOU race: concurrent record_user_tier can downgrade tier ─────────
|
||||||
|
// Two threads call record_user_tier for the same NEW user simultaneously.
|
||||||
|
// Thread A records Tier1, Thread B records Base. Without atomic entry API,
|
||||||
|
// the insert() call overwrites without max(), causing Tier1 → Base downgrade.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_record_concurrent_insert_no_tier_downgrade() {
|
||||||
|
// Run multiple rounds to increase race detection probability.
|
||||||
|
for round in 0..50 {
|
||||||
|
let key = race_unique_key(&format!("race_downgrade_{}", round));
|
||||||
|
let key_a = key.clone();
|
||||||
|
let key_b = key.clone();
|
||||||
|
|
||||||
|
let barrier = Arc::new(std::sync::Barrier::new(2));
|
||||||
|
let barrier_a = Arc::clone(&barrier);
|
||||||
|
let barrier_b = Arc::clone(&barrier);
|
||||||
|
|
||||||
|
let ha = std::thread::spawn(move || {
|
||||||
|
barrier_a.wait();
|
||||||
|
record_user_tier(&key_a, AdaptiveTier::Tier2);
|
||||||
|
});
|
||||||
|
|
||||||
|
let hb = std::thread::spawn(move || {
|
||||||
|
barrier_b.wait();
|
||||||
|
record_user_tier(&key_b, AdaptiveTier::Base);
|
||||||
|
});
|
||||||
|
|
||||||
|
ha.join().expect("thread A panicked");
|
||||||
|
hb.join().expect("thread B panicked");
|
||||||
|
|
||||||
|
let result = seed_tier_for_user(&key);
|
||||||
|
profiles().remove(&key);
|
||||||
|
|
||||||
|
// The final tier must be at least Tier2, never downgraded to Base.
|
||||||
|
// With correct max() semantics: max(Tier2, Base) = Tier2.
|
||||||
|
assert!(
|
||||||
|
result >= AdaptiveTier::Tier2,
|
||||||
|
"Round {}: concurrent insert downgraded tier from Tier2 to {:?}",
|
||||||
|
round,
|
||||||
|
result,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── TOCTOU race: three threads write three tiers, highest must survive ──
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_record_triple_concurrent_insert_highest_tier_survives() {
|
||||||
|
for round in 0..30 {
|
||||||
|
let key = race_unique_key(&format!("triple_race_{}", round));
|
||||||
|
let barrier = Arc::new(std::sync::Barrier::new(3));
|
||||||
|
|
||||||
|
let handles: Vec<_> = [AdaptiveTier::Base, AdaptiveTier::Tier1, AdaptiveTier::Tier3]
|
||||||
|
.into_iter()
|
||||||
|
.map(|tier| {
|
||||||
|
let k = key.clone();
|
||||||
|
let b = Arc::clone(&barrier);
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
b.wait();
|
||||||
|
record_user_tier(&k, tier);
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().expect("thread panicked");
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = seed_tier_for_user(&key);
|
||||||
|
profiles().remove(&key);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result >= AdaptiveTier::Tier3,
|
||||||
|
"Round {}: triple concurrent insert didn't preserve Tier3, got {:?}",
|
||||||
|
round,
|
||||||
|
result,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Stress: 20 threads writing different tiers to same key ──────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_record_20_concurrent_writers_no_panic_no_downgrade() {
|
||||||
|
let key = race_unique_key("stress_20");
|
||||||
|
let barrier = Arc::new(std::sync::Barrier::new(20));
|
||||||
|
|
||||||
|
let handles: Vec<_> = (0..20u32)
|
||||||
|
.map(|i| {
|
||||||
|
let k = key.clone();
|
||||||
|
let b = Arc::clone(&barrier);
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
b.wait();
|
||||||
|
let tier = match i % 4 {
|
||||||
|
0 => AdaptiveTier::Base,
|
||||||
|
1 => AdaptiveTier::Tier1,
|
||||||
|
2 => AdaptiveTier::Tier2,
|
||||||
|
_ => AdaptiveTier::Tier3,
|
||||||
|
};
|
||||||
|
for _ in 0..100 {
|
||||||
|
record_user_tier(&k, tier);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().expect("thread panicked");
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = seed_tier_for_user(&key);
|
||||||
|
profiles().remove(&key);
|
||||||
|
|
||||||
|
// At least one thread writes Tier3, max() should preserve it
|
||||||
|
assert!(
|
||||||
|
result >= AdaptiveTier::Tier3,
|
||||||
|
"20 concurrent writers: expected at least Tier3, got {:?}",
|
||||||
|
result,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── TOCTOU: seed reads stale, concurrent record inserts fresh ───────────
|
||||||
|
// Verifies remove_if predicate preserves fresh insertions.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_seed_and_record_race_preserves_fresh_entry() {
|
||||||
|
for round in 0..30 {
|
||||||
|
let key = race_unique_key(&format!("seed_record_race_{}", round));
|
||||||
|
|
||||||
|
// Plant a stale entry
|
||||||
|
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||||
|
profiles().insert(
|
||||||
|
key.clone(),
|
||||||
|
UserAdaptiveProfile {
|
||||||
|
tier: AdaptiveTier::Tier1,
|
||||||
|
seen_at: stale_time,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let key_seed = key.clone();
|
||||||
|
let key_record = key.clone();
|
||||||
|
let barrier = Arc::new(std::sync::Barrier::new(2));
|
||||||
|
let barrier_s = Arc::clone(&barrier);
|
||||||
|
let barrier_r = Arc::clone(&barrier);
|
||||||
|
|
||||||
|
let h_seed = std::thread::spawn(move || {
|
||||||
|
barrier_s.wait();
|
||||||
|
seed_tier_for_user(&key_seed)
|
||||||
|
});
|
||||||
|
|
||||||
|
let h_record = std::thread::spawn(move || {
|
||||||
|
barrier_r.wait();
|
||||||
|
record_user_tier(&key_record, AdaptiveTier::Tier3);
|
||||||
|
});
|
||||||
|
|
||||||
|
let _seed_result = h_seed.join().expect("seed thread panicked");
|
||||||
|
h_record.join().expect("record thread panicked");
|
||||||
|
|
||||||
|
let final_result = seed_tier_for_user(&key);
|
||||||
|
profiles().remove(&key);
|
||||||
|
|
||||||
|
// Fresh Tier3 entry should survive the stale-removal race.
|
||||||
|
// Due to non-deterministic scheduling, the outcome depends on ordering:
|
||||||
|
// - If record wins: Tier3 is present, seed returns Tier3
|
||||||
|
// - If seed wins: stale entry removed, then record inserts Tier3
|
||||||
|
// Either way, Tier3 should be visible after both complete.
|
||||||
|
assert!(
|
||||||
|
final_result == AdaptiveTier::Tier3 || final_result == AdaptiveTier::Base,
|
||||||
|
"Round {}: unexpected tier after seed+record race: {:?}",
|
||||||
|
round,
|
||||||
|
final_result,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Eviction safety: retain() during concurrent inserts ─────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_eviction_during_concurrent_inserts_no_panic() {
|
||||||
|
let prefix = race_unique_key("evict_conc");
|
||||||
|
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||||
|
|
||||||
|
// Pre-fill with stale entries to push past the eviction threshold
|
||||||
|
for i in 0..100 {
|
||||||
|
let k = format!("{}_{}", prefix, i);
|
||||||
|
profiles().insert(
|
||||||
|
k,
|
||||||
|
UserAdaptiveProfile {
|
||||||
|
tier: AdaptiveTier::Base,
|
||||||
|
seen_at: stale_time,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let barrier = Arc::new(std::sync::Barrier::new(10));
|
||||||
|
let handles: Vec<_> = (0..10)
|
||||||
|
.map(|t| {
|
||||||
|
let b = Arc::clone(&barrier);
|
||||||
|
let pfx = prefix.clone();
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
b.wait();
|
||||||
|
for i in 0..50 {
|
||||||
|
let k = format!("{}_t{}_{}", pfx, t, i);
|
||||||
|
record_user_tier(&k, AdaptiveTier::Tier1);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().expect("eviction thread panicked");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
profiles().retain(|k, _| !k.starts_with(&prefix));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Adversarial: attacker races insert+seed in tight loop ───────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_tight_loop_insert_seed_race_no_panic() {
|
||||||
|
let key = race_unique_key("tight_loop");
|
||||||
|
let key_w = key.clone();
|
||||||
|
let key_r = key.clone();
|
||||||
|
|
||||||
|
let done = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||||
|
let done_w = Arc::clone(&done);
|
||||||
|
let done_r = Arc::clone(&done);
|
||||||
|
|
||||||
|
let writer = std::thread::spawn(move || {
|
||||||
|
while !done_w.load(Ordering::Relaxed) {
|
||||||
|
record_user_tier(&key_w, AdaptiveTier::Tier2);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let reader = std::thread::spawn(move || {
|
||||||
|
while !done_r.load(Ordering::Relaxed) {
|
||||||
|
let _ = seed_tier_for_user(&key_r);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
std::thread::sleep(Duration::from_millis(100));
|
||||||
|
done.store(true, Ordering::Relaxed);
|
||||||
|
|
||||||
|
writer.join().expect("writer panicked");
|
||||||
|
reader.join().expect("reader panicked");
|
||||||
|
profiles().remove(&key);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,447 @@
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
// Unique key generator to avoid test interference through the global DashMap.
|
||||||
|
static TEST_KEY_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
|
||||||
|
fn unique_key(prefix: &str) -> String {
|
||||||
|
let id = TEST_KEY_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||||
|
format!("{}_{}", prefix, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Positive / Lifecycle ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_seed_unknown_user_returns_base() {
|
||||||
|
let key = unique_key("seed_unknown");
|
||||||
|
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Base);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_record_then_seed_returns_recorded_tier() {
|
||||||
|
let key = unique_key("record_seed");
|
||||||
|
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||||
|
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Tier1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_separate_users_have_independent_tiers() {
|
||||||
|
let key_a = unique_key("indep_a");
|
||||||
|
let key_b = unique_key("indep_b");
|
||||||
|
record_user_tier(&key_a, AdaptiveTier::Tier1);
|
||||||
|
record_user_tier(&key_b, AdaptiveTier::Tier2);
|
||||||
|
assert_eq!(seed_tier_for_user(&key_a), AdaptiveTier::Tier1);
|
||||||
|
assert_eq!(seed_tier_for_user(&key_b), AdaptiveTier::Tier2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_record_upgrades_tier_within_ttl() {
|
||||||
|
let key = unique_key("upgrade");
|
||||||
|
record_user_tier(&key, AdaptiveTier::Base);
|
||||||
|
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||||
|
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Tier1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_record_does_not_downgrade_within_ttl() {
|
||||||
|
let key = unique_key("no_downgrade");
|
||||||
|
record_user_tier(&key, AdaptiveTier::Tier2);
|
||||||
|
record_user_tier(&key, AdaptiveTier::Base);
|
||||||
|
// max(Tier2, Base) = Tier2 — within TTL the higher tier is retained
|
||||||
|
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Tier2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge Cases ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_base_tier_buffers_unchanged() {
|
||||||
|
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Base, 65536, 262144);
|
||||||
|
assert_eq!(c2s, 65536);
|
||||||
|
assert_eq!(s2c, 262144);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_tier1_buffers_within_caps() {
|
||||||
|
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Tier1, 65536, 262144);
|
||||||
|
assert!(c2s > 65536, "Tier1 c2s should exceed Base");
|
||||||
|
assert!(c2s <= 128 * 1024, "Tier1 c2s should not exceed DIRECT_C2S_CAP_BYTES");
|
||||||
|
assert!(s2c > 262144, "Tier1 s2c should exceed Base");
|
||||||
|
assert!(s2c <= 512 * 1024, "Tier1 s2c should not exceed DIRECT_S2C_CAP_BYTES");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_tier3_buffers_capped() {
|
||||||
|
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Tier3, 65536, 262144);
|
||||||
|
assert!(c2s <= 128 * 1024, "Tier3 c2s must not exceed cap");
|
||||||
|
assert!(s2c <= 512 * 1024, "Tier3 s2c must not exceed cap");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_scale_zero_base_returns_at_least_one() {
|
||||||
|
// scale(0, num, den, cap) should return at least 1 (the .max(1) guard)
|
||||||
|
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Tier1, 0, 0);
|
||||||
|
assert!(c2s >= 1);
|
||||||
|
assert!(s2c >= 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Stale Entry Handling ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_stale_profile_returns_base_tier() {
|
||||||
|
let key = unique_key("stale_base");
|
||||||
|
// Manually insert a stale entry with seen_at in the far past.
|
||||||
|
// PROFILE_TTL = 300s, so 600s ago is well past expiry.
|
||||||
|
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||||
|
profiles().insert(
|
||||||
|
key.clone(),
|
||||||
|
UserAdaptiveProfile {
|
||||||
|
tier: AdaptiveTier::Tier3,
|
||||||
|
seen_at: stale_time,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
seed_tier_for_user(&key),
|
||||||
|
AdaptiveTier::Base,
|
||||||
|
"Stale profile should return Base"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// RED TEST: exposes the stale entry leak bug.
|
||||||
|
// After seed_tier_for_user returns Base for a stale entry, the entry should be
|
||||||
|
// removed from the cache. Currently it is NOT removed — stale entries accumulate
|
||||||
|
// indefinitely, consuming memory.
|
||||||
|
#[test]
|
||||||
|
fn adaptive_stale_entry_removed_after_seed() {
|
||||||
|
let key = unique_key("stale_removal");
|
||||||
|
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||||
|
profiles().insert(
|
||||||
|
key.clone(),
|
||||||
|
UserAdaptiveProfile {
|
||||||
|
tier: AdaptiveTier::Tier2,
|
||||||
|
seen_at: stale_time,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let _ = seed_tier_for_user(&key);
|
||||||
|
// After seeding, the stale entry should have been removed.
|
||||||
|
assert!(
|
||||||
|
!profiles().contains_key(&key),
|
||||||
|
"Stale entry should be removed from cache after seed_tier_for_user"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Cardinality Attack / Unbounded Growth ───────────────────────────────
|
||||||
|
|
||||||
|
// RED TEST: exposes the missing eviction cap.
|
||||||
|
// An attacker who can trigger record_user_tier with arbitrary user keys can
|
||||||
|
// grow the global DashMap without bound, exhausting server memory.
|
||||||
|
// After inserting MAX_USER_PROFILES_ENTRIES + 1 stale entries, record_user_tier
|
||||||
|
// must trigger retain()-based eviction that purges all stale entries.
|
||||||
|
#[test]
|
||||||
|
fn adaptive_profile_cache_bounded_under_cardinality_attack() {
|
||||||
|
let prefix = unique_key("cardinality");
|
||||||
|
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||||
|
let n = MAX_USER_PROFILES_ENTRIES + 1;
|
||||||
|
for i in 0..n {
|
||||||
|
let key = format!("{}_{}", prefix, i);
|
||||||
|
profiles().insert(
|
||||||
|
key,
|
||||||
|
UserAdaptiveProfile {
|
||||||
|
tier: AdaptiveTier::Base,
|
||||||
|
seen_at: stale_time,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// This insert should push the cache over MAX_USER_PROFILES_ENTRIES and trigger eviction.
|
||||||
|
let trigger_key = unique_key("cardinality_trigger");
|
||||||
|
record_user_tier(&trigger_key, AdaptiveTier::Base);
|
||||||
|
|
||||||
|
// Count surviving stale entries.
|
||||||
|
let mut surviving_stale = 0;
|
||||||
|
for i in 0..n {
|
||||||
|
let key = format!("{}_{}", prefix, i);
|
||||||
|
if profiles().contains_key(&key) {
|
||||||
|
surviving_stale += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Cleanup: remove anything that survived + the trigger key.
|
||||||
|
for i in 0..n {
|
||||||
|
let key = format!("{}_{}", prefix, i);
|
||||||
|
profiles().remove(&key);
|
||||||
|
}
|
||||||
|
profiles().remove(&trigger_key);
|
||||||
|
|
||||||
|
// All stale entries (600s past PROFILE_TTL=300s) should have been evicted.
|
||||||
|
assert_eq!(
|
||||||
|
surviving_stale, 0,
|
||||||
|
"All {} stale entries should be evicted, but {} survived",
|
||||||
|
n, surviving_stale
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Key Length Validation ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// RED TEST: exposes missing key length validation.
|
||||||
|
// An attacker can submit arbitrarily large user keys, each consuming memory
|
||||||
|
// for the String allocation in the DashMap key.
|
||||||
|
#[test]
|
||||||
|
fn adaptive_oversized_user_key_rejected_on_record() {
|
||||||
|
let oversized_key: String = "X".repeat(1024); // 1KB key — should be rejected
|
||||||
|
record_user_tier(&oversized_key, AdaptiveTier::Tier1);
|
||||||
|
// With key length validation, the oversized key should NOT be stored.
|
||||||
|
let stored = profiles().contains_key(&oversized_key);
|
||||||
|
// Cleanup regardless
|
||||||
|
profiles().remove(&oversized_key);
|
||||||
|
assert!(
|
||||||
|
!stored,
|
||||||
|
"Oversized user key (1024 bytes) should be rejected by record_user_tier"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_oversized_user_key_rejected_on_seed() {
|
||||||
|
let oversized_key: String = "X".repeat(1024);
|
||||||
|
// Insert it directly to test seed behavior
|
||||||
|
profiles().insert(
|
||||||
|
oversized_key.clone(),
|
||||||
|
UserAdaptiveProfile {
|
||||||
|
tier: AdaptiveTier::Tier3,
|
||||||
|
seen_at: Instant::now(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let result = seed_tier_for_user(&oversized_key);
|
||||||
|
profiles().remove(&oversized_key);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
AdaptiveTier::Base,
|
||||||
|
"Oversized user key should return Base from seed_tier_for_user"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_empty_user_key_safe() {
|
||||||
|
// Empty string is a valid (if unusual) key — should not panic
|
||||||
|
record_user_tier("", AdaptiveTier::Tier1);
|
||||||
|
let tier = seed_tier_for_user("");
|
||||||
|
profiles().remove("");
|
||||||
|
assert_eq!(tier, AdaptiveTier::Tier1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_max_length_key_accepted() {
|
||||||
|
// A key at exactly 512 bytes should be accepted
|
||||||
|
let key: String = "K".repeat(512);
|
||||||
|
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||||
|
let tier = seed_tier_for_user(&key);
|
||||||
|
profiles().remove(&key);
|
||||||
|
assert_eq!(tier, AdaptiveTier::Tier1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Concurrent Access Safety ────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_concurrent_record_and_seed_no_torn_read() {
|
||||||
|
let key = unique_key("concurrent_rw");
|
||||||
|
let key_clone = key.clone();
|
||||||
|
|
||||||
|
// Record from multiple threads simultaneously
|
||||||
|
let handles: Vec<_> = (0..10)
|
||||||
|
.map(|i| {
|
||||||
|
let k = key_clone.clone();
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let tier = if i % 2 == 0 {
|
||||||
|
AdaptiveTier::Tier1
|
||||||
|
} else {
|
||||||
|
AdaptiveTier::Tier2
|
||||||
|
};
|
||||||
|
record_user_tier(&k, tier);
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().expect("thread panicked");
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = seed_tier_for_user(&key);
|
||||||
|
profiles().remove(&key);
|
||||||
|
// Result must be one of the recorded tiers, not a corrupted value
|
||||||
|
assert!(
|
||||||
|
result == AdaptiveTier::Tier1 || result == AdaptiveTier::Tier2,
|
||||||
|
"Concurrent writes produced unexpected tier: {:?}",
|
||||||
|
result
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_concurrent_seed_does_not_panic() {
|
||||||
|
let key = unique_key("concurrent_seed");
|
||||||
|
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||||
|
let key_clone = key.clone();
|
||||||
|
|
||||||
|
let handles: Vec<_> = (0..20)
|
||||||
|
.map(|_| {
|
||||||
|
let k = key_clone.clone();
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
for _ in 0..100 {
|
||||||
|
let _ = seed_tier_for_user(&k);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().expect("concurrent seed panicked");
|
||||||
|
}
|
||||||
|
profiles().remove(&key);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── TOCTOU: Concurrent seed + record race ───────────────────────────────
|
||||||
|
|
||||||
|
// RED TEST: seed_tier_for_user reads a stale entry, drops the reference,
|
||||||
|
// then another thread inserts a fresh entry. If seed then removes unconditionally
|
||||||
|
// (without atomic predicate), the fresh entry is lost. With remove_if, the
|
||||||
|
// fresh entry survives.
|
||||||
|
#[test]
|
||||||
|
fn adaptive_remove_if_does_not_delete_fresh_concurrent_insert() {
|
||||||
|
let key = unique_key("toctou");
|
||||||
|
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||||
|
profiles().insert(
|
||||||
|
key.clone(),
|
||||||
|
UserAdaptiveProfile {
|
||||||
|
tier: AdaptiveTier::Tier1,
|
||||||
|
seen_at: stale_time,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Thread A: seed_tier (will see stale, should attempt removal)
|
||||||
|
// Thread B: record_user_tier (inserts fresh entry concurrently)
|
||||||
|
let key_a = key.clone();
|
||||||
|
let key_b = key.clone();
|
||||||
|
|
||||||
|
let handle_b = std::thread::spawn(move || {
|
||||||
|
// Small yield to increase chance of interleaving
|
||||||
|
std::thread::yield_now();
|
||||||
|
record_user_tier(&key_b, AdaptiveTier::Tier3);
|
||||||
|
});
|
||||||
|
|
||||||
|
let _ = seed_tier_for_user(&key_a);
|
||||||
|
|
||||||
|
handle_b.join().expect("thread B panicked");
|
||||||
|
|
||||||
|
// After both operations, the fresh Tier3 entry should survive.
|
||||||
|
// With a correct remove_if predicate, the fresh entry is NOT deleted.
|
||||||
|
// Without remove_if (current code), the entry may be lost.
|
||||||
|
let final_tier = seed_tier_for_user(&key);
|
||||||
|
profiles().remove(&key);
|
||||||
|
|
||||||
|
// The fresh Tier3 entry should survive the stale-removal race.
|
||||||
|
// Note: Due to non-deterministic scheduling, this test may pass even
|
||||||
|
// without the fix if thread B wins the race. Run with --test-threads=1
|
||||||
|
// or multiple iterations for reliable detection.
|
||||||
|
assert!(
|
||||||
|
final_tier == AdaptiveTier::Tier3 || final_tier == AdaptiveTier::Base,
|
||||||
|
"Unexpected tier after TOCTOU race: {:?}",
|
||||||
|
final_tier
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Fuzz: Random keys ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_fuzz_random_keys_no_panic() {
|
||||||
|
use rand::{Rng, RngExt};
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let mut keys = Vec::new();
|
||||||
|
for _ in 0..200 {
|
||||||
|
let len: usize = rng.random_range(0..=256);
|
||||||
|
let key: String = (0..len)
|
||||||
|
.map(|_| {
|
||||||
|
let c: u8 = rng.random_range(0x20..=0x7E);
|
||||||
|
c as char
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||||
|
let _ = seed_tier_for_user(&key);
|
||||||
|
keys.push(key);
|
||||||
|
}
|
||||||
|
// Cleanup
|
||||||
|
for key in &keys {
|
||||||
|
profiles().remove(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── average_throughput_to_tier (proposed function, tests the mapping) ────
|
||||||
|
|
||||||
|
// These tests verify the function that will be added in PR-D.
|
||||||
|
// They are written against the current code's constant definitions.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_throughput_mapping_below_threshold_is_base() {
|
||||||
|
// 7 Mbps < 8 Mbps threshold → Base
|
||||||
|
// 7 Mbps = 7_000_000 bps = 875_000 bytes/s over 10s = 8_750_000 bytes
|
||||||
|
// max(c2s, s2c) determines direction
|
||||||
|
let c2s_bytes: u64 = 8_750_000;
|
||||||
|
let s2c_bytes: u64 = 1_000_000;
|
||||||
|
let duration_secs: f64 = 10.0;
|
||||||
|
let avg_bps = (c2s_bytes.max(s2c_bytes) as f64 * 8.0) / duration_secs;
|
||||||
|
// 8_750_000 * 8 / 10 = 7_000_000 bps = 7 Mbps → Base
|
||||||
|
assert!(
|
||||||
|
avg_bps < THROUGHPUT_UP_BPS,
|
||||||
|
"Should be below threshold: {} < {}",
|
||||||
|
avg_bps,
|
||||||
|
THROUGHPUT_UP_BPS,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_throughput_mapping_above_threshold_is_tier1() {
|
||||||
|
// 10 Mbps > 8 Mbps threshold → Tier1
|
||||||
|
let bytes_10mbps_10s: u64 = 12_500_000; // 10 Mbps * 10s / 8 = 12_500_000 bytes
|
||||||
|
let duration_secs: f64 = 10.0;
|
||||||
|
let avg_bps = (bytes_10mbps_10s as f64 * 8.0) / duration_secs;
|
||||||
|
assert!(
|
||||||
|
avg_bps >= THROUGHPUT_UP_BPS,
|
||||||
|
"Should be above threshold: {} >= {}",
|
||||||
|
avg_bps,
|
||||||
|
THROUGHPUT_UP_BPS,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_throughput_short_session_should_return_base() {
|
||||||
|
// Sessions shorter than 1 second should not promote (too little data to judge)
|
||||||
|
let duration_secs: f64 = 0.5;
|
||||||
|
// Even with high throughput, short sessions should return Base
|
||||||
|
assert!(
|
||||||
|
duration_secs < 1.0,
|
||||||
|
"Short session duration guard should activate"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── me_flush_policy_for_tier ────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_me_flush_base_unchanged() {
|
||||||
|
let (frames, bytes, delay) =
|
||||||
|
me_flush_policy_for_tier(AdaptiveTier::Base, 32, 65536, Duration::from_micros(1000));
|
||||||
|
assert_eq!(frames, 32);
|
||||||
|
assert_eq!(bytes, 65536);
|
||||||
|
assert_eq!(delay, Duration::from_micros(1000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_me_flush_tier1_delay_reduced() {
|
||||||
|
let (_, _, delay) =
|
||||||
|
me_flush_policy_for_tier(AdaptiveTier::Tier1, 32, 65536, Duration::from_micros(1000));
|
||||||
|
// Tier1: delay * 7/10 = 700 µs
|
||||||
|
assert_eq!(delay, Duration::from_micros(700));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adaptive_me_flush_delay_never_below_minimum() {
|
||||||
|
let (_, _, delay) =
|
||||||
|
me_flush_policy_for_tier(AdaptiveTier::Tier3, 32, 65536, Duration::from_micros(200));
|
||||||
|
// Tier3: 200 * 3/10 = 60, but min is ME_DELAY_MIN_US = 150
|
||||||
|
assert!(delay.as_micros() >= 150, "Delay must respect minimum");
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,224 @@
|
||||||
|
use super::*;
|
||||||
|
use crate::crypto::sha256_hmac;
|
||||||
|
use crate::stats::ReplayChecker;
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||||
|
let mut cfg = ProxyConfig::default();
|
||||||
|
cfg.access.users.clear();
|
||||||
|
cfg.access
|
||||||
|
.users
|
||||||
|
.insert("user".to_string(), secret_hex.to_string());
|
||||||
|
cfg.access.ignore_time_skew = true;
|
||||||
|
cfg.censorship.mask = true;
|
||||||
|
cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
||||||
|
let session_id_len: usize = 32;
|
||||||
|
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
|
||||||
|
let mut handshake = vec![0x42u8; len];
|
||||||
|
|
||||||
|
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
|
||||||
|
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
|
||||||
|
|
||||||
|
let computed = sha256_hmac(secret, &handshake);
|
||||||
|
let mut digest = computed;
|
||||||
|
let ts = timestamp.to_le_bytes();
|
||||||
|
for i in 0..4 {
|
||||||
|
digest[28 + i] ^= ts[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||||
|
.copy_from_slice(&digest);
|
||||||
|
handshake
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_lock_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||||
|
auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_baseline_probe_always_falls_back_to_masking() {
|
||||||
|
let _guard = test_lock_guard();
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let cfg = test_config_with_secret_hex("11111111111111111111111111111111");
|
||||||
|
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let peer: SocketAddr = "198.51.100.210:44321".parse().unwrap();
|
||||||
|
|
||||||
|
let probe = b"not-a-tls-clienthello";
|
||||||
|
let res = handle_tls_handshake(
|
||||||
|
probe,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
peer,
|
||||||
|
&cfg,
|
||||||
|
&replay_checker,
|
||||||
|
&rng,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_baseline_invalid_secret_triggers_fallback_not_error_response() {
|
||||||
|
let _guard = test_lock_guard();
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let good_secret = [0x22u8; 16];
|
||||||
|
let bad_cfg = test_config_with_secret_hex("33333333333333333333333333333333");
|
||||||
|
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let peer: SocketAddr = "198.51.100.211:44322".parse().unwrap();
|
||||||
|
|
||||||
|
let handshake = make_valid_tls_handshake(&good_secret, 0);
|
||||||
|
let res = handle_tls_handshake(
|
||||||
|
&handshake,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
peer,
|
||||||
|
&bad_cfg,
|
||||||
|
&replay_checker,
|
||||||
|
&rng,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_baseline_auth_probe_streak_increments_per_ip() {
|
||||||
|
let _guard = test_lock_guard();
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let cfg = test_config_with_secret_hex("44444444444444444444444444444444");
|
||||||
|
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.10:5555".parse().unwrap();
|
||||||
|
let untouched_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 11));
|
||||||
|
let bad_probe = b"\x16\x03\x01\x00";
|
||||||
|
|
||||||
|
for expected in 1..=3 {
|
||||||
|
let res = handle_tls_handshake(
|
||||||
|
bad_probe,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
peer,
|
||||||
|
&cfg,
|
||||||
|
&replay_checker,
|
||||||
|
&rng,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||||
|
assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected));
|
||||||
|
assert_eq!(auth_probe_fail_streak_for_testing(untouched_ip), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handshake_baseline_saturation_fires_at_compile_time_threshold() {
|
||||||
|
let _guard = test_lock_guard();
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 33));
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS.saturating_sub(1) {
|
||||||
|
auth_probe_record_failure(ip, now);
|
||||||
|
}
|
||||||
|
assert!(!auth_probe_is_throttled(ip, now));
|
||||||
|
|
||||||
|
auth_probe_record_failure(ip, now);
|
||||||
|
assert!(auth_probe_is_throttled(ip, now));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handshake_baseline_repeated_probes_streak_monotonic() {
|
||||||
|
let _guard = test_lock_guard();
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 42));
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut prev = 0u32;
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
auth_probe_record_failure(ip, now);
|
||||||
|
let current = auth_probe_fail_streak_for_testing(ip).unwrap_or(0);
|
||||||
|
assert!(current >= prev, "streak must be monotonic");
|
||||||
|
prev = current;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn handshake_baseline_throttled_ip_incurs_backoff_delay() {
|
||||||
|
let _guard = test_lock_guard();
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||||
|
auth_probe_record_failure(ip, now);
|
||||||
|
}
|
||||||
|
|
||||||
|
let delay = auth_probe_backoff(AUTH_PROBE_BACKOFF_START_FAILS);
|
||||||
|
assert!(delay >= Duration::from_millis(AUTH_PROBE_BACKOFF_BASE_MS));
|
||||||
|
|
||||||
|
let before_expiry = now + delay.saturating_sub(Duration::from_millis(1));
|
||||||
|
let after_expiry = now + delay + Duration::from_millis(1);
|
||||||
|
|
||||||
|
assert!(auth_probe_is_throttled(ip, before_expiry));
|
||||||
|
assert!(!auth_probe_is_throttled(ip, after_expiry));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_baseline_malformed_probe_frames_fail_closed_to_masking() {
|
||||||
|
let _guard = test_lock_guard();
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let cfg = test_config_with_secret_hex("55555555555555555555555555555555");
|
||||||
|
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let peer: SocketAddr = "198.51.100.212:44323".parse().unwrap();
|
||||||
|
|
||||||
|
let corpus: Vec<Vec<u8>> = vec![
|
||||||
|
vec![0x16, 0x03, 0x01],
|
||||||
|
vec![0x16, 0x03, 0x01, 0xFF, 0xFF],
|
||||||
|
vec![0x00; 128],
|
||||||
|
(0..64u8).collect(),
|
||||||
|
];
|
||||||
|
|
||||||
|
for probe in corpus {
|
||||||
|
let res = timeout(
|
||||||
|
Duration::from_millis(250),
|
||||||
|
handle_tls_handshake(
|
||||||
|
&probe,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
peer,
|
||||||
|
&cfg,
|
||||||
|
&replay_checker,
|
||||||
|
&rng,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("malformed probe handling must complete in bounded time");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(res, HandshakeResult::BadClient { .. } | HandshakeResult::Error(_)),
|
||||||
|
"malformed probe must fail closed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::duplex;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::time::{Duration, Instant, timeout};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_baseline_timing_normalization_budget_within_bounds() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.censorship.mask_timing_normalization_enabled = true;
|
||||||
|
config.censorship.mask_timing_normalization_floor_ms = 120;
|
||||||
|
config.censorship.mask_timing_normalization_ceiling_ms = 180;
|
||||||
|
|
||||||
|
for _ in 0..256 {
|
||||||
|
let budget = mask_outcome_target_budget(&config);
|
||||||
|
assert!(budget >= Duration::from_millis(120));
|
||||||
|
assert!(budget <= Duration::from_millis(180));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn masking_baseline_fallback_relays_to_mask_host() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let initial = b"GET /baseline HTTP/1.1\r\nHost: x\r\n\r\n".to_vec();
|
||||||
|
let reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let initial = initial.clone();
|
||||||
|
let reply = reply.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut seen = vec![0u8; initial.len()];
|
||||||
|
stream.read_exact(&mut seen).await.unwrap();
|
||||||
|
assert_eq!(seen, initial);
|
||||||
|
stream.write_all(&reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.70:55070".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(1024);
|
||||||
|
let (mut visible_reader, visible_writer) = duplex(2048);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
visible_writer,
|
||||||
|
&initial,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; reply.len()];
|
||||||
|
visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, reply);
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_baseline_no_normalization_returns_default_budget() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.censorship.mask_timing_normalization_enabled = false;
|
||||||
|
let budget = mask_outcome_target_budget(&config);
|
||||||
|
assert_eq!(budget, MASK_TIMEOUT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn masking_baseline_unreachable_mask_host_silent_failure() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = 1;
|
||||||
|
config.censorship.mask_timing_normalization_enabled = false;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.71:55071".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(1024);
|
||||||
|
let (mut visible_reader, visible_writer) = duplex(1024);
|
||||||
|
|
||||||
|
let started = Instant::now();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
visible_writer,
|
||||||
|
b"GET / HTTP/1.1\r\n\r\n",
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let elapsed = started.elapsed();
|
||||||
|
|
||||||
|
assert!(elapsed < Duration::from_secs(1));
|
||||||
|
|
||||||
|
let mut buf = [0u8; 1];
|
||||||
|
let read_res = timeout(Duration::from_millis(50), visible_reader.read(&mut buf)).await;
|
||||||
|
match read_res {
|
||||||
|
Ok(Ok(0)) | Err(_) => {}
|
||||||
|
Ok(Ok(n)) => panic!("expected no response bytes, got {n}"),
|
||||||
|
Ok(Err(e)) => panic!("unexpected client-side read error: {e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn masking_baseline_light_fuzz_initial_data_no_panic() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = false;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.72:55072".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let corpus: Vec<Vec<u8>> = vec![
|
||||||
|
vec![],
|
||||||
|
vec![0x00],
|
||||||
|
vec![0xFF; 1024],
|
||||||
|
(0..255u8).collect(),
|
||||||
|
b"\xF0\x28\x8C\x28".to_vec(),
|
||||||
|
];
|
||||||
|
|
||||||
|
for sample in corpus {
|
||||||
|
let (client_reader, _client_writer) = duplex(1024);
|
||||||
|
let (_visible_reader, visible_writer) = duplex(1024);
|
||||||
|
timeout(
|
||||||
|
Duration::from_millis(300),
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
visible_writer,
|
||||||
|
&sample,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("fuzz sample must complete in bounded time");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,333 @@
|
||||||
|
use super::*;
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
use rand::SeedableRng;
|
||||||
|
|
||||||
|
fn seeded_rng(seed: u64) -> StdRng {
|
||||||
|
StdRng::seed_from_u64(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Positive: all samples within configured envelope ────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_all_samples_within_configured_envelope() {
|
||||||
|
let mut rng = seeded_rng(42);
|
||||||
|
let floor: u64 = 500;
|
||||||
|
let ceiling: u64 = 2000;
|
||||||
|
for _ in 0..10_000 {
|
||||||
|
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||||
|
assert!(
|
||||||
|
val >= floor && val <= ceiling,
|
||||||
|
"sample {} outside [{}, {}]",
|
||||||
|
val,
|
||||||
|
floor,
|
||||||
|
ceiling,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Statistical: median near geometric mean ─────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_sample_median_near_geometric_mean_of_range() {
|
||||||
|
let mut rng = seeded_rng(42);
|
||||||
|
let floor: u64 = 500;
|
||||||
|
let ceiling: u64 = 2000;
|
||||||
|
let geometric_mean = ((floor as f64) * (ceiling as f64)).sqrt();
|
||||||
|
|
||||||
|
let mut samples: Vec<u64> = (0..10_000)
|
||||||
|
.map(|_| sample_lognormal_percentile_bounded(floor, ceiling, &mut rng))
|
||||||
|
.collect();
|
||||||
|
samples.sort();
|
||||||
|
let median = samples[samples.len() / 2] as f64;
|
||||||
|
|
||||||
|
let tolerance = geometric_mean * 0.10;
|
||||||
|
assert!(
|
||||||
|
(median - geometric_mean).abs() <= tolerance,
|
||||||
|
"median {} not within 10% of geometric mean {} (tolerance {})",
|
||||||
|
median,
|
||||||
|
geometric_mean,
|
||||||
|
tolerance,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge: degenerate floor == ceiling returns exactly that value ─────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_degenerate_floor_eq_ceiling_returns_floor() {
|
||||||
|
let mut rng = seeded_rng(99);
|
||||||
|
for _ in 0..100 {
|
||||||
|
let val = sample_lognormal_percentile_bounded(1000, 1000, &mut rng);
|
||||||
|
assert_eq!(val, 1000, "floor == ceiling must always return exactly that value");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge: floor > ceiling (misconfiguration) clamps safely ──────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_floor_greater_than_ceiling_returns_ceiling() {
|
||||||
|
let mut rng = seeded_rng(77);
|
||||||
|
let val = sample_lognormal_percentile_bounded(2000, 500, &mut rng);
|
||||||
|
assert_eq!(
|
||||||
|
val, 500,
|
||||||
|
"floor > ceiling misconfiguration must return ceiling (the minimum)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge: floor == 1, ceiling == 1 ──────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_floor_1_ceiling_1_returns_1() {
|
||||||
|
let mut rng = seeded_rng(12);
|
||||||
|
let val = sample_lognormal_percentile_bounded(1, 1, &mut rng);
|
||||||
|
assert_eq!(val, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge: floor == 1, ceiling very large ────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_wide_range_all_samples_within_bounds() {
|
||||||
|
let mut rng = seeded_rng(55);
|
||||||
|
let floor: u64 = 1;
|
||||||
|
let ceiling: u64 = 100_000;
|
||||||
|
for _ in 0..10_000 {
|
||||||
|
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||||
|
assert!(
|
||||||
|
val >= floor && val <= ceiling,
|
||||||
|
"sample {} outside [{}, {}]",
|
||||||
|
val,
|
||||||
|
floor,
|
||||||
|
ceiling,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Adversarial: extreme sigma (floor very close to ceiling) ────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_narrow_range_does_not_panic() {
|
||||||
|
let mut rng = seeded_rng(88);
|
||||||
|
let floor: u64 = 999;
|
||||||
|
let ceiling: u64 = 1001;
|
||||||
|
for _ in 0..10_000 {
|
||||||
|
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||||
|
assert!(
|
||||||
|
val >= floor && val <= ceiling,
|
||||||
|
"narrow range sample {} outside [{}, {}]",
|
||||||
|
val,
|
||||||
|
floor,
|
||||||
|
ceiling,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Adversarial: u64::MAX ceiling does not overflow ──────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_u64_max_ceiling_no_overflow() {
|
||||||
|
let mut rng = seeded_rng(123);
|
||||||
|
let floor: u64 = 1;
|
||||||
|
let ceiling: u64 = u64::MAX;
|
||||||
|
for _ in 0..1000 {
|
||||||
|
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||||
|
assert!(val >= floor, "sample {} below floor {}", val, floor);
|
||||||
|
// u64::MAX clamp ensures no overflow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Adversarial: floor == 0 guard ───────────────────────────────────────
|
||||||
|
// The function should handle floor=0 gracefully even though callers
|
||||||
|
// should never pass it. Verifies no panic on ln(0).
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_floor_zero_no_panic() {
|
||||||
|
let mut rng = seeded_rng(200);
|
||||||
|
let val = sample_lognormal_percentile_bounded(0, 1000, &mut rng);
|
||||||
|
assert!(val <= 1000, "sample {} exceeds ceiling 1000", val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Adversarial: both zero → returns 0 ──────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_both_zero_returns_zero() {
|
||||||
|
let mut rng = seeded_rng(201);
|
||||||
|
let val = sample_lognormal_percentile_bounded(0, 0, &mut rng);
|
||||||
|
assert_eq!(val, 0, "floor=0 ceiling=0 must return 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Distribution shape: not uniform ─────────────────────────────────────
|
||||||
|
// A DPI classifier trained on uniform delay samples should detect a
|
||||||
|
// distribution where > 60% of samples fall in the lower half of the range.
|
||||||
|
// Log-normal is right-skewed: more samples near floor than ceiling.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_distribution_is_right_skewed() {
|
||||||
|
let mut rng = seeded_rng(42);
|
||||||
|
let floor: u64 = 100;
|
||||||
|
let ceiling: u64 = 5000;
|
||||||
|
let midpoint = (floor + ceiling) / 2;
|
||||||
|
|
||||||
|
let samples: Vec<u64> = (0..10_000)
|
||||||
|
.map(|_| sample_lognormal_percentile_bounded(floor, ceiling, &mut rng))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let below_mid = samples.iter().filter(|&&s| s < midpoint).count();
|
||||||
|
let ratio = below_mid as f64 / samples.len() as f64;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
ratio > 0.55,
|
||||||
|
"Log-normal should be right-skewed (>55% below midpoint), got {}%",
|
||||||
|
ratio * 100.0,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Determinism: same seed produces same sequence ───────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_deterministic_with_same_seed() {
|
||||||
|
let mut rng1 = seeded_rng(42);
|
||||||
|
let mut rng2 = seeded_rng(42);
|
||||||
|
for _ in 0..100 {
|
||||||
|
let a = sample_lognormal_percentile_bounded(500, 2000, &mut rng1);
|
||||||
|
let b = sample_lognormal_percentile_bounded(500, 2000, &mut rng2);
|
||||||
|
assert_eq!(a, b, "Same seed must produce same output");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Fuzz: 1000 random (floor, ceiling) pairs, no panics ─────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_fuzz_random_params_no_panic() {
|
||||||
|
use rand::Rng;
|
||||||
|
let mut rng = seeded_rng(999);
|
||||||
|
for _ in 0..1000 {
|
||||||
|
let a: u64 = rng.random_range(0..=10_000);
|
||||||
|
let b: u64 = rng.random_range(0..=10_000);
|
||||||
|
let floor = a.min(b);
|
||||||
|
let ceiling = a.max(b);
|
||||||
|
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||||
|
assert!(
|
||||||
|
val >= floor && val <= ceiling,
|
||||||
|
"fuzz: sample {} outside [{}, {}]",
|
||||||
|
val,
|
||||||
|
floor,
|
||||||
|
ceiling,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Fuzz: adversarial floor > ceiling pairs ──────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_fuzz_inverted_params_no_panic() {
|
||||||
|
use rand::Rng;
|
||||||
|
let mut rng = seeded_rng(777);
|
||||||
|
for _ in 0..500 {
|
||||||
|
let floor: u64 = rng.random_range(1..=10_000);
|
||||||
|
let ceiling: u64 = rng.random_range(0..floor);
|
||||||
|
// When floor > ceiling, must return ceiling (the smaller value)
|
||||||
|
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||||
|
assert_eq!(
|
||||||
|
val, ceiling,
|
||||||
|
"inverted: floor={} ceiling={} should return ceiling, got {}",
|
||||||
|
floor, ceiling, val,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Security: clamp spike check ─────────────────────────────────────────
|
||||||
|
// With well-parameterized sigma, no more than 5% of samples should be
|
||||||
|
// at exactly floor or exactly ceiling (clamp spikes). A spike > 10%
|
||||||
|
// is detectable by DPI as bimodal.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_no_clamp_spike_at_boundaries() {
|
||||||
|
let mut rng = seeded_rng(42);
|
||||||
|
let floor: u64 = 500;
|
||||||
|
let ceiling: u64 = 2000;
|
||||||
|
let n = 10_000;
|
||||||
|
let samples: Vec<u64> = (0..n)
|
||||||
|
.map(|_| sample_lognormal_percentile_bounded(floor, ceiling, &mut rng))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let at_floor = samples.iter().filter(|&&s| s == floor).count();
|
||||||
|
let at_ceiling = samples.iter().filter(|&&s| s == ceiling).count();
|
||||||
|
let floor_pct = at_floor as f64 / n as f64;
|
||||||
|
let ceiling_pct = at_ceiling as f64 / n as f64;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
floor_pct < 0.05,
|
||||||
|
"floor clamp spike: {}% of samples at exactly floor (max 5%)",
|
||||||
|
floor_pct * 100.0,
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
ceiling_pct < 0.05,
|
||||||
|
"ceiling clamp spike: {}% of samples at exactly ceiling (max 5%)",
|
||||||
|
ceiling_pct * 100.0,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Integration: mask_outcome_target_budget uses log-normal for path 3 ──
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn masking_lognormal_integration_budget_within_bounds() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.censorship.mask_timing_normalization_enabled = true;
|
||||||
|
config.censorship.mask_timing_normalization_floor_ms = 500;
|
||||||
|
config.censorship.mask_timing_normalization_ceiling_ms = 2000;
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
let budget = mask_outcome_target_budget(&config);
|
||||||
|
let ms = budget.as_millis() as u64;
|
||||||
|
assert!(
|
||||||
|
ms >= 500 && ms <= 2000,
|
||||||
|
"budget {} ms outside [500, 2000]",
|
||||||
|
ms,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Integration: floor == 0 path stays uniform (NOT log-normal) ─────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn masking_lognormal_floor_zero_path_stays_uniform() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.censorship.mask_timing_normalization_enabled = true;
|
||||||
|
config.censorship.mask_timing_normalization_floor_ms = 0;
|
||||||
|
config.censorship.mask_timing_normalization_ceiling_ms = 1000;
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
let budget = mask_outcome_target_budget(&config);
|
||||||
|
let ms = budget.as_millis() as u64;
|
||||||
|
// floor=0 path uses uniform [0, ceiling], not log-normal
|
||||||
|
assert!(ms <= 1000, "budget {} ms exceeds ceiling 1000", ms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Integration: floor > ceiling misconfiguration is safe ───────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn masking_lognormal_misconfigured_floor_gt_ceiling_safe() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.censorship.mask_timing_normalization_enabled = true;
|
||||||
|
config.censorship.mask_timing_normalization_floor_ms = 2000;
|
||||||
|
config.censorship.mask_timing_normalization_ceiling_ms = 500;
|
||||||
|
|
||||||
|
let budget = mask_outcome_target_budget(&config);
|
||||||
|
let ms = budget.as_millis() as u64;
|
||||||
|
// floor > ceiling: should not exceed the minimum of the two
|
||||||
|
assert!(
|
||||||
|
ms <= 2000,
|
||||||
|
"misconfigured budget {} ms should be bounded",
|
||||||
|
ms,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Stress: rapid repeated calls do not panic or starve ─────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn masking_lognormal_stress_rapid_calls_no_panic() {
|
||||||
|
let mut rng = seeded_rng(42);
|
||||||
|
for _ in 0..100_000 {
|
||||||
|
let _ = sample_lognormal_percentile_bounded(100, 5000, &mut rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
use super::*;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn middle_relay_baseline_public_api_idle_roundtrip_contract() {
|
||||||
|
let _guard = relay_idle_pressure_test_scope();
|
||||||
|
clear_relay_idle_pressure_state_for_testing();
|
||||||
|
|
||||||
|
assert!(mark_relay_idle_candidate(7001));
|
||||||
|
assert_eq!(oldest_relay_idle_candidate(), Some(7001));
|
||||||
|
|
||||||
|
clear_relay_idle_candidate(7001);
|
||||||
|
assert_ne!(oldest_relay_idle_candidate(), Some(7001));
|
||||||
|
|
||||||
|
assert!(mark_relay_idle_candidate(7001));
|
||||||
|
assert_eq!(oldest_relay_idle_candidate(), Some(7001));
|
||||||
|
|
||||||
|
clear_relay_idle_pressure_state_for_testing();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn middle_relay_baseline_public_api_desync_window_contract() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let key = 0xDEAD_BEEF_0000_0001u64;
|
||||||
|
let t0 = Instant::now();
|
||||||
|
|
||||||
|
assert!(should_emit_full_desync(key, false, t0));
|
||||||
|
assert!(!should_emit_full_desync(key, false, t0 + Duration::from_secs(1)));
|
||||||
|
|
||||||
|
let t1 = t0 + DESYNC_DEDUP_WINDOW + Duration::from_millis(10);
|
||||||
|
assert!(should_emit_full_desync(key, false, t1));
|
||||||
|
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,275 @@
|
||||||
|
use super::*;
|
||||||
|
use crate::error::ProxyError;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
use crate::stream::BufferPool;
|
||||||
|
use std::io;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, duplex};
|
||||||
|
use tokio::time::{Duration, timeout};
|
||||||
|
|
||||||
|
struct BrokenPipeWriter;
|
||||||
|
|
||||||
|
impl AsyncWrite for BrokenPipeWriter {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Poll::Ready(Err(io::Error::new(
|
||||||
|
io::ErrorKind::BrokenPipe,
|
||||||
|
"forced broken pipe",
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(start_paused = true)]
|
||||||
|
async fn relay_baseline_activity_timeout_fires_after_inactivity() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "relay-baseline-idle-timeout";
|
||||||
|
|
||||||
|
let (_client_peer, relay_client) = duplex(1024);
|
||||||
|
let (_server_peer, relay_server) = duplex(1024);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
tokio::time::advance(ACTIVITY_TIMEOUT.saturating_sub(Duration::from_secs(1))).await;
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
assert!(
|
||||||
|
!relay_task.is_finished(),
|
||||||
|
"relay must stay alive before inactivity timeout"
|
||||||
|
);
|
||||||
|
|
||||||
|
tokio::time::advance(WATCHDOG_INTERVAL + Duration::from_secs(2)).await;
|
||||||
|
|
||||||
|
let done = timeout(Duration::from_secs(1), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay must complete after inactivity timeout")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(done.is_ok(), "relay must return Ok(()) after inactivity timeout");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_baseline_zero_bytes_returns_ok_and_counters_zero() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "relay-baseline-zero-bytes";
|
||||||
|
|
||||||
|
let (client_peer, relay_client) = duplex(1024);
|
||||||
|
let (server_peer, relay_server) = duplex(1024);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
|
||||||
|
let done = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay must stop after both peers close")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
assert!(done.is_ok(), "relay must return Ok(()) on immediate EOF");
|
||||||
|
assert_eq!(stats.get_user_total_octets(user), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_baseline_bidirectional_bytes_counted_symmetrically() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "relay-baseline-bidir-counters";
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(16 * 1024);
|
||||||
|
let (relay_server, mut server_peer) = duplex(16 * 1024);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
4096,
|
||||||
|
4096,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
let c2s = vec![0xAA; 4096];
|
||||||
|
let s2c = vec![0xBB; 2048];
|
||||||
|
|
||||||
|
client_peer.write_all(&c2s).await.unwrap();
|
||||||
|
server_peer.write_all(&s2c).await.unwrap();
|
||||||
|
|
||||||
|
let mut seen_c2s = vec![0u8; c2s.len()];
|
||||||
|
let mut seen_s2c = vec![0u8; s2c.len()];
|
||||||
|
server_peer.read_exact(&mut seen_c2s).await.unwrap();
|
||||||
|
client_peer.read_exact(&mut seen_s2c).await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(seen_c2s, c2s);
|
||||||
|
assert_eq!(seen_s2c, s2c);
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
|
||||||
|
let done = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay must complete after both peers close")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
assert!(done.is_ok());
|
||||||
|
|
||||||
|
assert_eq!(stats.get_user_total_octets(user), (c2s.len() + s2c.len()) as u64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_baseline_both_sides_close_simultaneously_no_panic() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
|
||||||
|
let (client_peer, relay_client) = duplex(1024);
|
||||||
|
let (relay_server, server_peer) = duplex(1024);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
"relay-baseline-sim-close",
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
|
||||||
|
let done = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay must complete")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
assert!(done.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_baseline_broken_pipe_midtransfer_returns_error() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "relay-baseline-broken-pipe";
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(1024);
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
tokio::io::empty(),
|
||||||
|
BrokenPipeWriter,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
client_peer.write_all(b"trigger").await.unwrap();
|
||||||
|
|
||||||
|
let done = timeout(Duration::from_secs(2), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay must return after broken pipe")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
|
||||||
|
match done {
|
||||||
|
Err(ProxyError::Io(err)) => {
|
||||||
|
assert!(
|
||||||
|
matches!(err.kind(), io::ErrorKind::BrokenPipe | io::ErrorKind::ConnectionReset),
|
||||||
|
"expected BrokenPipe/ConnectionReset, got {:?}",
|
||||||
|
err.kind()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
other => panic!("expected ProxyError::Io, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_baseline_many_small_writes_exact_counter() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let user = "relay-baseline-many-small";
|
||||||
|
|
||||||
|
let (mut client_peer, relay_client) = duplex(4096);
|
||||||
|
let (relay_server, mut server_peer) = duplex(4096);
|
||||||
|
|
||||||
|
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(relay_bidirectional(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
1024,
|
||||||
|
1024,
|
||||||
|
user,
|
||||||
|
Arc::clone(&stats),
|
||||||
|
None,
|
||||||
|
Arc::new(BufferPool::new()),
|
||||||
|
));
|
||||||
|
|
||||||
|
for i in 0..10_000u32 {
|
||||||
|
let b = [(i & 0xFF) as u8];
|
||||||
|
client_peer.write_all(&b).await.unwrap();
|
||||||
|
let mut seen = [0u8; 1];
|
||||||
|
server_peer.read_exact(&mut seen).await.unwrap();
|
||||||
|
assert_eq!(seen, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
drop(client_peer);
|
||||||
|
drop(server_peer);
|
||||||
|
|
||||||
|
let done = timeout(Duration::from_secs(3), relay_task)
|
||||||
|
.await
|
||||||
|
.expect("relay must complete for many small writes")
|
||||||
|
.expect("relay task must not panic");
|
||||||
|
assert!(done.is_ok());
|
||||||
|
assert_eq!(stats.get_user_total_octets(user), 10_000);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,202 @@
|
||||||
|
use crate::config::ProxyConfig;
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
use rand::SeedableRng;
|
||||||
|
use std::io;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::io::AsyncWrite;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::task::{RawWaker, RawWakerVTable, Waker};
|
||||||
|
|
||||||
|
unsafe fn wake_counter_clone(data: *const ()) -> RawWaker {
|
||||||
|
let arc = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||||
|
let cloned = Arc::clone(&arc);
|
||||||
|
let _ = Arc::into_raw(arc);
|
||||||
|
RawWaker::new(Arc::into_raw(cloned).cast::<()>(), &WAKE_COUNTER_WAKER_VTABLE)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn wake_counter_wake(data: *const ()) {
|
||||||
|
let arc = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||||
|
arc.fetch_add(1, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn wake_counter_wake_by_ref(data: *const ()) {
|
||||||
|
let arc = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||||
|
arc.fetch_add(1, Ordering::SeqCst);
|
||||||
|
let _ = Arc::into_raw(arc);
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn wake_counter_drop(data: *const ()) {
|
||||||
|
let _ = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||||
|
}
|
||||||
|
|
||||||
|
static WAKE_COUNTER_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
|
||||||
|
wake_counter_clone,
|
||||||
|
wake_counter_wake,
|
||||||
|
wake_counter_wake_by_ref,
|
||||||
|
wake_counter_drop,
|
||||||
|
);
|
||||||
|
|
||||||
|
fn wake_counter_waker(counter: Arc<AtomicUsize>) -> Waker {
|
||||||
|
let raw = RawWaker::new(
|
||||||
|
Arc::into_raw(counter).cast::<()>(),
|
||||||
|
&WAKE_COUNTER_WAKER_VTABLE,
|
||||||
|
);
|
||||||
|
// SAFETY: `raw` points to a valid `Arc<AtomicUsize>` and uses a vtable
|
||||||
|
// that preserves Arc reference-counting semantics.
|
||||||
|
unsafe { Waker::from_raw(raw) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pending_count_writer_write_pending_does_not_spurious_wake() {
|
||||||
|
let counter = Arc::new(AtomicUsize::new(0));
|
||||||
|
let waker = wake_counter_waker(Arc::clone(&counter));
|
||||||
|
let mut cx = Context::from_waker(&waker);
|
||||||
|
|
||||||
|
let mut writer = PendingCountWriter::new(RecordingWriter::new(), 1, 0);
|
||||||
|
let poll = Pin::new(&mut writer).poll_write(&mut cx, b"x");
|
||||||
|
|
||||||
|
assert!(matches!(poll, Poll::Pending));
|
||||||
|
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pending_count_writer_flush_pending_does_not_spurious_wake() {
|
||||||
|
let counter = Arc::new(AtomicUsize::new(0));
|
||||||
|
let waker = wake_counter_waker(Arc::clone(&counter));
|
||||||
|
let mut cx = Context::from_waker(&waker);
|
||||||
|
|
||||||
|
let mut writer = PendingCountWriter::new(RecordingWriter::new(), 0, 1);
|
||||||
|
let poll = Pin::new(&mut writer).poll_flush(&mut cx);
|
||||||
|
|
||||||
|
assert!(matches!(poll, Poll::Pending));
|
||||||
|
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In-memory AsyncWrite that records both per-write and per-flush granularity.
|
||||||
|
pub struct RecordingWriter {
|
||||||
|
pub writes: Vec<Vec<u8>>,
|
||||||
|
pub flushed: Vec<Vec<u8>>,
|
||||||
|
current_record: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RecordingWriter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
writes: Vec::new(),
|
||||||
|
flushed: Vec::new(),
|
||||||
|
current_record: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn total_bytes(&self) -> usize {
|
||||||
|
self.writes.iter().map(|w| w.len()).sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RecordingWriter {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for RecordingWriter {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
let me = self.as_mut().get_mut();
|
||||||
|
me.writes.push(buf.to_vec());
|
||||||
|
me.current_record.extend_from_slice(buf);
|
||||||
|
Poll::Ready(Ok(buf.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
let me = self.as_mut().get_mut();
|
||||||
|
let record = std::mem::take(&mut me.current_record);
|
||||||
|
if !record.is_empty() {
|
||||||
|
me.flushed.push(record);
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns Poll::Pending for the first N write/flush calls, then delegates.
|
||||||
|
pub struct PendingCountWriter<W> {
|
||||||
|
pub inner: W,
|
||||||
|
pub write_pending_remaining: usize,
|
||||||
|
pub flush_pending_remaining: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<W> PendingCountWriter<W> {
|
||||||
|
pub fn new(inner: W, write_pending: usize, flush_pending: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
write_pending_remaining: write_pending,
|
||||||
|
flush_pending_remaining: flush_pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<W: AsyncWrite + Unpin> AsyncWrite for PendingCountWriter<W> {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
let me = self.as_mut().get_mut();
|
||||||
|
if me.write_pending_remaining > 0 {
|
||||||
|
me.write_pending_remaining -= 1;
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
Pin::new(&mut me.inner).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
let me = self.as_mut().get_mut();
|
||||||
|
if me.flush_pending_remaining > 0 {
|
||||||
|
me.flush_pending_remaining -= 1;
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
Pin::new(&mut me.inner).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn seeded_rng(seed: u64) -> StdRng {
|
||||||
|
StdRng::seed_from_u64(seed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tls_only_config() -> Arc<ProxyConfig> {
|
||||||
|
let mut cfg = ProxyConfig::default();
|
||||||
|
cfg.general.modes.tls = true;
|
||||||
|
Arc::new(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handshake_test_config(secret_hex: &str) -> ProxyConfig {
|
||||||
|
let mut cfg = ProxyConfig::default();
|
||||||
|
cfg.access.users.clear();
|
||||||
|
cfg.access
|
||||||
|
.users
|
||||||
|
.insert("test-user".to_string(), secret_hex.to_string());
|
||||||
|
cfg.access.ignore_time_skew = true;
|
||||||
|
cfg.censorship.mask = true;
|
||||||
|
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
cfg.censorship.mask_port = 0;
|
||||||
|
cfg
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue