diff --git a/src/api/http_utils.rs b/src/api/http_utils.rs index e04bd04..9dfe526 100644 --- a/src/api/http_utils.rs +++ b/src/api/http_utils.rs @@ -24,10 +24,7 @@ pub(super) fn success_response( .unwrap() } -pub(super) fn error_response( - request_id: u64, - failure: ApiFailure, -) -> hyper::Response> { +pub(super) fn error_response(request_id: u64, failure: ApiFailure) -> hyper::Response> { let payload = ErrorResponse { ok: false, error: ErrorBody { diff --git a/src/api/mod.rs b/src/api/mod.rs index 0e2edd4..b622c5e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -19,8 +19,8 @@ use crate::ip_tracker::UserIpTracker; use crate::proxy::route_mode::RouteRuntimeController; use crate::startup::StartupTracker; use crate::stats::Stats; -use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; +use crate::transport::middle_proxy::MePool; mod config_store; mod events; @@ -36,8 +36,8 @@ mod runtime_zero; mod users; use config_store::{current_revision, parse_if_match}; -use http_utils::{error_response, read_json, read_optional_json, success_response}; use events::ApiEventStore; +use http_utils::{error_response, read_json, read_optional_json, success_response}; use model::{ ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData, }; @@ -55,11 +55,11 @@ use runtime_stats::{ MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data, build_upstreams_data, build_zero_all_data, }; +use runtime_watch::spawn_runtime_watchers; use runtime_zero::{ build_limits_effective_data, build_runtime_gates_data, build_security_posture_data, build_system_info_data, }; -use runtime_watch::spawn_runtime_watchers; use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; pub(super) struct ApiRuntimeState { @@ -208,15 +208,15 @@ async fn handle( )); } - if !api_cfg.whitelist.is_empty() - && !api_cfg - .whitelist - .iter() - .any(|net| net.contains(peer.ip())) + if !api_cfg.whitelist.is_empty() && !api_cfg.whitelist.iter().any(|net| net.contains(peer.ip())) { return Ok(error_response( request_id, - ApiFailure::new(StatusCode::FORBIDDEN, "forbidden", "Source IP is not allowed"), + ApiFailure::new( + StatusCode::FORBIDDEN, + "forbidden", + "Source IP is not allowed", + ), )); } @@ -347,7 +347,8 @@ async fn handle( } ("GET", "/v1/runtime/connections/summary") => { let revision = current_revision(&shared.config_path).await?; - let data = build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await; + let data = + build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await; Ok(success_response(StatusCode::OK, data, revision)) } ("GET", "/v1/runtime/events/recent") => { @@ -389,13 +390,16 @@ async fn handle( let (data, revision) = match result { Ok(ok) => ok, Err(error) => { - shared.runtime_events.record("api.user.create.failed", error.code); + shared + .runtime_events + .record("api.user.create.failed", error.code); return Err(error); } }; - shared - .runtime_events - .record("api.user.create.ok", format!("username={}", data.user.username)); + shared.runtime_events.record( + "api.user.create.ok", + format!("username={}", data.user.username), + ); Ok(success_response(StatusCode::CREATED, data, revision)) } _ => { @@ -414,7 +418,8 @@ async fn handle( detected_ip_v6, ) .await; - if let Some(user_info) = users.into_iter().find(|entry| entry.username == user) + if let Some(user_info) = + users.into_iter().find(|entry| entry.username == user) { return Ok(success_response(StatusCode::OK, user_info, revision)); } @@ -435,7 +440,8 @@ async fn handle( )); } let expected_revision = parse_if_match(req.headers()); - let body = read_json::(req.into_body(), body_limit).await?; + let body = + read_json::(req.into_body(), body_limit).await?; let result = patch_user(user, body, expected_revision, &shared).await; let (data, revision) = match result { Ok(ok) => ok, @@ -475,10 +481,9 @@ async fn handle( return Err(error); } }; - shared.runtime_events.record( - "api.user.delete.ok", - format!("username={}", deleted_user), - ); + shared + .runtime_events + .record("api.user.delete.ok", format!("username={}", deleted_user)); return Ok(success_response(StatusCode::OK, deleted_user, revision)); } if method == Method::POST diff --git a/src/api/runtime_init.rs b/src/api/runtime_init.rs index 4bd8943..b7601f5 100644 --- a/src/api/runtime_init.rs +++ b/src/api/runtime_init.rs @@ -167,11 +167,7 @@ async fn current_me_pool_stage_progress(shared: &ApiShared) -> Option { let pool = shared.me_pool.read().await.clone()?; let status = pool.api_status_snapshot().await; let configured_dc_groups = status.configured_dc_groups; - let covered_dc_groups = status - .dcs - .iter() - .filter(|dc| dc.alive_writers > 0) - .count(); + let covered_dc_groups = status.dcs.iter().filter(|dc| dc.alive_writers > 0).count(); let dc_coverage = ratio_01(covered_dc_groups, configured_dc_groups); let writer_coverage = ratio_01(status.alive_writers, status.required_writers); diff --git a/src/api/runtime_stats.rs b/src/api/runtime_stats.rs index b646567..94f27a9 100644 --- a/src/api/runtime_stats.rs +++ b/src/api/runtime_stats.rs @@ -2,8 +2,8 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use crate::config::ApiConfig; use crate::stats::Stats; -use crate::transport::upstream::IpPreference; use crate::transport::UpstreamRouteKind; +use crate::transport::upstream::IpPreference; use super::ApiShared; use super::model::{ diff --git a/src/api/runtime_zero.rs b/src/api/runtime_zero.rs index ba89302..a6eb163 100644 --- a/src/api/runtime_zero.rs +++ b/src/api/runtime_zero.rs @@ -128,7 +128,8 @@ pub(super) fn build_system_info_data( .runtime_state .last_config_reload_epoch_secs .load(Ordering::Relaxed); - let last_config_reload_epoch_secs = (last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs); + let last_config_reload_epoch_secs = + (last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs); let git_commit = option_env!("TELEMT_GIT_COMMIT") .or(option_env!("VERGEN_GIT_SHA")) @@ -153,7 +154,10 @@ pub(super) fn build_system_info_data( uptime_seconds: shared.stats.uptime_secs(), config_path: shared.config_path.display().to_string(), config_hash: revision.to_string(), - config_reload_count: shared.runtime_state.config_reload_count.load(Ordering::Relaxed), + config_reload_count: shared + .runtime_state + .config_reload_count + .load(Ordering::Relaxed), last_config_reload_epoch_secs, } } @@ -233,9 +237,7 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD adaptive_floor_writers_per_core_total: cfg .general .me_adaptive_floor_writers_per_core_total, - adaptive_floor_cpu_cores_override: cfg - .general - .me_adaptive_floor_cpu_cores_override, + adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override, adaptive_floor_max_extra_writers_single_per_core: cfg .general .me_adaptive_floor_max_extra_writers_single_per_core, diff --git a/src/api/users.rs b/src/api/users.rs index f339806..4793f89 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -46,7 +46,9 @@ pub(super) async fn create_user( None => random_user_secret(), }; - if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + if let Some(ad_tag) = body.user_ad_tag.as_ref() + && !is_valid_ad_tag(ad_tag) + { return Err(ApiFailure::bad_request( "user_ad_tag must be exactly 32 hex characters", )); @@ -65,12 +67,18 @@ pub(super) async fn create_user( )); } - cfg.access.users.insert(body.username.clone(), secret.clone()); + cfg.access + .users + .insert(body.username.clone(), secret.clone()); if let Some(ad_tag) = body.user_ad_tag { - cfg.access.user_ad_tags.insert(body.username.clone(), ad_tag); + cfg.access + .user_ad_tags + .insert(body.username.clone(), ad_tag); } if let Some(limit) = body.max_tcp_conns { - cfg.access.user_max_tcp_conns.insert(body.username.clone(), limit); + cfg.access + .user_max_tcp_conns + .insert(body.username.clone(), limit); } if let Some(expiration) = expiration { cfg.access @@ -78,7 +86,9 @@ pub(super) async fn create_user( .insert(body.username.clone(), expiration); } if let Some(quota) = body.data_quota_bytes { - cfg.access.user_data_quota.insert(body.username.clone(), quota); + cfg.access + .user_data_quota + .insert(body.username.clone(), quota); } let updated_limit = body.max_unique_ips; @@ -108,11 +118,15 @@ pub(super) async fn create_user( touched_sections.push(AccessSection::UserMaxUniqueIps); } - let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; + let revision = + save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; drop(_guard); if let Some(limit) = updated_limit { - shared.ip_tracker.set_user_limit(&body.username, limit).await; + shared + .ip_tracker + .set_user_limit(&body.username, limit) + .await; } let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); @@ -140,12 +154,7 @@ pub(super) async fn create_user( recent_unique_ips: 0, recent_unique_ips_list: Vec::new(), total_octets: 0, - links: build_user_links( - &cfg, - &secret, - detected_ip_v4, - detected_ip_v6, - ), + links: build_user_links(&cfg, &secret, detected_ip_v4, detected_ip_v6), }); Ok((CreateUserResponse { user, secret }, revision)) @@ -157,12 +166,16 @@ pub(super) async fn patch_user( expected_revision: Option, shared: &ApiShared, ) -> Result<(UserInfo, String), ApiFailure> { - if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) { + if let Some(secret) = body.secret.as_ref() + && !is_valid_user_secret(secret) + { return Err(ApiFailure::bad_request( "secret must be exactly 32 hex characters", )); } - if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + if let Some(ad_tag) = body.user_ad_tag.as_ref() + && !is_valid_ad_tag(ad_tag) + { return Err(ApiFailure::bad_request( "user_ad_tag must be exactly 32 hex characters", )); @@ -187,10 +200,14 @@ pub(super) async fn patch_user( cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); } if let Some(limit) = body.max_tcp_conns { - cfg.access.user_max_tcp_conns.insert(user.to_string(), limit); + cfg.access + .user_max_tcp_conns + .insert(user.to_string(), limit); } if let Some(expiration) = expiration { - cfg.access.user_expirations.insert(user.to_string(), expiration); + cfg.access + .user_expirations + .insert(user.to_string(), expiration); } if let Some(quota) = body.data_quota_bytes { cfg.access.user_data_quota.insert(user.to_string(), quota); @@ -198,7 +215,9 @@ pub(super) async fn patch_user( let mut updated_limit = None; if let Some(limit) = body.max_unique_ips { - cfg.access.user_max_unique_ips.insert(user.to_string(), limit); + cfg.access + .user_max_unique_ips + .insert(user.to_string(), limit); updated_limit = Some(limit); } @@ -263,7 +282,8 @@ pub(super) async fn rotate_secret( AccessSection::UserDataQuota, AccessSection::UserMaxUniqueIps, ]; - let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; + let revision = + save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; drop(_guard); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); @@ -330,7 +350,8 @@ pub(super) async fn delete_user( AccessSection::UserDataQuota, AccessSection::UserMaxUniqueIps, ]; - let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; + let revision = + save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; drop(_guard); shared.ip_tracker.remove_user_limit(user).await; shared.ip_tracker.clear_user_ips(user).await; @@ -365,12 +386,7 @@ pub(super) async fn users_from_config( .users .get(&username) .map(|secret| { - build_user_links( - cfg, - secret, - startup_detected_ip_v4, - startup_detected_ip_v6, - ) + build_user_links(cfg, secret, startup_detected_ip_v4, startup_detected_ip_v6) }) .unwrap_or(UserLinks { classic: Vec::new(), @@ -392,10 +408,8 @@ pub(super) async fn users_from_config( .get(&username) .copied() .filter(|limit| *limit > 0) - .or( - (cfg.access.user_max_unique_ips_global_each > 0) - .then_some(cfg.access.user_max_unique_ips_global_each), - ), + .or((cfg.access.user_max_unique_ips_global_each > 0) + .then_some(cfg.access.user_max_unique_ips_global_each)), current_connections: stats.get_user_curr_connects(&username), active_unique_ips: active_ip_list.len(), active_unique_ips_list: active_ip_list, diff --git a/src/cli.rs b/src/cli.rs index 035fe92..6dc0e2a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,9 +1,9 @@ //! CLI commands: --init (fire-and-forget setup) +use rand::RngExt; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; -use rand::RngExt; /// Options for the init command pub struct InitOptions { @@ -35,10 +35,10 @@ pub fn parse_init_args(args: &[String]) -> Option { if !args.iter().any(|a| a == "--init") { return None; } - + let mut opts = InitOptions::default(); let mut i = 0; - + while i < args.len() { match args[i].as_str() { "--port" => { @@ -78,7 +78,7 @@ pub fn parse_init_args(args: &[String]) -> Option { } i += 1; } - + Some(opts) } @@ -86,7 +86,7 @@ pub fn parse_init_args(args: &[String]) -> Option { pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[telemt] Fire-and-forget setup"); eprintln!(); - + // 1. Generate or validate secret let secret = match opts.secret { Some(s) => { @@ -98,28 +98,28 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { } None => generate_secret(), }; - + eprintln!("[+] Secret: {}", secret); eprintln!("[+] User: {}", opts.username); eprintln!("[+] Port: {}", opts.port); eprintln!("[+] Domain: {}", opts.domain); - + // 2. Create config directory fs::create_dir_all(&opts.config_dir)?; let config_path = opts.config_dir.join("config.toml"); - + // 3. Write config let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain); fs::write(&config_path, &config_content)?; eprintln!("[+] Config written to {}", config_path.display()); - + // 4. Write systemd unit - let exe_path = std::env::current_exe() - .unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); - + let exe_path = + std::env::current_exe().unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); + let unit_path = Path::new("/etc/systemd/system/telemt.service"); let unit_content = generate_systemd_unit(&exe_path, &config_path); - + match fs::write(unit_path, &unit_content) { Ok(()) => { eprintln!("[+] Systemd unit written to {}", unit_path.display()); @@ -128,31 +128,31 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[!] Cannot write systemd unit (run as root?): {}", e); eprintln!("[!] Manual unit file content:"); eprintln!("{}", unit_content); - + // Still print links and config print_links(&opts.username, &secret, opts.port, &opts.domain); return Ok(()); } } - + // 5. Reload systemd run_cmd("systemctl", &["daemon-reload"]); - + // 6. Enable service run_cmd("systemctl", &["enable", "telemt.service"]); eprintln!("[+] Service enabled"); - + // 7. Start service (unless --no-start) if !opts.no_start { run_cmd("systemctl", &["start", "telemt.service"]); eprintln!("[+] Service started"); - + // Brief delay then check status std::thread::sleep(std::time::Duration::from_secs(1)); let status = Command::new("systemctl") .args(["is-active", "telemt.service"]) .output(); - + match status { Ok(out) if out.status.success() => { eprintln!("[+] Service is running"); @@ -166,12 +166,12 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[+] Service not started (--no-start)"); eprintln!("[+] Start manually: systemctl start telemt.service"); } - + eprintln!(); - + // 8. Print links print_links(&opts.username, &secret, opts.port, &opts.domain); - + Ok(()) } @@ -183,7 +183,7 @@ fn generate_secret() -> String { fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String { format!( -r#"# Telemt MTProxy — auto-generated config + r#"# Telemt MTProxy — auto-generated config # Re-run `telemt --init` to regenerate show_link = ["{username}"] @@ -266,7 +266,7 @@ weight = 10 fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String { format!( -r#"[Unit] + r#"[Unit] Description=Telemt MTProxy Documentation=https://github.com/telemt/telemt After=network-online.target @@ -309,11 +309,13 @@ fn run_cmd(cmd: &str, args: &[&str]) { fn print_links(username: &str, secret: &str, port: u16, domain: &str) { let domain_hex = hex::encode(domain); - + println!("=== Proxy Links ==="); println!("[{}]", username); - println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}", - port, secret, domain_hex); + println!( + " EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}", + port, secret, domain_hex + ); println!(); println!("Replace YOUR_SERVER_IP with your server's public IP."); println!("The proxy will auto-detect and display the correct link on startup."); diff --git a/src/config/defaults.rs b/src/config/defaults.rs index e3d729c..76b9e8b 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; use ipnetwork::IpNetwork; use serde::Deserialize; +use std::collections::HashMap; // Helper defaults kept private to the config module. const DEFAULT_NETWORK_IPV6: Option = Some(false); @@ -143,10 +143,7 @@ pub(crate) fn default_weight() -> u16 { } pub(crate) fn default_metrics_whitelist() -> Vec { - vec![ - "127.0.0.1/32".parse().unwrap(), - "::1/128".parse().unwrap(), - ] + vec!["127.0.0.1/32".parse().unwrap(), "::1/128".parse().unwrap()] } pub(crate) fn default_api_listen() -> String { @@ -169,10 +166,18 @@ pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 { 1000 } -pub(crate) fn default_api_runtime_edge_enabled() -> bool { false } -pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { 1000 } -pub(crate) fn default_api_runtime_edge_top_n() -> usize { 10 } -pub(crate) fn default_api_runtime_edge_events_capacity() -> usize { 256 } +pub(crate) fn default_api_runtime_edge_enabled() -> bool { + false +} +pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { + 1000 +} +pub(crate) fn default_api_runtime_edge_top_n() -> usize { + 10 +} +pub(crate) fn default_api_runtime_edge_events_capacity() -> usize { + 256 +} pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { 500 diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 10fc976..39c31a1 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -31,11 +31,10 @@ use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher}; use tokio::sync::{mpsc, watch}; use tracing::{error, info, warn}; -use crate::config::{ - LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, - MeWriterPickMode, -}; use super::load::{LoadedConfig, ProxyConfig}; +use crate::config::{ + LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode, +}; const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); @@ -44,16 +43,16 @@ const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); /// Fields that are safe to swap without restarting listeners. #[derive(Debug, Clone, PartialEq)] pub struct HotFields { - pub log_level: LogLevel, - pub ad_tag: Option, - pub dns_overrides: Vec, - pub desync_all_full: bool, - pub update_every_secs: u64, - pub me_reinit_every_secs: u64, - pub me_reinit_singleflight: bool, + pub log_level: LogLevel, + pub ad_tag: Option, + pub dns_overrides: Vec, + pub desync_all_full: bool, + pub update_every_secs: u64, + pub me_reinit_every_secs: u64, + pub me_reinit_singleflight: bool, pub me_reinit_coalesce_window_ms: u64, - pub hardswap: bool, - pub me_pool_drain_ttl_secs: u64, + pub hardswap: bool, + pub me_pool_drain_ttl_secs: u64, pub me_instadrain: bool, pub me_pool_drain_threshold: u64, pub me_pool_min_fresh_ratio: f32, @@ -113,12 +112,12 @@ pub struct HotFields { pub me_health_interval_ms_healthy: u64, pub me_admission_poll_ms: u64, pub me_warn_rate_limit_ms: u64, - pub users: std::collections::HashMap, - pub user_ad_tags: std::collections::HashMap, - pub user_max_tcp_conns: std::collections::HashMap, - pub user_expirations: std::collections::HashMap>, - pub user_data_quota: std::collections::HashMap, - pub user_max_unique_ips: std::collections::HashMap, + pub users: std::collections::HashMap, + pub user_ad_tags: std::collections::HashMap, + pub user_max_tcp_conns: std::collections::HashMap, + pub user_expirations: std::collections::HashMap>, + pub user_data_quota: std::collections::HashMap, + pub user_max_unique_ips: std::collections::HashMap, pub user_max_unique_ips_global_each: usize, pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode, pub user_max_unique_ips_window_secs: u64, @@ -127,16 +126,16 @@ pub struct HotFields { impl HotFields { pub fn from_config(cfg: &ProxyConfig) -> Self { Self { - log_level: cfg.general.log_level.clone(), - ad_tag: cfg.general.ad_tag.clone(), - dns_overrides: cfg.network.dns_overrides.clone(), - desync_all_full: cfg.general.desync_all_full, - update_every_secs: cfg.general.effective_update_every_secs(), - me_reinit_every_secs: cfg.general.me_reinit_every_secs, - me_reinit_singleflight: cfg.general.me_reinit_singleflight, + log_level: cfg.general.log_level.clone(), + ad_tag: cfg.general.ad_tag.clone(), + dns_overrides: cfg.network.dns_overrides.clone(), + desync_all_full: cfg.general.desync_all_full, + update_every_secs: cfg.general.effective_update_every_secs(), + me_reinit_every_secs: cfg.general.me_reinit_every_secs, + me_reinit_singleflight: cfg.general.me_reinit_singleflight, me_reinit_coalesce_window_ms: cfg.general.me_reinit_coalesce_window_ms, - hardswap: cfg.general.hardswap, - me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs, + hardswap: cfg.general.hardswap, + me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs, me_instadrain: cfg.general.me_instadrain, me_pool_drain_threshold: cfg.general.me_pool_drain_threshold, me_pool_min_fresh_ratio: cfg.general.me_pool_min_fresh_ratio, @@ -189,15 +188,11 @@ impl HotFields { me_adaptive_floor_min_writers_multi_endpoint: cfg .general .me_adaptive_floor_min_writers_multi_endpoint, - me_adaptive_floor_recover_grace_secs: cfg - .general - .me_adaptive_floor_recover_grace_secs, + me_adaptive_floor_recover_grace_secs: cfg.general.me_adaptive_floor_recover_grace_secs, me_adaptive_floor_writers_per_core_total: cfg .general .me_adaptive_floor_writers_per_core_total, - me_adaptive_floor_cpu_cores_override: cfg - .general - .me_adaptive_floor_cpu_cores_override, + me_adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override, me_adaptive_floor_max_extra_writers_single_per_core: cfg .general .me_adaptive_floor_max_extra_writers_single_per_core, @@ -216,9 +211,15 @@ impl HotFields { me_adaptive_floor_max_warm_writers_global: cfg .general .me_adaptive_floor_max_warm_writers_global, - me_route_backpressure_base_timeout_ms: cfg.general.me_route_backpressure_base_timeout_ms, - me_route_backpressure_high_timeout_ms: cfg.general.me_route_backpressure_high_timeout_ms, - me_route_backpressure_high_watermark_pct: cfg.general.me_route_backpressure_high_watermark_pct, + me_route_backpressure_base_timeout_ms: cfg + .general + .me_route_backpressure_base_timeout_ms, + me_route_backpressure_high_timeout_ms: cfg + .general + .me_route_backpressure_high_timeout_ms, + me_route_backpressure_high_watermark_pct: cfg + .general + .me_route_backpressure_high_watermark_pct, me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms, me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames, me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes, @@ -230,12 +231,12 @@ impl HotFields { me_health_interval_ms_healthy: cfg.general.me_health_interval_ms_healthy, me_admission_poll_ms: cfg.general.me_admission_poll_ms, me_warn_rate_limit_ms: cfg.general.me_warn_rate_limit_ms, - users: cfg.access.users.clone(), - user_ad_tags: cfg.access.user_ad_tags.clone(), - user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), - user_expirations: cfg.access.user_expirations.clone(), - user_data_quota: cfg.access.user_data_quota.clone(), - user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), + users: cfg.access.users.clone(), + user_ad_tags: cfg.access.user_ad_tags.clone(), + user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), + user_expirations: cfg.access.user_expirations.clone(), + user_data_quota: cfg.access.user_data_quota.clone(), + user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each, user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode, user_max_unique_ips_window_secs: cfg.access.user_max_unique_ips_window_secs, @@ -334,7 +335,9 @@ struct ReloadState { impl ReloadState { fn new(applied_snapshot_hash: Option) -> Self { - Self { applied_snapshot_hash } + Self { + applied_snapshot_hash, + } } fn is_applied(&self, hash: u64) -> bool { @@ -481,10 +484,14 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { new.general.me_adaptive_floor_writers_per_core_total; cfg.general.me_adaptive_floor_cpu_cores_override = new.general.me_adaptive_floor_cpu_cores_override; - cfg.general.me_adaptive_floor_max_extra_writers_single_per_core = - new.general.me_adaptive_floor_max_extra_writers_single_per_core; - cfg.general.me_adaptive_floor_max_extra_writers_multi_per_core = - new.general.me_adaptive_floor_max_extra_writers_multi_per_core; + cfg.general + .me_adaptive_floor_max_extra_writers_single_per_core = new + .general + .me_adaptive_floor_max_extra_writers_single_per_core; + cfg.general + .me_adaptive_floor_max_extra_writers_multi_per_core = new + .general + .me_adaptive_floor_max_extra_writers_multi_per_core; cfg.general.me_adaptive_floor_max_active_writers_per_core = new.general.me_adaptive_floor_max_active_writers_per_core; cfg.general.me_adaptive_floor_max_warm_writers_per_core = @@ -543,8 +550,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.server.api.minimal_runtime_cache_ttl_ms != new.server.api.minimal_runtime_cache_ttl_ms || old.server.api.runtime_edge_enabled != new.server.api.runtime_edge_enabled - || old.server.api.runtime_edge_cache_ttl_ms - != new.server.api.runtime_edge_cache_ttl_ms + || old.server.api.runtime_edge_cache_ttl_ms != new.server.api.runtime_edge_cache_ttl_ms || old.server.api.runtime_edge_top_n != new.server.api.runtime_edge_top_n || old.server.api.runtime_edge_events_capacity != new.server.api.runtime_edge_events_capacity @@ -583,10 +589,8 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening || old.censorship.mask_shape_bucket_floor_bytes != new.censorship.mask_shape_bucket_floor_bytes - || old.censorship.mask_shape_bucket_cap_bytes - != new.censorship.mask_shape_bucket_cap_bytes - || old.censorship.mask_shape_above_cap_blur - != new.censorship.mask_shape_above_cap_blur + || old.censorship.mask_shape_bucket_cap_bytes != new.censorship.mask_shape_bucket_cap_bytes + || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur || old.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes || old.censorship.mask_timing_normalization_enabled @@ -870,8 +874,7 @@ fn log_changes( { info!( "config reload: me_bind_stale: mode={:?} ttl={}s", - new_hot.me_bind_stale_mode, - new_hot.me_bind_stale_ttl_secs + new_hot.me_bind_stale_mode, new_hot.me_bind_stale_ttl_secs ); } if old_hot.me_secret_atomic_snapshot != new_hot.me_secret_atomic_snapshot @@ -951,8 +954,7 @@ fn log_changes( if old_hot.me_socks_kdf_policy != new_hot.me_socks_kdf_policy { info!( "config reload: me_socks_kdf_policy: {:?} → {:?}", - old_hot.me_socks_kdf_policy, - new_hot.me_socks_kdf_policy, + old_hot.me_socks_kdf_policy, new_hot.me_socks_kdf_policy, ); } @@ -1006,8 +1008,7 @@ fn log_changes( || old_hot.me_route_backpressure_high_watermark_pct != new_hot.me_route_backpressure_high_watermark_pct || old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms - || old_hot.me_health_interval_ms_unhealthy - != new_hot.me_health_interval_ms_unhealthy + || old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy || old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy || old_hot.me_admission_poll_ms != new_hot.me_admission_poll_ms || old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms @@ -1044,19 +1045,27 @@ fn log_changes( } if old_hot.users != new_hot.users { - let mut added: Vec<&String> = new_hot.users.keys() + let mut added: Vec<&String> = new_hot + .users + .keys() .filter(|u| !old_hot.users.contains_key(*u)) .collect(); added.sort(); - let mut removed: Vec<&String> = old_hot.users.keys() + let mut removed: Vec<&String> = old_hot + .users + .keys() .filter(|u| !new_hot.users.contains_key(*u)) .collect(); removed.sort(); - let mut changed: Vec<&String> = new_hot.users.keys() + let mut changed: Vec<&String> = new_hot + .users + .keys() .filter(|u| { - old_hot.users.get(*u) + old_hot + .users + .get(*u) .map(|s| s != &new_hot.users[*u]) .unwrap_or(false) }) @@ -1066,10 +1075,18 @@ fn log_changes( if !added.is_empty() { info!( "config reload: users added: [{}]", - added.iter().map(|s| s.as_str()).collect::>().join(", ") + added + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6); - let port = new_cfg.general.links.public_port.unwrap_or(new_cfg.server.port); + let port = new_cfg + .general + .links + .public_port + .unwrap_or(new_cfg.server.port); for user in &added { if let Some(secret) = new_hot.users.get(*user) { print_user_links(user, secret, &host, port, new_cfg); @@ -1079,13 +1096,21 @@ fn log_changes( if !removed.is_empty() { info!( "config reload: users removed: [{}]", - removed.iter().map(|s| s.as_str()).collect::>().join(", ") + removed + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); } if !changed.is_empty() { info!( "config reload: users secret changed: [{}]", - changed.iter().map(|s| s.as_str()).collect::>().join(", ") + changed + .iter() + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); } } @@ -1116,8 +1141,7 @@ fn log_changes( } if old_hot.user_max_unique_ips_global_each != new_hot.user_max_unique_ips_global_each || old_hot.user_max_unique_ips_mode != new_hot.user_max_unique_ips_mode - || old_hot.user_max_unique_ips_window_secs - != new_hot.user_max_unique_ips_window_secs + || old_hot.user_max_unique_ips_window_secs != new_hot.user_max_unique_ips_window_secs { info!( "config reload: user_max_unique_ips policy global_each={} mode={:?} window={}s", @@ -1152,7 +1176,10 @@ fn reload_config( let next_manifest = WatchManifest::from_source_files(&source_files); if let Err(e) = new_cfg.validate() { - error!("config reload: validation failed: {}; keeping old config", e); + error!( + "config reload: validation failed: {}; keeping old config", + e + ); return Some(next_manifest); } @@ -1217,7 +1244,7 @@ pub fn spawn_config_watcher( ) -> (watch::Receiver>, watch::Receiver) { let initial_level = initial.general.log_level.clone(); let (config_tx, config_rx) = watch::channel(initial); - let (log_tx, log_rx) = watch::channel(initial_level); + let (log_tx, log_rx) = watch::channel(initial_level); let config_path = normalize_watch_path(&config_path); let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok(); @@ -1234,25 +1261,29 @@ pub fn spawn_config_watcher( let tx_inotify = notify_tx.clone(); let manifest_for_inotify = manifest_state.clone(); - let mut inotify_watcher = match recommended_watcher(move |res: notify::Result| { - let Ok(event) = res else { return }; - if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { - return; - } - let is_our_file = manifest_for_inotify - .read() - .map(|manifest| manifest.matches_event_paths(&event.paths)) - .unwrap_or(false); - if is_our_file { - let _ = tx_inotify.try_send(()); - } - }) { - Ok(watcher) => Some(watcher), - Err(e) => { - warn!("config watcher: inotify unavailable: {}", e); - None - } - }; + let mut inotify_watcher = + match recommended_watcher(move |res: notify::Result| { + let Ok(event) = res else { return }; + if !matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ) { + return; + } + let is_our_file = manifest_for_inotify + .read() + .map(|manifest| manifest.matches_event_paths(&event.paths)) + .unwrap_or(false); + if is_our_file { + let _ = tx_inotify.try_send(()); + } + }) { + Ok(watcher) => Some(watcher), + Err(e) => { + warn!("config watcher: inotify unavailable: {}", e); + None + } + }; apply_watch_manifest( inotify_watcher.as_mut(), Option::<&mut notify::poll::PollWatcher>::None, @@ -1268,7 +1299,10 @@ pub fn spawn_config_watcher( let mut poll_watcher = match notify::poll::PollWatcher::new( move |res: notify::Result| { let Ok(event) = res else { return }; - if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { + if !matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ) { return; } let is_our_file = manifest_for_poll @@ -1316,7 +1350,9 @@ pub fn spawn_config_watcher( } } #[cfg(not(unix))] - if notify_rx.recv().await.is_none() { break; } + if notify_rx.recv().await.is_none() { + break; + } // Debounce: drain extra events that arrive within a short quiet window. tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; @@ -1418,7 +1454,10 @@ mod tests { new.server.port = old.server.port.saturating_add(1); let applied = overlay_hot_fields(&old, &new); - assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied)); + assert_eq!( + HotFields::from_config(&old), + HotFields::from_config(&applied) + ); assert_eq!(applied.server.port, old.server.port); } @@ -1437,7 +1476,10 @@ mod tests { applied.general.me_bind_stale_mode, new.general.me_bind_stale_mode ); - assert_ne!(HotFields::from_config(&old), HotFields::from_config(&applied)); + assert_ne!( + HotFields::from_config(&old), + HotFields::from_config(&applied) + ); } #[test] @@ -1451,7 +1493,10 @@ mod tests { applied.general.me_keepalive_interval_secs, old.general.me_keepalive_interval_secs ); - assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied)); + assert_eq!( + HotFields::from_config(&old), + HotFields::from_config(&applied) + ); } #[test] @@ -1463,7 +1508,10 @@ mod tests { let applied = overlay_hot_fields(&old, &new); assert_eq!(applied.general.hardswap, new.general.hardswap); - assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy); + assert_eq!( + applied.general.use_middle_proxy, + old.general.use_middle_proxy + ); assert!(!config_equal(&applied, &new)); } @@ -1475,14 +1523,19 @@ mod tests { write_reload_config(&path, Some(initial_tag), None); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); - let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let initial_hash = ProxyConfig::load_with_metadata(&path) + .unwrap() + .rendered_hash; let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); write_reload_config(&path, Some(final_tag), None); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(final_tag) + ); let _ = std::fs::remove_file(path); } @@ -1495,7 +1548,9 @@ mod tests { write_reload_config(&path, Some(initial_tag), None); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); - let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let initial_hash = ProxyConfig::load_with_metadata(&path) + .unwrap() + .rendered_hash; let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); @@ -1518,7 +1573,9 @@ mod tests { write_reload_config(&path, Some(initial_tag), None); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); - let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let initial_hash = ProxyConfig::load_with_metadata(&path) + .unwrap() + .rendered_hash; let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); @@ -1532,7 +1589,10 @@ mod tests { write_reload_config(&path, Some(final_tag), None); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(final_tag) + ); let _ = std::fs::remove_file(path); } diff --git a/src/config/load.rs b/src/config/load.rs index 2c50f4e..30f1707 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -399,9 +399,7 @@ impl ProxyConfig { )); } - if config.censorship.mask_shape_above_cap_blur - && !config.censorship.mask_shape_hardening - { + if config.censorship.mask_shape_above_cap_blur && !config.censorship.mask_shape_hardening { return Err(ProxyError::Config( "censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true" .to_string(), @@ -419,8 +417,7 @@ impl ProxyConfig { if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 { return Err(ProxyError::Config( - "censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576" - .to_string(), + "censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576".to_string(), )); } @@ -444,8 +441,7 @@ impl ProxyConfig { if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 { return Err(ProxyError::Config( - "censorship.mask_timing_normalization_ceiling_ms must be <= 60000" - .to_string(), + "censorship.mask_timing_normalization_ceiling_ms must be <= 60000".to_string(), )); } @@ -461,8 +457,7 @@ impl ProxyConfig { )); } - if config.timeouts.relay_client_idle_hard_secs - < config.timeouts.relay_client_idle_soft_secs + if config.timeouts.relay_client_idle_hard_secs < config.timeouts.relay_client_idle_soft_secs { return Err(ProxyError::Config( "timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs" @@ -470,7 +465,9 @@ impl ProxyConfig { )); } - if config.timeouts.relay_idle_grace_after_downstream_activity_secs + if config + .timeouts + .relay_idle_grace_after_downstream_activity_secs > config.timeouts.relay_client_idle_hard_secs { return Err(ProxyError::Config( @@ -767,7 +764,8 @@ impl ProxyConfig { } if config.general.me_route_backpressure_base_timeout_ms > 5000 { return Err(ProxyError::Config( - "general.me_route_backpressure_base_timeout_ms must be within [1, 5000]".to_string(), + "general.me_route_backpressure_base_timeout_ms must be within [1, 5000]" + .to_string(), )); } @@ -780,7 +778,8 @@ impl ProxyConfig { } if config.general.me_route_backpressure_high_timeout_ms > 5000 { return Err(ProxyError::Config( - "general.me_route_backpressure_high_timeout_ms must be within [1, 5000]".to_string(), + "general.me_route_backpressure_high_timeout_ms must be within [1, 5000]" + .to_string(), )); } @@ -1828,7 +1827,9 @@ mod tests { let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml"); std::fs::write(&path, toml).unwrap(); let err = ProxyConfig::load(&path).unwrap_err().to_string(); - assert!(err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]")); + assert!( + err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]") + ); let _ = std::fs::remove_file(path); } @@ -1849,7 +1850,9 @@ mod tests { let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml"); std::fs::write(&path, toml).unwrap(); let err = ProxyConfig::load(&path).unwrap_err().to_string(); - assert!(err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]")); + assert!( + err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]") + ); let _ = std::fs::remove_file(path); } diff --git a/src/config/mod.rs b/src/config/mod.rs index c7187ad..dcb3bec 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,9 +1,9 @@ //! Configuration. pub(crate) mod defaults; -mod types; -mod load; pub mod hot_reload; +mod load; +mod types; pub use load::ProxyConfig; pub use types::*; diff --git a/src/config/tests/load_idle_policy_tests.rs b/src/config/tests/load_idle_policy_tests.rs index 087fd75..c6a4e86 100644 --- a/src/config/tests/load_idle_policy_tests.rs +++ b/src/config/tests/load_idle_policy_tests.rs @@ -30,7 +30,9 @@ relay_client_idle_hard_secs = 60 let err = ProxyConfig::load(&path).expect_err("config with hard= timeouts.relay_client_idle_soft_secs"), + msg.contains( + "timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs" + ), "error must explain the violated hard>=soft invariant, got: {msg}" ); diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 41df0f5..736fe05 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -91,11 +91,13 @@ mask_shape_above_cap_blur_max_bytes = 64 "#, ); - let err = ProxyConfig::load(&path) - .expect_err("above-cap blur must require shape hardening enabled"); + let err = + ProxyConfig::load(&path).expect_err("above-cap blur must require shape hardening enabled"); let msg = err.to_string(); assert!( - msg.contains("censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"), + msg.contains( + "censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true" + ), "error must explain blur prerequisite, got: {msg}" ); @@ -113,8 +115,8 @@ mask_shape_above_cap_blur_max_bytes = 0 "#, ); - let err = ProxyConfig::load(&path) - .expect_err("above-cap blur max bytes must be > 0 when enabled"); + let err = + ProxyConfig::load(&path).expect_err("above-cap blur max bytes must be > 0 when enabled"); let msg = err.to_string(); assert!( msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"), @@ -135,8 +137,8 @@ mask_timing_normalization_ceiling_ms = 200 "#, ); - let err = ProxyConfig::load(&path) - .expect_err("timing normalization floor must be > 0 when enabled"); + let err = + ProxyConfig::load(&path).expect_err("timing normalization floor must be > 0 when enabled"); let msg = err.to_string(); assert!( msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"), @@ -157,8 +159,7 @@ mask_timing_normalization_ceiling_ms = 200 "#, ); - let err = ProxyConfig::load(&path) - .expect_err("timing normalization ceiling must be >= floor"); + let err = ProxyConfig::load(&path).expect_err("timing normalization ceiling must be >= floor"); let msg = err.to_string(); assert!( msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"), diff --git a/src/config/tests/load_security_tests.rs b/src/config/tests/load_security_tests.rs index a1a35ac..654a9c0 100644 --- a/src/config/tests/load_security_tests.rs +++ b/src/config/tests/load_security_tests.rs @@ -29,11 +29,13 @@ server_hello_delay_max_ms = 1000 "#, ); - let err = ProxyConfig::load(&path) - .expect_err("delay equal to handshake timeout must be rejected"); + let err = + ProxyConfig::load(&path).expect_err("delay equal to handshake timeout must be rejected"); let msg = err.to_string(); assert!( - msg.contains("censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"), + msg.contains( + "censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000" + ), "error must explain delay; @@ -42,33 +45,39 @@ impl AesCtr { cipher: Aes256Ctr::new(key.into(), (&iv_bytes).into()), } } - + /// Create from key and IV slices pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result { if key.len() != 32 { - return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 32, + got: key.len(), + }); } if iv.len() != 16 { - return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 16, + got: iv.len(), + }); } - + let key: [u8; 32] = key.try_into().unwrap(); let iv = u128::from_be_bytes(iv.try_into().unwrap()); Ok(Self::new(&key, iv)) } - + /// Encrypt/decrypt data in-place (CTR mode is symmetric) pub fn apply(&mut self, data: &mut [u8]) { self.cipher.apply_keystream(data); } - + /// Encrypt data, returning new buffer pub fn encrypt(&mut self, data: &[u8]) -> Vec { let mut output = data.to_vec(); self.apply(&mut output); output } - + /// Decrypt data (for CTR, identical to encrypt) pub fn decrypt(&mut self, data: &[u8]) -> Vec { self.encrypt(data) @@ -99,27 +108,33 @@ impl Drop for AesCbc { impl AesCbc { /// AES block size const BLOCK_SIZE: usize = 16; - + /// Create new AES-CBC cipher with key and IV pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self { Self { key, iv } } - + /// Create from slices pub fn from_slices(key: &[u8], iv: &[u8]) -> Result { if key.len() != 32 { - return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 32, + got: key.len(), + }); } if iv.len() != 16 { - return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); + return Err(ProxyError::InvalidKeyLength { + expected: 16, + got: iv.len(), + }); } - + Ok(Self { key: key.try_into().unwrap(), iv: iv.try_into().unwrap(), }) } - + /// Encrypt a single block using raw AES (no chaining) fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { use aes::cipher::BlockEncrypt; @@ -127,7 +142,7 @@ impl AesCbc { key_schedule.encrypt_block((&mut output).into()); output } - + /// Decrypt a single block using raw AES (no chaining) fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { use aes::cipher::BlockDecrypt; @@ -135,7 +150,7 @@ impl AesCbc { key_schedule.decrypt_block((&mut output).into()); output } - + /// XOR two 16-byte blocks fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] { let mut result = [0u8; 16]; @@ -144,27 +159,28 @@ impl AesCbc { } result } - + /// Encrypt data using CBC mode with proper chaining /// /// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV pub fn encrypt(&self, data: &[u8]) -> Result> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(Vec::new()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut result = Vec::with_capacity(data.len()); let mut prev_ciphertext = self.iv; - + for chunk in data.chunks(Self::BLOCK_SIZE) { let plaintext: [u8; 16] = chunk.try_into().unwrap(); let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); @@ -172,30 +188,31 @@ impl AesCbc { prev_ciphertext = ciphertext; result.extend_from_slice(&ciphertext); } - + Ok(result) } - + /// Decrypt data using CBC mode with proper chaining /// /// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV pub fn decrypt(&self, data: &[u8]) -> Result> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(Vec::new()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut result = Vec::with_capacity(data.len()); let mut prev_ciphertext = self.iv; - + for chunk in data.chunks(Self::BLOCK_SIZE) { let ciphertext: [u8; 16] = chunk.try_into().unwrap(); let decrypted = self.decrypt_block(&ciphertext, &key_schedule); @@ -203,75 +220,77 @@ impl AesCbc { prev_ciphertext = ciphertext; result.extend_from_slice(&plaintext); } - + Ok(result) } - + /// Encrypt data in-place pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut prev_ciphertext = self.iv; - + for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - + for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - + let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.encrypt_block(block_array, &key_schedule); - + prev_ciphertext = *block_array; } - + Ok(()) } - + /// Decrypt data in-place pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> { if !data.len().is_multiple_of(Self::BLOCK_SIZE) { - return Err(ProxyError::Crypto( - format!("CBC data must be aligned to 16 bytes, got {}", data.len()) - )); + return Err(ProxyError::Crypto(format!( + "CBC data must be aligned to 16 bytes, got {}", + data.len() + ))); } - + if data.is_empty() { return Ok(()); } - + use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - + let mut prev_ciphertext = self.iv; - + for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - + let current_ciphertext: [u8; 16] = block.try_into().unwrap(); - + let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.decrypt_block(block_array, &key_schedule); - + for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - + prev_ciphertext = current_ciphertext; } - + Ok(()) } } @@ -318,227 +337,227 @@ impl Decryptor for PassthroughEncryptor { #[cfg(test)] mod tests { use super::*; - + // ============= AES-CTR Tests ============= - + #[test] fn test_aes_ctr_roundtrip() { let key = [0u8; 32]; let iv = 12345u128; - + let original = b"Hello, MTProto!"; - + let mut enc = AesCtr::new(&key, iv); let encrypted = enc.encrypt(original); - + let mut dec = AesCtr::new(&key, iv); let decrypted = dec.decrypt(&encrypted); - + assert_eq!(original.as_slice(), decrypted.as_slice()); } - + #[test] fn test_aes_ctr_in_place() { let key = [0x42u8; 32]; let iv = 999u128; - + let original = b"Test data for in-place encryption"; let mut data = original.to_vec(); - + let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); - + assert_ne!(&data[..], original); - + let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); - + assert_eq!(&data[..], original); } - + // ============= AES-CBC Tests ============= - + #[test] fn test_aes_cbc_roundtrip() { let key = [0u8; 32]; let iv = [0u8; 16]; - + let original = [0u8; 32]; - + let cipher = AesCbc::new(key, iv); let encrypted = cipher.encrypt(&original).unwrap(); let decrypted = cipher.decrypt(&encrypted).unwrap(); - + assert_eq!(original.as_slice(), decrypted.as_slice()); } - + #[test] fn test_aes_cbc_chaining_works() { let key = [0x42u8; 32]; let iv = [0x00u8; 16]; - + let plaintext = [0xAAu8; 32]; - + let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - + let block1 = &ciphertext[0..16]; let block2 = &ciphertext[16..32]; - + assert_ne!( block1, block2, "CBC chaining broken: identical plaintext blocks produced identical ciphertext" ); } - + #[test] fn test_aes_cbc_known_vector() { let key = [0u8; 32]; let iv = [0u8; 16]; let plaintext = [0u8; 16]; - + let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - + let decrypted = cipher.decrypt(&ciphertext).unwrap(); assert_eq!(plaintext.as_slice(), decrypted.as_slice()); - + assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); } - + #[test] fn test_aes_cbc_multi_block() { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - + let plaintext: Vec = (0..80).collect(); - + let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); let decrypted = cipher.decrypt(&ciphertext).unwrap(); - + assert_eq!(plaintext, decrypted); } - + #[test] fn test_aes_cbc_in_place() { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - + let original = [0x56u8; 48]; let mut buffer = original; - + let cipher = AesCbc::new(key, iv); - + cipher.encrypt_in_place(&mut buffer).unwrap(); assert_ne!(&buffer[..], &original[..]); - + cipher.decrypt_in_place(&mut buffer).unwrap(); assert_eq!(&buffer[..], &original[..]); } - + #[test] fn test_aes_cbc_empty_data() { let cipher = AesCbc::new([0u8; 32], [0u8; 16]); - + let encrypted = cipher.encrypt(&[]).unwrap(); assert!(encrypted.is_empty()); - + let decrypted = cipher.decrypt(&[]).unwrap(); assert!(decrypted.is_empty()); } - + #[test] fn test_aes_cbc_unaligned_error() { let cipher = AesCbc::new([0u8; 32], [0u8; 16]); - + let result = cipher.encrypt(&[0u8; 15]); assert!(result.is_err()); - + let result = cipher.encrypt(&[0u8; 17]); assert!(result.is_err()); } - + #[test] fn test_aes_cbc_avalanche_effect() { let key = [0xAB; 32]; let iv = [0xCD; 16]; - + let plaintext1 = [0u8; 32]; let mut plaintext2 = [0u8; 32]; plaintext2[0] = 0x01; - + let cipher = AesCbc::new(key, iv); - + let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); - + assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); } - + #[test] fn test_aes_cbc_iv_matters() { let key = [0x55; 32]; let plaintext = [0x77u8; 16]; - + let cipher1 = AesCbc::new(key, [0u8; 16]); let cipher2 = AesCbc::new(key, [1u8; 16]); - + let ciphertext1 = cipher1.encrypt(&plaintext).unwrap(); let ciphertext2 = cipher2.encrypt(&plaintext).unwrap(); - + assert_ne!(ciphertext1, ciphertext2); } - + #[test] fn test_aes_cbc_deterministic() { let key = [0x99; 32]; let iv = [0x88; 16]; let plaintext = [0x77u8; 32]; - + let cipher = AesCbc::new(key, iv); - + let ciphertext1 = cipher.encrypt(&plaintext).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext).unwrap(); - + assert_eq!(ciphertext1, ciphertext2); } - + // ============= Zeroize Tests ============= - + #[test] fn test_aes_cbc_zeroize_on_drop() { let key = [0xAA; 32]; let iv = [0xBB; 16]; - + let cipher = AesCbc::new(key, iv); // Verify key/iv are set assert_eq!(cipher.key, [0xAA; 32]); assert_eq!(cipher.iv, [0xBB; 16]); - + drop(cipher); // After drop, key/iv are zeroized (can't observe directly, // but the Drop impl runs without panic) } - + // ============= Error Handling Tests ============= - + #[test] fn test_invalid_key_length() { let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]); assert!(result.is_err()); - + let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]); assert!(result.is_err()); } - + #[test] fn test_invalid_iv_length() { let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]); assert!(result.is_err()); - + let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]); assert!(result.is_err()); } -} \ No newline at end of file +} diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs index fa3e441..9e1fa16 100644 --- a/src/crypto/hash.rs +++ b/src/crypto/hash.rs @@ -12,10 +12,10 @@ //! usages are intentional and protocol-mandated. use hmac::{Hmac, Mac}; -use sha2::Sha256; use md5::Md5; use sha1::Sha1; use sha2::Digest; +use sha2::Sha256; type HmacSha256 = Hmac; @@ -28,8 +28,7 @@ pub fn sha256(data: &[u8]) -> [u8; 32] { /// SHA-256 HMAC pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] { - let mut mac = HmacSha256::new_from_slice(key) - .expect("HMAC accepts any key length"); + let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length"); mac.update(data); mac.finalize().into_bytes().into() } @@ -124,27 +123,18 @@ pub fn derive_middleproxy_keys( srv_ipv6: Option<&[u8; 16]>, ) -> ([u8; 32], [u8; 16]) { let s = build_middleproxy_prekey( - nonce_srv, - nonce_clt, - clt_ts, - srv_ip, - clt_port, - purpose, - clt_ip, - srv_port, - secret, - clt_ipv6, - srv_ipv6, + nonce_srv, nonce_clt, clt_ts, srv_ip, clt_port, purpose, clt_ip, srv_port, secret, + clt_ipv6, srv_ipv6, ); let md5_1 = md5(&s[1..]); let sha1_sum = sha1(&s); let md5_2 = md5(&s[2..]); - + let mut key = [0u8; 32]; key[..12].copy_from_slice(&md5_1[..12]); key[12..].copy_from_slice(&sha1_sum); - + (key, md5_2) } @@ -164,17 +154,8 @@ mod tests { let secret = vec![0x55u8; 128]; let prekey = build_middleproxy_prekey( - &nonce_srv, - &nonce_clt, - &clt_ts, - srv_ip, - &clt_port, - b"CLIENT", - clt_ip, - &srv_port, - &secret, - None, - None, + &nonce_srv, &nonce_clt, &clt_ts, srv_ip, &clt_port, b"CLIENT", clt_ip, &srv_port, + &secret, None, None, ); let digest = sha256(&prekey); assert_eq!( diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 9108f34..cf2dcd2 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -4,7 +4,7 @@ pub mod aes; pub mod hash; pub mod random; -pub use aes::{AesCtr, AesCbc}; +pub use aes::{AesCbc, AesCtr}; pub use hash::{ build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, sha256, sha256_hmac, }; diff --git a/src/crypto/random.rs b/src/crypto/random.rs index 2f52188..760f120 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -3,11 +3,11 @@ #![allow(deprecated)] #![allow(dead_code)] -use rand::{Rng, RngExt, SeedableRng}; -use rand::rngs::StdRng; -use parking_lot::Mutex; -use zeroize::Zeroize; use crate::crypto::AesCtr; +use parking_lot::Mutex; +use rand::rngs::StdRng; +use rand::{Rng, RngExt, SeedableRng}; +use zeroize::Zeroize; /// Cryptographically secure PRNG with AES-CTR pub struct SecureRandom { @@ -34,16 +34,16 @@ impl SecureRandom { pub fn new() -> Self { let mut seed_source = rand::rng(); let mut rng = StdRng::from_rng(&mut seed_source); - + let mut key = [0u8; 32]; rng.fill_bytes(&mut key); let iv: u128 = rng.random(); - + let cipher = AesCtr::new(&key, iv); - + // Zeroize local key copy — cipher already consumed it key.zeroize(); - + Self { inner: Mutex::new(SecureRandomInner { rng, @@ -53,7 +53,7 @@ impl SecureRandom { }), } } - + /// Fill a caller-provided buffer with random bytes. pub fn fill(&self, out: &mut [u8]) { let mut inner = self.inner.lock(); @@ -94,7 +94,7 @@ impl SecureRandom { self.fill(&mut out); out } - + /// Generate random number in range [0, max) pub fn range(&self, max: usize) -> usize { if max == 0 { @@ -103,16 +103,16 @@ impl SecureRandom { let mut inner = self.inner.lock(); inner.rng.random_range(0..max) } - + /// Generate random bits pub fn bits(&self, k: usize) -> u64 { if k == 0 { return 0; } - + let bytes_needed = k.div_ceil(8); let bytes = self.bytes(bytes_needed.min(8)); - + let mut result = 0u64; for (i, &b) in bytes.iter().enumerate() { if i >= 8 { @@ -120,14 +120,14 @@ impl SecureRandom { } result |= (b as u64) << (i * 8); } - + if k < 64 { result &= (1u64 << k) - 1; } - + result } - + /// Choose random element from slice pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> { if slice.is_empty() { @@ -136,7 +136,7 @@ impl SecureRandom { Some(&slice[self.range(slice.len())]) } } - + /// Shuffle slice in place pub fn shuffle(&self, slice: &mut [T]) { let mut inner = self.inner.lock(); @@ -145,13 +145,13 @@ impl SecureRandom { slice.swap(i, j); } } - + /// Generate random u32 pub fn u32(&self) -> u32 { let mut inner = self.inner.lock(); inner.rng.random() } - + /// Generate random u64 pub fn u64(&self) -> u64 { let mut inner = self.inner.lock(); @@ -169,7 +169,7 @@ impl Default for SecureRandom { mod tests { use super::*; use std::collections::HashSet; - + #[test] fn test_bytes_uniqueness() { let rng = SecureRandom::new(); @@ -177,7 +177,7 @@ mod tests { let b = rng.bytes(32); assert_ne!(a, b); } - + #[test] fn test_bytes_length() { let rng = SecureRandom::new(); @@ -186,63 +186,63 @@ mod tests { assert_eq!(rng.bytes(100).len(), 100); assert_eq!(rng.bytes(1000).len(), 1000); } - + #[test] fn test_range() { let rng = SecureRandom::new(); - + for _ in 0..1000 { let n = rng.range(10); assert!(n < 10); } - + assert_eq!(rng.range(1), 0); assert_eq!(rng.range(0), 0); } - + #[test] fn test_bits() { let rng = SecureRandom::new(); - + for _ in 0..100 { assert!(rng.bits(1) <= 1); } - + for _ in 0..100 { assert!(rng.bits(8) <= 255); } } - + #[test] fn test_choose() { let rng = SecureRandom::new(); let items = vec![1, 2, 3, 4, 5]; - + let mut seen = HashSet::new(); for _ in 0..1000 { if let Some(&item) = rng.choose(&items) { seen.insert(item); } } - + assert_eq!(seen.len(), 5); - + let empty: Vec = vec![]; assert!(rng.choose(&empty).is_none()); } - + #[test] fn test_shuffle() { let rng = SecureRandom::new(); let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; - + let mut shuffled = original.clone(); rng.shuffle(&mut shuffled); - + let mut sorted = shuffled.clone(); sorted.sort(); assert_eq!(sorted, original); - + assert_ne!(shuffled, original); } } diff --git a/src/error.rs b/src/error.rs index e4d66b9..d9aeb22 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,28 +12,15 @@ use thiserror::Error; #[derive(Debug)] pub enum StreamError { /// Partial read: got fewer bytes than expected - PartialRead { - expected: usize, - got: usize, - }, + PartialRead { expected: usize, got: usize }, /// Partial write: wrote fewer bytes than expected - PartialWrite { - expected: usize, - written: usize, - }, + PartialWrite { expected: usize, written: usize }, /// Stream is in poisoned state and cannot be used - Poisoned { - reason: String, - }, + Poisoned { reason: String }, /// Buffer overflow: attempted to buffer more than allowed - BufferOverflow { - limit: usize, - attempted: usize, - }, + BufferOverflow { limit: usize, attempted: usize }, /// Invalid frame format - InvalidFrame { - details: String, - }, + InvalidFrame { details: String }, /// Unexpected end of stream UnexpectedEof, /// Underlying I/O error @@ -47,13 +34,21 @@ impl fmt::Display for StreamError { write!(f, "partial read: expected {} bytes, got {}", expected, got) } Self::PartialWrite { expected, written } => { - write!(f, "partial write: expected {} bytes, wrote {}", expected, written) + write!( + f, + "partial write: expected {} bytes, wrote {}", + expected, written + ) } Self::Poisoned { reason } => { write!(f, "stream poisoned: {}", reason) } Self::BufferOverflow { limit, attempted } => { - write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted) + write!( + f, + "buffer overflow: limit {}, attempted {}", + limit, attempted + ) } Self::InvalidFrame { details } => { write!(f, "invalid frame: {}", details) @@ -90,9 +85,7 @@ impl From for std::io::Error { StreamError::UnexpectedEof => { std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err) } - StreamError::Poisoned { .. } => { - std::io::Error::other(err) - } + StreamError::Poisoned { .. } => std::io::Error::other(err), StreamError::BufferOverflow { .. } => { std::io::Error::new(std::io::ErrorKind::OutOfMemory, err) } @@ -112,7 +105,7 @@ impl From for std::io::Error { pub trait Recoverable { /// Check if error is recoverable (can retry operation) fn is_recoverable(&self) -> bool; - + /// Check if connection can continue after this error fn can_continue(&self) -> bool; } @@ -123,19 +116,22 @@ impl Recoverable for StreamError { Self::PartialRead { .. } | Self::PartialWrite { .. } => true, Self::Io(e) => matches!( e.kind(), - std::io::ErrorKind::WouldBlock - | std::io::ErrorKind::Interrupted - | std::io::ErrorKind::TimedOut + std::io::ErrorKind::WouldBlock + | std::io::ErrorKind::Interrupted + | std::io::ErrorKind::TimedOut ), - Self::Poisoned { .. } + Self::Poisoned { .. } | Self::BufferOverflow { .. } | Self::InvalidFrame { .. } | Self::UnexpectedEof => false, } } - + fn can_continue(&self) -> bool { - !matches!(self, Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. }) + !matches!( + self, + Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. } + ) } } @@ -143,19 +139,19 @@ impl Recoverable for std::io::Error { fn is_recoverable(&self) -> bool { matches!( self.kind(), - std::io::ErrorKind::WouldBlock - | std::io::ErrorKind::Interrupted - | std::io::ErrorKind::TimedOut + std::io::ErrorKind::WouldBlock + | std::io::ErrorKind::Interrupted + | std::io::ErrorKind::TimedOut ) } - + fn can_continue(&self) -> bool { !matches!( self.kind(), std::io::ErrorKind::BrokenPipe - | std::io::ErrorKind::ConnectionReset - | std::io::ErrorKind::ConnectionAborted - | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::NotConnected ) } } @@ -165,96 +161,88 @@ impl Recoverable for std::io::Error { #[derive(Error, Debug)] pub enum ProxyError { // ============= Crypto Errors ============= - #[error("Crypto error: {0}")] Crypto(String), - + #[error("Invalid key length: expected {expected}, got {got}")] InvalidKeyLength { expected: usize, got: usize }, - + // ============= Stream Errors ============= - #[error("Stream error: {0}")] Stream(#[from] StreamError), - + // ============= Protocol Errors ============= - #[error("Invalid handshake: {0}")] InvalidHandshake(String), - + #[error("Invalid protocol tag: {0:02x?}")] InvalidProtoTag([u8; 4]), - + #[error("Invalid TLS record: type={record_type}, version={version:02x?}")] InvalidTlsRecord { record_type: u8, version: [u8; 2] }, - + #[error("Replay attack detected from {addr}")] ReplayAttack { addr: SocketAddr }, - + #[error("Time skew detected: client={client_time}, server={server_time}")] TimeSkew { client_time: u32, server_time: u32 }, - + #[error("Invalid message length: {len} (min={min}, max={max})")] InvalidMessageLength { len: usize, min: usize, max: usize }, - + #[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")] ChecksumMismatch { expected: u32, got: u32 }, - + #[error("Sequence number mismatch: expected={expected}, got={got}")] SeqNoMismatch { expected: i32, got: i32 }, - + #[error("TLS handshake failed: {reason}")] TlsHandshakeFailed { reason: String }, - + #[error("Telegram handshake timeout")] TgHandshakeTimeout, - + // ============= Network Errors ============= - #[error("Connection timeout to {addr}")] ConnectionTimeout { addr: String }, - + #[error("Connection refused by {addr}")] ConnectionRefused { addr: String }, - + #[error("IO error: {0}")] Io(#[from] std::io::Error), - + // ============= Proxy Protocol Errors ============= - #[error("Invalid proxy protocol header")] InvalidProxyProtocol, - + #[error("Proxy error: {0}")] Proxy(String), - + // ============= Config Errors ============= - #[error("Config error: {0}")] Config(String), - + #[error("Invalid secret for user {user}: {reason}")] InvalidSecret { user: String, reason: String }, - + // ============= User Errors ============= - #[error("User {user} expired")] UserExpired { user: String }, - + #[error("User {user} exceeded connection limit")] ConnectionLimitExceeded { user: String }, - + #[error("User {user} exceeded data quota")] DataQuotaExceeded { user: String }, - + #[error("Unknown user")] UnknownUser, - + #[error("Rate limited")] RateLimited, - + // ============= General Errors ============= - #[error("Internal error: {0}")] Internal(String), } @@ -269,7 +257,7 @@ impl Recoverable for ProxyError { _ => false, } } - + fn can_continue(&self) -> bool { match self { Self::Stream(e) => e.can_continue(), @@ -301,17 +289,19 @@ impl HandshakeResult { pub fn is_success(&self) -> bool { matches!(self, HandshakeResult::Success(_)) } - + /// Check if bad client pub fn is_bad_client(&self) -> bool { matches!(self, HandshakeResult::BadClient { .. }) } - + /// Map the success value pub fn map U>(self, f: F) -> HandshakeResult { match self { HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), - HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer }, + HandshakeResult::BadClient { reader, writer } => { + HandshakeResult::BadClient { reader, writer } + } HandshakeResult::Error(e) => HandshakeResult::Error(e), } } @@ -338,76 +328,104 @@ impl From for HandshakeResult { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_stream_error_display() { - let err = StreamError::PartialRead { expected: 100, got: 50 }; + let err = StreamError::PartialRead { + expected: 100, + got: 50, + }; assert!(err.to_string().contains("100")); assert!(err.to_string().contains("50")); - - let err = StreamError::Poisoned { reason: "test".into() }; + + let err = StreamError::Poisoned { + reason: "test".into(), + }; assert!(err.to_string().contains("test")); } - + #[test] fn test_stream_error_recoverable() { - assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable()); - assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable()); + assert!( + StreamError::PartialRead { + expected: 10, + got: 5 + } + .is_recoverable() + ); + assert!( + StreamError::PartialWrite { + expected: 10, + written: 5 + } + .is_recoverable() + ); assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable()); assert!(!StreamError::UnexpectedEof.is_recoverable()); } - + #[test] fn test_stream_error_can_continue() { assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue()); assert!(!StreamError::UnexpectedEof.can_continue()); - assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue()); + assert!( + StreamError::PartialRead { + expected: 10, + got: 5 + } + .can_continue() + ); } - + #[test] fn test_stream_error_to_io_error() { let stream_err = StreamError::UnexpectedEof; let io_err: std::io::Error = stream_err.into(); assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof); } - + #[test] fn test_handshake_result() { let success: HandshakeResult = HandshakeResult::Success(42); assert!(success.is_success()); assert!(!success.is_bad_client()); - - let bad: HandshakeResult = HandshakeResult::BadClient { reader: (), writer: () }; + + let bad: HandshakeResult = HandshakeResult::BadClient { + reader: (), + writer: (), + }; assert!(!bad.is_success()); assert!(bad.is_bad_client()); } - + #[test] fn test_handshake_result_map() { let success: HandshakeResult = HandshakeResult::Success(42); let mapped = success.map(|x| x * 2); - + match mapped { HandshakeResult::Success(v) => assert_eq!(v, 84), _ => panic!("Expected success"), } } - + #[test] fn test_proxy_error_recoverable() { let err = ProxyError::RateLimited; assert!(err.is_recoverable()); - + let err = ProxyError::InvalidHandshake("bad".into()); assert!(!err.is_recoverable()); } - + #[test] fn test_error_display() { - let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() }; + let err = ProxyError::ConnectionTimeout { + addr: "1.2.3.4:443".into(), + }; assert!(err.to_string().contains("1.2.3.4:443")); - + let err = ProxyError::InvalidProxyProtocol; assert!(err.to_string().contains("proxy protocol")); } -} \ No newline at end of file +} diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index c35c587..c9a0681 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -5,9 +5,9 @@ use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; +use std::sync::Mutex; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; -use std::sync::Mutex; use tokio::sync::{Mutex as AsyncMutex, RwLock}; @@ -41,7 +41,6 @@ impl UserIpTracker { } } - pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { match self.cleanup_queue.lock() { Ok(mut queue) => queue.push((user, ip)), @@ -129,7 +128,8 @@ impl UserIpTracker { let mut active_ips = self.active_ips.write().await; let mut recent_ips = self.recent_ips.write().await; - let mut users = Vec::::with_capacity(active_ips.len().saturating_add(recent_ips.len())); + let mut users = + Vec::::with_capacity(active_ips.len().saturating_add(recent_ips.len())); users.extend(active_ips.keys().cloned()); for user in recent_ips.keys() { if !active_ips.contains_key(user) { @@ -138,8 +138,14 @@ impl UserIpTracker { } for user in users { - let active_empty = active_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); - let recent_empty = recent_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); + let active_empty = active_ips + .get(&user) + .map(|ips| ips.is_empty()) + .unwrap_or(true); + let recent_empty = recent_ips + .get(&user) + .map(|ips| ips.is_empty()) + .unwrap_or(true); if active_empty && recent_empty { active_ips.remove(&user); recent_ips.remove(&user); diff --git a/src/maestro/connectivity.rs b/src/maestro/connectivity.rs index c843223..ee5fdb9 100644 --- a/src/maestro/connectivity.rs +++ b/src/maestro/connectivity.rs @@ -11,10 +11,10 @@ use crate::startup::{ COMPONENT_DC_CONNECTIVITY_PING, COMPONENT_ME_CONNECTIVITY_PING, COMPONENT_RUNTIME_READY, StartupTracker, }; +use crate::transport::UpstreamManager; use crate::transport::middle_proxy::{ MePingFamily, MePingSample, MePool, format_me_route, format_sample_line, run_me_ping, }; -use crate::transport::UpstreamManager; pub(crate) async fn run_startup_connectivity( config: &Arc, @@ -47,11 +47,15 @@ pub(crate) async fn run_startup_connectivity( let v4_ok = me_results.iter().any(|r| { matches!(r.family, MePingFamily::V4) - && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + && r.samples + .iter() + .any(|s| s.error.is_none() && s.handshake_ms.is_some()) }); let v6_ok = me_results.iter().any(|r| { matches!(r.family, MePingFamily::V6) - && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + && r.samples + .iter() + .any(|s| s.error.is_none() && s.handshake_ms.is_some()) }); info!("================= Telegram ME Connectivity ================="); @@ -131,8 +135,14 @@ pub(crate) async fn run_startup_connectivity( .await; for upstream_result in &ping_results { - let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some()); - let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some()); + let v6_works = upstream_result + .v6_results + .iter() + .any(|r| r.rtt_ms.is_some()); + let v4_works = upstream_result + .v4_results + .iter() + .any(|r| r.rtt_ms.is_some()); if upstream_result.both_available { if prefer_ipv6 { diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index f916633..ffa4d1b 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -1,5 +1,5 @@ -use std::time::Duration; use std::path::PathBuf; +use std::time::Duration; use tokio::sync::watch; use tracing::{debug, error, info, warn}; @@ -10,7 +10,10 @@ use crate::transport::middle_proxy::{ ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, }; -pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf { +pub(crate) fn resolve_runtime_config_path( + config_path_cli: &str, + startup_cwd: &std::path::Path, +) -> PathBuf { let raw = PathBuf::from(config_path_cli); let absolute = if raw.is_absolute() { raw @@ -50,7 +53,9 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { } } s if s.starts_with("--data-path=") => { - data_path = Some(PathBuf::from(s.trim_start_matches("--data-path=").to_string())); + data_path = Some(PathBuf::from( + s.trim_start_matches("--data-path=").to_string(), + )); } "--silent" | "-s" => { silent = true; @@ -68,7 +73,9 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { eprintln!("Usage: telemt [config.toml] [OPTIONS]"); eprintln!(); eprintln!("Options:"); - eprintln!(" --data-path Set data directory (absolute path; overrides config value)"); + eprintln!( + " --data-path Set data directory (absolute path; overrides config value)" + ); eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --log-level debug|verbose|normal|silent"); eprintln!(" --help, -h Show this help"); @@ -146,7 +153,12 @@ mod tests { pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); - for user_name in config.general.links.show.resolve_users(&config.access.users) { + for user_name in config + .general + .links + .show + .resolve_users(&config.access.users) + { if let Some(secret) = config.access.users.get(user_name) { info!(target: "telemt::links", "User: {}", user_name); if config.general.modes.classic { @@ -287,7 +299,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot( return Some(cfg); } - warn!(snapshot = label, url, "Startup proxy-config is empty; trying disk cache"); + warn!( + snapshot = label, + url, "Startup proxy-config is empty; trying disk cache" + ); if let Some(path) = cache_path { match load_proxy_config_cache(path).await { Ok(cached) if !cached.map.is_empty() => { @@ -302,8 +317,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot( Ok(_) => { warn!( snapshot = label, - path, - "Startup proxy-config cache is empty; ignoring cache file" + path, "Startup proxy-config cache is empty; ignoring cache file" ); } Err(cache_err) => { @@ -347,8 +361,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot( Ok(_) => { warn!( snapshot = label, - path, - "Startup proxy-config cache is empty; ignoring cache file" + path, "Startup proxy-config cache is empty; ignoring cache file" ); } Err(cache_err) => { diff --git a/src/maestro/listeners.rs b/src/maestro/listeners.rs index fe041d9..effaff8 100644 --- a/src/maestro/listeners.rs +++ b/src/maestro/listeners.rs @@ -12,17 +12,15 @@ use tracing::{debug, error, info, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; -use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController}; use crate::proxy::ClientHandler; +use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController}; use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker}; use crate::stats::beobachten::BeobachtenStore; use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; use crate::tls_front::TlsFrontCache; use crate::transport::middle_proxy::MePool; -use crate::transport::{ - ListenOptions, UpstreamManager, create_listener, find_listener_processes, -}; +use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes}; use super::helpers::{is_expected_handshake_eof, print_proxy_links}; @@ -81,8 +79,9 @@ pub(crate) async fn bind_listeners( Ok(socket) => { let listener = TcpListener::from_std(socket.into())?; info!("Listening on {}", addr); - let listener_proxy_protocol = - listener_conf.proxy_protocol.unwrap_or(config.server.proxy_protocol); + let listener_proxy_protocol = listener_conf + .proxy_protocol + .unwrap_or(config.server.proxy_protocol); let public_host = if let Some(ref announce) = listener_conf.announce { announce.clone() @@ -100,8 +99,14 @@ pub(crate) async fn bind_listeners( listener_conf.ip.to_string() }; - if config.general.links.public_host.is_none() && !config.general.links.show.is_empty() { - let link_port = config.general.links.public_port.unwrap_or(config.server.port); + if config.general.links.public_host.is_none() + && !config.general.links.show.is_empty() + { + let link_port = config + .general + .links + .public_port + .unwrap_or(config.server.port); print_proxy_links(&public_host, link_port, config); } @@ -145,12 +150,14 @@ pub(crate) async fn bind_listeners( let (host, port) = if let Some(ref h) = config.general.links.public_host { ( h.clone(), - config.general.links.public_port.unwrap_or(config.server.port), + config + .general + .links + .public_port + .unwrap_or(config.server.port), ) } else { - let ip = detected_ip_v4 - .or(detected_ip_v6) - .map(|ip| ip.to_string()); + let ip = detected_ip_v4.or(detected_ip_v6).map(|ip| ip.to_string()); if ip.is_none() { warn!( "show_link is configured but public IP could not be detected. Set public_host in config." @@ -158,7 +165,11 @@ pub(crate) async fn bind_listeners( } ( ip.unwrap_or_else(|| "UNKNOWN".to_string()), - config.general.links.public_port.unwrap_or(config.server.port), + config + .general + .links + .public_port + .unwrap_or(config.server.port), ) }; @@ -178,13 +189,19 @@ pub(crate) async fn bind_listeners( use std::os::unix::fs::PermissionsExt; let perms = std::fs::Permissions::from_mode(mode); if let Err(e) = std::fs::set_permissions(unix_path, perms) { - error!("Failed to set unix socket permissions to {}: {}", perm_str, e); + error!( + "Failed to set unix socket permissions to {}: {}", + perm_str, e + ); } else { info!("Listening on unix:{} (mode {})", unix_path, perm_str); } } Err(e) => { - warn!("Invalid listen_unix_sock_perm '{}': {}. Ignoring.", perm_str, e); + warn!( + "Invalid listen_unix_sock_perm '{}': {}. Ignoring.", + perm_str, e + ); info!("Listening on unix:{}", unix_path); } } @@ -218,10 +235,8 @@ pub(crate) async fn bind_listeners( drop(stream); continue; } - let accept_permit_timeout_ms = config_rx_unix - .borrow() - .server - .accept_permit_timeout_ms; + let accept_permit_timeout_ms = + config_rx_unix.borrow().server.accept_permit_timeout_ms; let permit = if accept_permit_timeout_ms == 0 { match max_connections_unix.clone().acquire_owned().await { Ok(permit) => permit, @@ -361,10 +376,8 @@ pub(crate) fn spawn_tcp_accept_loops( drop(stream); continue; } - let accept_permit_timeout_ms = config_rx - .borrow() - .server - .accept_permit_timeout_ms; + let accept_permit_timeout_ms = + config_rx.borrow().server.accept_permit_timeout_ms; let permit = if accept_permit_timeout_ms == 0 { match max_connections_tcp.clone().acquire_owned().await { Ok(permit) => permit, diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index bbe46a8..c668734 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -12,8 +12,8 @@ use crate::startup::{ COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, StartupMeStatus, StartupTracker, }; use crate::stats::Stats; -use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; +use crate::transport::middle_proxy::MePool; use super::helpers::load_startup_proxy_config_snapshot; @@ -229,8 +229,12 @@ pub(crate) async fn initialize_me_pool( config.general.me_adaptive_floor_recover_grace_secs, config.general.me_adaptive_floor_writers_per_core_total, config.general.me_adaptive_floor_cpu_cores_override, - config.general.me_adaptive_floor_max_extra_writers_single_per_core, - config.general.me_adaptive_floor_max_extra_writers_multi_per_core, + config + .general + .me_adaptive_floor_max_extra_writers_single_per_core, + config + .general + .me_adaptive_floor_max_extra_writers_multi_per_core, config.general.me_adaptive_floor_max_active_writers_per_core, config.general.me_adaptive_floor_max_warm_writers_per_core, config.general.me_adaptive_floor_max_active_writers_global, @@ -457,64 +461,70 @@ pub(crate) async fn initialize_me_pool( "Middle-End pool initialized successfully" ); - // ── Supervised background tasks ────────────────── - let pool_clone = pool.clone(); - let rng_clone = rng.clone(); - let min_conns = pool_size; - tokio::spawn(async move { - loop { - let p = pool_clone.clone(); - let r = rng_clone.clone(); - let res = tokio::spawn(async move { - crate::transport::middle_proxy::me_health_monitor( - p, r, min_conns, - ) - .await; - }) + // ── Supervised background tasks ────────────────── + let pool_clone = pool.clone(); + let rng_clone = rng.clone(); + let min_conns = pool_size; + tokio::spawn(async move { + loop { + let p = pool_clone.clone(); + let r = rng_clone.clone(); + let res = tokio::spawn(async move { + crate::transport::middle_proxy::me_health_monitor( + p, r, min_conns, + ) .await; - match res { - Ok(()) => warn!("me_health_monitor exited unexpectedly, restarting"), - Err(e) => { - error!(error = %e, "me_health_monitor panicked, restarting in 1s"); - tokio::time::sleep(Duration::from_secs(1)).await; - } + }) + .await; + match res { + Ok(()) => warn!( + "me_health_monitor exited unexpectedly, restarting" + ), + Err(e) => { + error!(error = %e, "me_health_monitor panicked, restarting in 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; } } - }); - let pool_drain_enforcer = pool.clone(); - tokio::spawn(async move { - loop { - let p = pool_drain_enforcer.clone(); - let res = tokio::spawn(async move { + } + }); + let pool_drain_enforcer = pool.clone(); + tokio::spawn(async move { + loop { + let p = pool_drain_enforcer.clone(); + let res = tokio::spawn(async move { crate::transport::middle_proxy::me_drain_timeout_enforcer(p).await; }) .await; - match res { - Ok(()) => warn!("me_drain_timeout_enforcer exited unexpectedly, restarting"), - Err(e) => { - error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s"); - tokio::time::sleep(Duration::from_secs(1)).await; - } + match res { + Ok(()) => warn!( + "me_drain_timeout_enforcer exited unexpectedly, restarting" + ), + Err(e) => { + error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; } } - }); - let pool_watchdog = pool.clone(); - tokio::spawn(async move { - loop { - let p = pool_watchdog.clone(); - let res = tokio::spawn(async move { + } + }); + let pool_watchdog = pool.clone(); + tokio::spawn(async move { + loop { + let p = pool_watchdog.clone(); + let res = tokio::spawn(async move { crate::transport::middle_proxy::me_zombie_writer_watchdog(p).await; }) .await; - match res { - Ok(()) => warn!("me_zombie_writer_watchdog exited unexpectedly, restarting"), - Err(e) => { - error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s"); - tokio::time::sleep(Duration::from_secs(1)).await; - } + match res { + Ok(()) => warn!( + "me_zombie_writer_watchdog exited unexpectedly, restarting" + ), + Err(e) => { + error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; } } - }); + } + }); break Some(pool); } diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index 7ba7b39..7d3b168 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -11,9 +11,9 @@ // - admission: conditional-cast gate and route mode switching. // - listeners: TCP/Unix listener bind and accept-loop orchestration. // - shutdown: graceful shutdown sequence and uptime logging. -mod helpers; mod admission; mod connectivity; +mod helpers; mod listeners; mod me_startup; mod runtime_tasks; @@ -33,18 +33,18 @@ use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe}; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::startup::{ + COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, COMPONENT_ME_POOL_CONSTRUCT, + COMPONENT_ME_POOL_INIT_STAGE1, COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6, + COMPONENT_ME_SECRET_FETCH, COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus, + StartupTracker, +}; use crate::stats::beobachten::BeobachtenStore; use crate::stats::telemetry::TelemetryPolicy; use crate::stats::{ReplayChecker, Stats}; -use crate::startup::{ - COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, - COMPONENT_ME_POOL_CONSTRUCT, COMPONENT_ME_POOL_INIT_STAGE1, - COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, - COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus, StartupTracker, -}; use crate::stream::BufferPool; -use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; +use crate::transport::middle_proxy::MePool; use helpers::{parse_cli, resolve_runtime_config_path}; /// Runs the full telemt runtime startup pipeline and blocks until shutdown. @@ -56,7 +56,10 @@ pub async fn run() -> std::result::Result<(), Box> { .as_secs(); let startup_tracker = Arc::new(StartupTracker::new(process_started_at_epoch_secs)); startup_tracker - .start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string())) + .start_component( + COMPONENT_CONFIG_LOAD, + Some("load and validate config".to_string()), + ) .await; let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); let startup_cwd = match std::env::current_dir() { @@ -77,7 +80,10 @@ pub async fn run() -> std::result::Result<(), Box> { } else { let default = ProxyConfig::default(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); - eprintln!("[telemt] Created default config at {}", config_path.display()); + eprintln!( + "[telemt] Created default config at {}", + config_path.display() + ); default } } @@ -94,24 +100,38 @@ pub async fn run() -> std::result::Result<(), Box> { if let Some(ref data_path) = config.general.data_path { if !data_path.is_absolute() { - eprintln!("[telemt] data_path must be absolute: {}", data_path.display()); + eprintln!( + "[telemt] data_path must be absolute: {}", + data_path.display() + ); std::process::exit(1); } if data_path.exists() { if !data_path.is_dir() { - eprintln!("[telemt] data_path exists but is not a directory: {}", data_path.display()); + eprintln!( + "[telemt] data_path exists but is not a directory: {}", + data_path.display() + ); std::process::exit(1); } } else { if let Err(e) = std::fs::create_dir_all(data_path) { - eprintln!("[telemt] Can't create data_path {}: {}", data_path.display(), e); + eprintln!( + "[telemt] Can't create data_path {}: {}", + data_path.display(), + e + ); std::process::exit(1); } } if let Err(e) = std::env::set_current_dir(data_path) { - eprintln!("[telemt] Can't use data_path {}: {}", data_path.display(), e); + eprintln!( + "[telemt] Can't use data_path {}: {}", + data_path.display(), + e + ); std::process::exit(1); } } @@ -135,7 +155,10 @@ pub async fn run() -> std::result::Result<(), Box> { let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info")); startup_tracker - .start_component(COMPONENT_TRACING_INIT, Some("initialize tracing subscriber".to_string())) + .start_component( + COMPONENT_TRACING_INIT, + Some("initialize tracing subscriber".to_string()), + ) .await; // Configure color output based on config @@ -150,7 +173,10 @@ pub async fn run() -> std::result::Result<(), Box> { .with(fmt_layer) .init(); startup_tracker - .complete_component(COMPONENT_TRACING_INIT, Some("tracing initialized".to_string())) + .complete_component( + COMPONENT_TRACING_INIT, + Some("tracing initialized".to_string()), + ) .await; info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); @@ -216,7 +242,8 @@ pub async fn run() -> std::result::Result<(), Box> { config.access.user_max_unique_ips_window_secs, ) .await; - if config.access.user_max_unique_ips_global_each > 0 || !config.access.user_max_unique_ips.is_empty() + if config.access.user_max_unique_ips_global_each > 0 + || !config.access.user_max_unique_ips.is_empty() { info!( global_each_limit = config.access.user_max_unique_ips_global_each, @@ -243,7 +270,10 @@ pub async fn run() -> std::result::Result<(), Box> { let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode)); let api_me_pool = Arc::new(RwLock::new(None::>)); startup_tracker - .start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string())) + .start_component( + COMPONENT_API_BOOTSTRAP, + Some("spawn API listener task".to_string()), + ) .await; if config.server.api.enabled { @@ -326,7 +356,10 @@ pub async fn run() -> std::result::Result<(), Box> { .await; startup_tracker - .start_component(COMPONENT_NETWORK_PROBE, Some("probe network capabilities".to_string())) + .start_component( + COMPONENT_NETWORK_PROBE, + Some("probe network capabilities".to_string()), + ) .await; let probe = run_probe( &config.network, @@ -339,11 +372,8 @@ pub async fn run() -> std::result::Result<(), Box> { probe.detected_ipv4.map(IpAddr::V4), probe.detected_ipv6.map(IpAddr::V6), )); - let decision = decide_network_capabilities( - &config.network, - &probe, - config.general.middle_proxy_nat_ip, - ); + let decision = + decide_network_capabilities(&config.network, &probe, config.general.middle_proxy_nat_ip); log_probe_result(&probe, &decision); startup_tracker .complete_component( @@ -446,24 +476,16 @@ pub async fn run() -> std::result::Result<(), Box> { // If ME failed to initialize, force direct-only mode. if me_pool.is_some() { - startup_tracker - .set_transport_mode("middle_proxy") - .await; - startup_tracker - .set_degraded(false) - .await; + startup_tracker.set_transport_mode("middle_proxy").await; + startup_tracker.set_degraded(false).await; info!("Transport: Middle-End Proxy - all DC-over-RPC"); } else { let _ = use_middle_proxy; use_middle_proxy = false; // Make runtime config reflect direct-only mode for handlers. config.general.use_middle_proxy = false; - startup_tracker - .set_transport_mode("direct") - .await; - startup_tracker - .set_degraded(true) - .await; + startup_tracker.set_transport_mode("direct").await; + startup_tracker.set_degraded(true).await; if me2dc_fallback { startup_tracker .set_me_status(StartupMeStatus::Failed, "fallback_to_direct") diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index c2233c7..d553eb9 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -4,21 +4,24 @@ use std::sync::Arc; use tokio::sync::{mpsc, watch}; use tracing::{debug, warn}; -use tracing_subscriber::reload; use tracing_subscriber::EnvFilter; +use tracing_subscriber::reload; -use crate::config::{LogLevel, ProxyConfig}; use crate::config::hot_reload::spawn_config_watcher; +use crate::config::{LogLevel, ProxyConfig}; use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::metrics; use crate::network::probe::NetworkProbe; -use crate::startup::{COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, StartupTracker}; +use crate::startup::{ + COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, + StartupTracker, +}; use crate::stats::beobachten::BeobachtenStore; use crate::stats::telemetry::TelemetryPolicy; use crate::stats::{ReplayChecker, Stats}; -use crate::transport::middle_proxy::{MePool, MeReinitTrigger}; use crate::transport::UpstreamManager; +use crate::transport::middle_proxy::{MePool, MeReinitTrigger}; use super::helpers::write_beobachten_snapshot; @@ -79,15 +82,13 @@ pub(crate) async fn spawn_runtime_tasks( Some("spawn config hot-reload watcher".to_string()), ) .await; - let (config_rx, log_level_rx): ( - watch::Receiver>, - watch::Receiver, - ) = spawn_config_watcher( - config_path.to_path_buf(), - config.clone(), - detected_ip_v4, - detected_ip_v6, - ); + let (config_rx, log_level_rx): (watch::Receiver>, watch::Receiver) = + spawn_config_watcher( + config_path.to_path_buf(), + config.clone(), + detected_ip_v4, + detected_ip_v6, + ); startup_tracker .complete_component( COMPONENT_CONFIG_WATCHER_START, @@ -114,7 +115,8 @@ pub(crate) async fn spawn_runtime_tasks( break; } let cfg = config_rx_policy.borrow_and_update().clone(); - stats_policy.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry)); + stats_policy + .apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry)); if let Some(pool) = &me_pool_for_policy { pool.update_runtime_transport_policy( cfg.general.me_socks_kdf_policy, @@ -130,7 +132,11 @@ pub(crate) async fn spawn_runtime_tasks( let ip_tracker_policy = ip_tracker.clone(); let mut config_rx_ip_limits = config_rx.clone(); tokio::spawn(async move { - let mut prev_limits = config_rx_ip_limits.borrow().access.user_max_unique_ips.clone(); + let mut prev_limits = config_rx_ip_limits + .borrow() + .access + .user_max_unique_ips + .clone(); let mut prev_global_each = config_rx_ip_limits .borrow() .access @@ -183,7 +189,9 @@ pub(crate) async fn spawn_runtime_tasks( let sleep_secs = cfg.general.beobachten_flush_secs.max(1); if cfg.general.beobachten { - let ttl = std::time::Duration::from_secs(cfg.general.beobachten_minutes.saturating_mul(60)); + let ttl = std::time::Duration::from_secs( + cfg.general.beobachten_minutes.saturating_mul(60), + ); let path = cfg.general.beobachten_file.clone(); let snapshot = beobachten_writer.snapshot_text(ttl); if let Err(e) = write_beobachten_snapshot(&path, &snapshot).await { @@ -227,8 +235,11 @@ pub(crate) async fn spawn_runtime_tasks( let config_rx_clone_rot = config_rx.clone(); let reinit_tx_rotation = reinit_tx.clone(); tokio::spawn(async move { - crate::transport::middle_proxy::me_rotation_task(config_rx_clone_rot, reinit_tx_rotation) - .await; + crate::transport::middle_proxy::me_rotation_task( + config_rx_clone_rot, + reinit_tx_rotation, + ) + .await; }); } diff --git a/src/maestro/shutdown.rs b/src/maestro/shutdown.rs index b73df30..243c772 100644 --- a/src/maestro/shutdown.rs +++ b/src/maestro/shutdown.rs @@ -16,8 +16,11 @@ pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Opti let uptime_secs = process_started_at.elapsed().as_secs(); info!("Uptime: {}", format_uptime(uptime_secs)); if let Some(pool) = &me_pool { - match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()) - .await + match tokio::time::timeout( + Duration::from_secs(2), + pool.shutdown_send_close_conn_all(), + ) + .await { Ok(total) => { info!( diff --git a/src/main.rs b/src/main.rs index dff8c8a..e8b91a0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,11 +7,11 @@ mod crypto; mod error; mod ip_tracker; #[cfg(test)] -#[path = "tests/ip_tracker_regression_tests.rs"] -mod ip_tracker_regression_tests; -#[cfg(test)] #[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] mod ip_tracker_hotpath_adversarial_tests; +#[cfg(test)] +#[path = "tests/ip_tracker_regression_tests.rs"] +mod ip_tracker_regression_tests; mod maestro; mod metrics; mod network; diff --git a/src/metrics.rs b/src/metrics.rs index b7a16f0..2560294 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,5 +1,5 @@ -use std::convert::Infallible; use std::collections::{BTreeSet, HashMap}; +use std::convert::Infallible; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -11,12 +11,12 @@ use hyper::service::service_fn; use hyper::{Request, Response, StatusCode}; use ipnetwork::IpNetwork; use tokio::net::TcpListener; -use tracing::{info, warn, debug}; +use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::ip_tracker::UserIpTracker; -use crate::stats::beobachten::BeobachtenStore; use crate::stats::Stats; +use crate::stats::beobachten::BeobachtenStore; use crate::transport::{ListenOptions, create_listener}; pub async fn serve( @@ -62,7 +62,10 @@ pub async fn serve( let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port)); match bind_metrics_listener(addr_v4, false) { Ok(listener) => { - info!("Metrics endpoint: http://{}/metrics and /beobachten", addr_v4); + info!( + "Metrics endpoint: http://{}/metrics and /beobachten", + addr_v4 + ); listener_v4 = Some(listener); } Err(e) => { @@ -73,7 +76,10 @@ pub async fn serve( let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port)); match bind_metrics_listener(addr_v6, true) { Ok(listener) => { - info!("Metrics endpoint: http://[::]:{}/metrics and /beobachten", port); + info!( + "Metrics endpoint: http://[::]:{}/metrics and /beobachten", + port + ); listener_v6 = Some(listener); } Err(e) => { @@ -109,12 +115,7 @@ pub async fn serve( .await; }); serve_listener( - listener4, - stats, - beobachten, - ip_tracker, - config_rx, - whitelist, + listener4, stats, beobachten, ip_tracker, config_rx, whitelist, ) .await; } @@ -231,7 +232,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge"); let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs()); - let _ = writeln!(out, "# HELP telemt_telemetry_core_enabled Runtime core telemetry switch"); + let _ = writeln!( + out, + "# HELP telemt_telemetry_core_enabled Runtime core telemetry switch" + ); let _ = writeln!(out, "# TYPE telemt_telemetry_core_enabled gauge"); let _ = writeln!( out, @@ -239,7 +243,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp if core_enabled { 1 } else { 0 } ); - let _ = writeln!(out, "# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch"); + let _ = writeln!( + out, + "# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch" + ); let _ = writeln!(out, "# TYPE telemt_telemetry_user_enabled gauge"); let _ = writeln!( out, @@ -247,7 +254,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp if user_enabled { 1 } else { 0 } ); - let _ = writeln!(out, "# HELP telemt_telemetry_me_level Runtime ME telemetry level flag"); + let _ = writeln!( + out, + "# HELP telemt_telemetry_me_level Runtime ME telemetry level flag" + ); let _ = writeln!(out, "# TYPE telemt_telemetry_me_level gauge"); let _ = writeln!( out, @@ -277,23 +287,40 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_connections_total Total accepted connections"); + let _ = writeln!( + out, + "# HELP telemt_connections_total Total accepted connections" + ); let _ = writeln!(out, "# TYPE telemt_connections_total counter"); let _ = writeln!( out, "telemt_connections_total {}", - if core_enabled { stats.get_connects_all() } else { 0 } + if core_enabled { + stats.get_connects_all() + } else { + 0 + } ); - let _ = writeln!(out, "# HELP telemt_connections_bad_total Bad/rejected connections"); + let _ = writeln!( + out, + "# HELP telemt_connections_bad_total Bad/rejected connections" + ); let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter"); let _ = writeln!( out, "telemt_connections_bad_total {}", - if core_enabled { stats.get_connects_bad() } else { 0 } + if core_enabled { + stats.get_connects_bad() + } else { + 0 + } ); - let _ = writeln!(out, "# HELP telemt_handshake_timeouts_total Handshake timeouts"); + let _ = writeln!( + out, + "# HELP telemt_handshake_timeouts_total Handshake timeouts" + ); let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter"); let _ = writeln!( out, @@ -372,7 +399,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_upstream_connect_attempts_per_request Histogram-like buckets for attempts per upstream connect request cycle" ); - let _ = writeln!(out, "# TYPE telemt_upstream_connect_attempts_per_request counter"); + let _ = writeln!( + out, + "# TYPE telemt_upstream_connect_attempts_per_request counter" + ); let _ = writeln!( out, "telemt_upstream_connect_attempts_per_request{{bucket=\"1\"}} {}", @@ -414,7 +444,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_upstream_connect_duration_success_total Histogram-like buckets of successful upstream connect cycle duration" ); - let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_success_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_upstream_connect_duration_success_total counter" + ); let _ = writeln!( out, "telemt_upstream_connect_duration_success_total{{bucket=\"le_100ms\"}} {}", @@ -456,7 +489,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_upstream_connect_duration_fail_total Histogram-like buckets of failed upstream connect cycle duration" ); - let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_fail_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_upstream_connect_duration_fail_total counter" + ); let _ = writeln!( out, "telemt_upstream_connect_duration_fail_total{{bucket=\"le_100ms\"}} {}", @@ -494,7 +530,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_sent_total ME keepalive frames sent"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_sent_total ME keepalive frames sent" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_sent_total counter"); let _ = writeln!( out, @@ -506,7 +545,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_failed_total ME keepalive send failures"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_failed_total ME keepalive send failures" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!( out, @@ -518,7 +560,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter"); let _ = writeln!( out, @@ -530,7 +575,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts"); + let _ = writeln!( + out, + "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts" + ); let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter"); let _ = writeln!( out, @@ -546,7 +594,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_rpc_proxy_req_signal_sent_total Service RPC_PROXY_REQ activity signals sent" ); - let _ = writeln!(out, "# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter" + ); let _ = writeln!( out, "telemt_me_rpc_proxy_req_signal_sent_total {}", @@ -629,7 +680,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"); + let _ = writeln!( + out, + "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts" + ); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!( out, @@ -641,7 +695,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_reconnect_success_total ME reconnect successes"); + let _ = writeln!( + out, + "# HELP telemt_me_reconnect_success_total ME reconnect successes" + ); let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!( out, @@ -653,7 +710,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream"); + let _ = writeln!( + out, + "# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream" + ); let _ = writeln!(out, "# TYPE telemt_me_handshake_reject_total counter"); let _ = writeln!( out, @@ -665,20 +725,25 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code"); + let _ = writeln!( + out, + "# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code" + ); let _ = writeln!(out, "# TYPE telemt_me_handshake_error_code_total counter"); if me_allows_normal { for (error_code, count) in stats.get_me_handshake_error_code_counts() { let _ = writeln!( out, "telemt_me_handshake_error_code_total{{error_code=\"{}\"}} {}", - error_code, - count + error_code, count ); } } - let _ = writeln!(out, "# HELP telemt_me_reader_eof_total ME reader EOF terminations"); + let _ = writeln!( + out, + "# HELP telemt_me_reader_eof_total ME reader EOF terminations" + ); let _ = writeln!(out, "# TYPE telemt_me_reader_eof_total counter"); let _ = writeln!( out, @@ -780,7 +845,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches"); + let _ = writeln!( + out, + "# HELP telemt_me_seq_mismatch_total ME sequence mismatches" + ); let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter"); let _ = writeln!( out, @@ -792,7 +860,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn"); + let _ = writeln!( + out, + "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn" + ); let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter"); let _ = writeln!( out, @@ -804,8 +875,14 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed"); - let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter"); + let _ = writeln!( + out, + "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_route_drop_channel_closed_total counter" + ); let _ = writeln!( out, "telemt_me_route_drop_channel_closed_total {}", @@ -816,7 +893,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full"); + let _ = writeln!( + out, + "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full" + ); let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter"); let _ = writeln!( out, @@ -973,7 +1053,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_pick_mode_switch_total Writer-pick mode switches via runtime updates" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_pick_mode_switch_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_pick_mode_switch_total counter" + ); let _ = writeln!( out, "telemt_me_writer_pick_mode_switch_total {}", @@ -1023,7 +1106,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_kdf_drift_total ME KDF input drift detections"); + let _ = writeln!( + out, + "# HELP telemt_me_kdf_drift_total ME KDF input drift detections" + ); let _ = writeln!(out, "# TYPE telemt_me_kdf_drift_total counter"); let _ = writeln!( out, @@ -1069,7 +1155,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_hardswap_pending_ttl_expired_total Pending hardswap generations reset by TTL expiration" ); - let _ = writeln!(out, "# TYPE telemt_me_hardswap_pending_ttl_expired_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_hardswap_pending_ttl_expired_total counter" + ); let _ = writeln!( out, "telemt_me_hardswap_pending_ttl_expired_total {}", @@ -1301,10 +1390,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_adaptive_floor_global_cap_raw Runtime raw global adaptive floor cap" ); - let _ = writeln!( - out, - "# TYPE telemt_me_adaptive_floor_global_cap_raw gauge" - ); + let _ = writeln!(out, "# TYPE telemt_me_adaptive_floor_global_cap_raw gauge"); let _ = writeln!( out, "telemt_me_adaptive_floor_global_cap_raw {}", @@ -1487,7 +1573,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"); + let _ = writeln!( + out, + "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths" + ); let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter"); let _ = writeln!( out, @@ -1499,7 +1588,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_total Total crypto-desync detections"); + let _ = writeln!( + out, + "# HELP telemt_desync_total Total crypto-desync detections" + ); let _ = writeln!(out, "# TYPE telemt_desync_total counter"); let _ = writeln!( out, @@ -1511,7 +1603,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_full_logged_total Full forensic desync logs emitted"); + let _ = writeln!( + out, + "# HELP telemt_desync_full_logged_total Full forensic desync logs emitted" + ); let _ = writeln!(out, "# TYPE telemt_desync_full_logged_total counter"); let _ = writeln!( out, @@ -1523,7 +1618,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_suppressed_total Suppressed desync forensic events"); + let _ = writeln!( + out, + "# HELP telemt_desync_suppressed_total Suppressed desync forensic events" + ); let _ = writeln!(out, "# TYPE telemt_desync_suppressed_total counter"); let _ = writeln!( out, @@ -1535,7 +1633,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket"); + let _ = writeln!( + out, + "# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket" + ); let _ = writeln!(out, "# TYPE telemt_desync_frames_bucket_total counter"); let _ = writeln!( out, @@ -1574,7 +1675,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_pool_swap_total Successful ME pool swaps"); + let _ = writeln!( + out, + "# HELP telemt_pool_swap_total Successful ME pool swaps" + ); let _ = writeln!(out, "# TYPE telemt_pool_swap_total counter"); let _ = writeln!( out, @@ -1586,7 +1690,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_pool_drain_active Active draining ME writers"); + let _ = writeln!( + out, + "# HELP telemt_pool_drain_active Active draining ME writers" + ); let _ = writeln!(out, "# TYPE telemt_pool_drain_active gauge"); let _ = writeln!( out, @@ -1598,7 +1705,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_pool_force_close_total Forced close events for draining writers"); + let _ = writeln!( + out, + "# HELP telemt_pool_force_close_total Forced close events for draining writers" + ); let _ = writeln!(out, "# TYPE telemt_pool_force_close_total counter"); let _ = writeln!( out, @@ -1610,7 +1720,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds"); + let _ = writeln!( + out, + "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds" + ); let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter"); let _ = writeln!( out, @@ -1622,7 +1735,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_writer_removed_total Total ME writer removals"); + let _ = writeln!( + out, + "# HELP telemt_me_writer_removed_total Total ME writer removals" + ); let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter"); let _ = writeln!( out, @@ -1638,7 +1754,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_removed_unexpected_total Unexpected ME writer removals that triggered refill" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_removed_unexpected_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_removed_unexpected_total counter" + ); let _ = writeln!( out, "telemt_me_writer_removed_unexpected_total {}", @@ -1649,7 +1768,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_refill_triggered_total Immediate ME refill runs started"); + let _ = writeln!( + out, + "# HELP telemt_me_refill_triggered_total Immediate ME refill runs started" + ); let _ = writeln!(out, "# TYPE telemt_me_refill_triggered_total counter"); let _ = writeln!( out, @@ -1665,7 +1787,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_refill_skipped_inflight_total Immediate ME refill skips due to inflight dedup" ); - let _ = writeln!(out, "# TYPE telemt_me_refill_skipped_inflight_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_refill_skipped_inflight_total counter" + ); let _ = writeln!( out, "telemt_me_refill_skipped_inflight_total {}", @@ -1676,7 +1801,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!(out, "# HELP telemt_me_refill_failed_total Immediate ME refill failures"); + let _ = writeln!( + out, + "# HELP telemt_me_refill_failed_total Immediate ME refill failures" + ); let _ = writeln!(out, "# TYPE telemt_me_refill_failed_total counter"); let _ = writeln!( out, @@ -1692,7 +1820,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_restored_same_endpoint_total Refilled ME writer restored on the same endpoint" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_restored_same_endpoint_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_restored_same_endpoint_total counter" + ); let _ = writeln!( out, "telemt_me_writer_restored_same_endpoint_total {}", @@ -1707,7 +1838,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_writer_restored_fallback_total Refilled ME writer restored via fallback endpoint" ); - let _ = writeln!(out, "# TYPE telemt_me_writer_restored_fallback_total counter"); + let _ = writeln!( + out, + "# TYPE telemt_me_writer_restored_fallback_total counter" + ); let _ = writeln!( out, "telemt_me_writer_restored_fallback_total {}", @@ -1785,17 +1919,35 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp unresolved_writer_losses ); - let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); + let _ = writeln!( + out, + "# HELP telemt_user_connections_total Per-user total connections" + ); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); - let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); + let _ = writeln!( + out, + "# HELP telemt_user_connections_current Per-user active connections" + ); let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge"); - let _ = writeln!(out, "# HELP telemt_user_octets_from_client Per-user bytes received"); + let _ = writeln!( + out, + "# HELP telemt_user_octets_from_client Per-user bytes received" + ); let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter"); - let _ = writeln!(out, "# HELP telemt_user_octets_to_client Per-user bytes sent"); + let _ = writeln!( + out, + "# HELP telemt_user_octets_to_client Per-user bytes sent" + ); let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter"); - let _ = writeln!(out, "# HELP telemt_user_msgs_from_client Per-user messages received"); + let _ = writeln!( + out, + "# HELP telemt_user_msgs_from_client Per-user messages received" + ); let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter"); - let _ = writeln!(out, "# HELP telemt_user_msgs_to_client Per-user messages sent"); + let _ = writeln!( + out, + "# HELP telemt_user_msgs_to_client Per-user messages sent" + ); let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter"); let _ = writeln!( out, @@ -1835,12 +1987,45 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp for entry in stats.iter_user_stats() { let user = entry.key(); let s = entry.value(); - let _ = writeln!(out, "telemt_user_connections_total{{user=\"{}\"}} {}", user, s.connects.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_connections_current{{user=\"{}\"}} {}", user, s.curr_connects.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_octets_from_client{{user=\"{}\"}} {}", user, s.octets_from_client.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_octets_to_client{{user=\"{}\"}} {}", user, s.octets_to_client.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_msgs_from_client{{user=\"{}\"}} {}", user, s.msgs_from_client.load(std::sync::atomic::Ordering::Relaxed)); - let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed)); + let _ = writeln!( + out, + "telemt_user_connections_total{{user=\"{}\"}} {}", + user, + s.connects.load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_connections_current{{user=\"{}\"}} {}", + user, + s.curr_connects.load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_octets_from_client{{user=\"{}\"}} {}", + user, + s.octets_from_client + .load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_octets_to_client{{user=\"{}\"}} {}", + user, + s.octets_to_client + .load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_msgs_from_client{{user=\"{}\"}} {}", + user, + s.msgs_from_client + .load(std::sync::atomic::Ordering::Relaxed) + ); + let _ = writeln!( + out, + "telemt_user_msgs_to_client{{user=\"{}\"}} {}", + user, + s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed) + ); } let ip_stats = ip_tracker.get_stats().await; @@ -1858,16 +2043,25 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp .get_recent_counts_for_users(&unique_users_vec) .await; - let _ = writeln!(out, "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs"); + let _ = writeln!( + out, + "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs" + ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge"); let _ = writeln!( out, "# HELP telemt_user_unique_ips_recent_window Per-user unique IPs seen in configured observation window" ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_recent_window gauge"); - let _ = writeln!(out, "# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)"); + let _ = writeln!( + out, + "# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)" + ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge"); - let _ = writeln!(out, "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)"); + let _ = writeln!( + out, + "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)" + ); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge"); for user in unique_users { @@ -1878,29 +2072,34 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp .get(&user) .copied() .filter(|limit| *limit > 0) - .or( - (config.access.user_max_unique_ips_global_each > 0) - .then_some(config.access.user_max_unique_ips_global_each), - ) + .or((config.access.user_max_unique_ips_global_each > 0) + .then_some(config.access.user_max_unique_ips_global_each)) .unwrap_or(0); let utilization = if limit > 0 { current as f64 / limit as f64 } else { 0.0 }; - let _ = writeln!(out, "telemt_user_unique_ips_current{{user=\"{}\"}} {}", user, current); + let _ = writeln!( + out, + "telemt_user_unique_ips_current{{user=\"{}\"}} {}", + user, current + ); let _ = writeln!( out, "telemt_user_unique_ips_recent_window{{user=\"{}\"}} {}", user, recent_counts.get(&user).copied().unwrap_or(0) ); - let _ = writeln!(out, "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", user, limit); + let _ = writeln!( + out, + "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", + user, limit + ); let _ = writeln!( out, "telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}", - user, - utilization + user, utilization ); } } @@ -1911,8 +2110,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp #[cfg(test)] mod tests { use super::*; - use std::net::IpAddr; use http_body_util::BodyExt; + use std::net::IpAddr; #[tokio::test] async fn test_render_metrics_format() { @@ -1967,13 +2166,10 @@ mod tests { assert!(output.contains("telemt_upstream_connect_success_total 1")); assert!(output.contains("telemt_upstream_connect_fail_total 1")); assert!(output.contains("telemt_upstream_connect_failfast_hard_error_total 1")); + assert!(output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1")); assert!( - output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1") - ); - assert!( - output.contains( - "telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1" - ) + output + .contains("telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1") ); assert!( output.contains("telemt_upstream_connect_duration_fail_total{bucket=\"gt_1000ms\"} 1") @@ -2050,9 +2246,10 @@ mod tests { assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter")); assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter")); assert!(output.contains("# TYPE telemt_me_writer_removed_total counter")); - assert!(output.contains( - "# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge" - )); + assert!( + output + .contains("# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge") + ); assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge")); @@ -2069,14 +2266,17 @@ mod tests { stats.increment_connects_all(); stats.increment_connects_all(); - let req = Request::builder() - .uri("/metrics") - .body(()) + let req = Request::builder().uri("/metrics").body(()).unwrap(); + let resp = handle(req, &stats, &beobachten, &tracker, &config) + .await .unwrap(); - let resp = handle(req, &stats, &beobachten, &tracker, &config).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = resp.into_body().collect().await.unwrap().to_bytes(); - assert!(std::str::from_utf8(body.as_ref()).unwrap().contains("telemt_connections_total 3")); + assert!( + std::str::from_utf8(body.as_ref()) + .unwrap() + .contains("telemt_connections_total 3") + ); config.general.beobachten = true; config.general.beobachten_minutes = 10; @@ -2085,10 +2285,7 @@ mod tests { "203.0.113.10".parse::().unwrap(), Duration::from_secs(600), ); - let req_beob = Request::builder() - .uri("/beobachten") - .body(()) - .unwrap(); + let req_beob = Request::builder().uri("/beobachten").body(()).unwrap(); let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config) .await .unwrap(); @@ -2098,10 +2295,7 @@ mod tests { assert!(beob_text.contains("[TLS-scanner]")); assert!(beob_text.contains("203.0.113.10-1")); - let req404 = Request::builder() - .uri("/other") - .body(()) - .unwrap(); + let req404 = Request::builder().uri("/other").body(()).unwrap(); let resp404 = handle(req404, &stats, &beobachten, &tracker, &config) .await .unwrap(); diff --git a/src/network/dns_overrides.rs b/src/network/dns_overrides.rs index 447863a..86fb325 100644 --- a/src/network/dns_overrides.rs +++ b/src/network/dns_overrides.rs @@ -26,9 +26,7 @@ fn parse_ip_spec(ip_spec: &str) -> Result { } let ip = ip_spec.parse::().map_err(|_| { - ProxyError::Config(format!( - "network.dns_overrides IP is invalid: '{ip_spec}'" - )) + ProxyError::Config(format!("network.dns_overrides IP is invalid: '{ip_spec}'")) })?; if matches!(ip, IpAddr::V6(_)) { return Err(ProxyError::Config(format!( @@ -103,9 +101,9 @@ pub fn validate_entries(entries: &[String]) -> Result<()> { /// Replace runtime DNS overrides with a new validated snapshot. pub fn install_entries(entries: &[String]) -> Result<()> { let parsed = parse_entries(entries)?; - let mut guard = overrides_store() - .write() - .map_err(|_| ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()))?; + let mut guard = overrides_store().write().map_err(|_| { + ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()) + })?; *guard = parsed; Ok(()) } diff --git a/src/network/probe.rs b/src/network/probe.rs index a9e369d..098e2eb 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -10,7 +10,9 @@ use tracing::{debug, info, warn}; use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType}; use crate::error::Result; -use crate::network::stun::{stun_probe_family_with_bind, DualStunResult, IpFamily, StunProbeResult}; +use crate::network::stun::{ + DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind, +}; use crate::transport::UpstreamManager; #[derive(Debug, Clone, Default)] @@ -78,13 +80,8 @@ pub async fn run_probe( warn!("STUN probe is enabled but network.stun_servers is empty"); DualStunResult::default() } else { - probe_stun_servers_parallel( - &servers, - stun_nat_probe_concurrency.max(1), - None, - None, - ) - .await + probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None) + .await } } else if nat_probe { info!("STUN probe is disabled by network.stun_use=false"); @@ -99,7 +96,8 @@ pub async fn run_probe( let UpstreamType::Direct { interface, bind_addresses, - } = &upstream.upstream_type else { + } = &upstream.upstream_type + else { continue; }; if let Some(addrs) = bind_addresses.as_ref().filter(|v| !v.is_empty()) { @@ -217,12 +215,20 @@ pub async fn run_probe( probe.ipv4_usable = config.ipv4 && probe.detected_ipv4.is_some() - && (!probe.ipv4_is_bogon || probe.reflected_ipv4.map(|r| !is_bogon(r.ip())).unwrap_or(false)); + && (!probe.ipv4_is_bogon + || probe + .reflected_ipv4 + .map(|r| !is_bogon(r.ip())) + .unwrap_or(false)); let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()); probe.ipv6_usable = ipv6_enabled && probe.detected_ipv6.is_some() - && (!probe.ipv6_is_bogon || probe.reflected_ipv6.map(|r| !is_bogon(r.ip())).unwrap_or(false)); + && (!probe.ipv6_is_bogon + || probe + .reflected_ipv6 + .map(|r| !is_bogon(r.ip())) + .unwrap_or(false)); Ok(probe) } @@ -300,11 +306,15 @@ async fn probe_stun_servers_parallel( match task { Ok((stun_addr, Ok(Ok(result)))) => { if let Some(v4) = result.v4 { - let entry = best_v4_by_ip.entry(v4.reflected_addr.ip()).or_insert((0, v4)); + let entry = best_v4_by_ip + .entry(v4.reflected_addr.ip()) + .or_insert((0, v4)); entry.0 += 1; } if let Some(v6) = result.v6 { - let entry = best_v6_by_ip.entry(v6.reflected_addr.ip()).or_insert((0, v6)); + let entry = best_v6_by_ip + .entry(v6.reflected_addr.ip()) + .or_insert((0, v6)); entry.0 += 1; } if result.v4.is_some() || result.v6.is_some() { @@ -324,17 +334,11 @@ async fn probe_stun_servers_parallel( } let mut out = DualStunResult::default(); - if let Some((_, best)) = best_v4_by_ip - .into_values() - .max_by_key(|(count, _)| *count) - { + if let Some((_, best)) = best_v4_by_ip.into_values().max_by_key(|(count, _)| *count) { info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); out.v4 = Some(best); } - if let Some((_, best)) = best_v6_by_ip - .into_values() - .max_by_key(|(count, _)| *count) - { + if let Some((_, best)) = best_v6_by_ip.into_values().max_by_key(|(count, _)| *count) { info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); out.v6 = Some(best); } @@ -347,7 +351,8 @@ pub fn decide_network_capabilities( middle_proxy_nat_ip: Option, ) -> NetworkDecision { let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some(); - let ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some(); + let ipv6_dc = + config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some(); let nat_ip_v4 = matches!(middle_proxy_nat_ip, Some(IpAddr::V4(_))); let nat_ip_v6 = matches!(middle_proxy_nat_ip, Some(IpAddr::V6(_))); @@ -534,10 +539,26 @@ pub fn is_bogon_v6(ip: Ipv6Addr) -> bool { pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) { info!( - ipv4 = probe.detected_ipv4.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()), - ipv6 = probe.detected_ipv6.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()), - reflected_v4 = probe.reflected_ipv4.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()), - reflected_v6 = probe.reflected_ipv6.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()), + ipv4 = probe + .detected_ipv4 + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "-".into()), + ipv6 = probe + .detected_ipv6 + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "-".into()), + reflected_v4 = probe + .reflected_ipv4 + .as_ref() + .map(|v| v.ip().to_string()) + .unwrap_or_else(|| "-".into()), + reflected_v6 = probe + .reflected_ipv6 + .as_ref() + .map(|v| v.ip().to_string()) + .unwrap_or_else(|| "-".into()), ipv4_bogon = probe.ipv4_is_bogon, ipv6_bogon = probe.ipv6_is_bogon, ipv4_me = decision.ipv4_me, diff --git a/src/network/stun.rs b/src/network/stun.rs index 6c6bd84..d1e088c 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -4,8 +4,8 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::OnceLock; -use tokio::net::{lookup_host, UdpSocket}; -use tokio::time::{timeout, Duration, sleep}; +use tokio::net::{UdpSocket, lookup_host}; +use tokio::time::{Duration, sleep, timeout}; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; @@ -41,13 +41,13 @@ pub async fn stun_probe_dual(stun_addr: &str) -> Result { stun_probe_family(stun_addr, IpFamily::V6), ); - Ok(DualStunResult { - v4: v4?, - v6: v6?, - }) + Ok(DualStunResult { v4: v4?, v6: v6? }) } -pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result> { +pub async fn stun_probe_family( + stun_addr: &str, + family: IpFamily, +) -> Result> { stun_probe_family_with_bind(stun_addr, family, None).await } @@ -76,13 +76,18 @@ pub async fn stun_probe_family_with_bind( if let Some(addr) = target_addr { match socket.connect(addr).await { Ok(()) => {} - Err(e) if family == IpFamily::V6 && matches!( - e.kind(), - std::io::ErrorKind::NetworkUnreachable - | std::io::ErrorKind::HostUnreachable - | std::io::ErrorKind::Unsupported - | std::io::ErrorKind::NetworkDown - ) => return Ok(None), + Err(e) + if family == IpFamily::V6 + && matches!( + e.kind(), + std::io::ErrorKind::NetworkUnreachable + | std::io::ErrorKind::HostUnreachable + | std::io::ErrorKind::Unsupported + | std::io::ErrorKind::NetworkDown + ) => + { + return Ok(None); + } Err(e) => return Err(ProxyError::Proxy(format!("STUN connect failed: {e}"))), } } else { @@ -125,16 +130,16 @@ pub async fn stun_probe_family_with_bind( let magic = 0x2112A442u32.to_be_bytes(); let txid = &req[8..20]; - let mut idx = 20; - while idx + 4 <= n { - let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); - let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; - idx += 4; - if idx + alen > n { - break; - } + let mut idx = 20; + while idx + 4 <= n { + let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); + let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; + idx += 4; + if idx + alen > n { + break; + } - match atype { + match atype { 0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => { if alen < 8 { break; @@ -203,9 +208,8 @@ pub async fn stun_probe_family_with_bind( _ => {} } - idx += (alen + 3) & !3; - } - + idx += (alen + 3) & !3; + } } Ok(None) @@ -233,7 +237,11 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result> = LazyLock::new(|| { // ============= Middle Proxies (for advertising) ============= -pub static TG_MIDDLE_PROXIES_V4: LazyLock>> = +pub static TG_MIDDLE_PROXIES_V4: LazyLock>> = LazyLock::new(|| { let mut m = std::collections::HashMap::new(); - m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); - m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); - m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); - m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); - m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); - m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); + m.insert( + 1, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)], + ); + m.insert( + -1, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)], + ); + m.insert( + 2, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)], + ); + m.insert( + -2, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)], + ); + m.insert( + 3, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)], + ); + m.insert( + -3, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)], + ); m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]); - m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]); + m.insert( + -4, + vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)], + ); m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); - m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); + m.insert( + -5, + vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)], + ); m }); -pub static TG_MIDDLE_PROXIES_V6: LazyLock>> = +pub static TG_MIDDLE_PROXIES_V6: LazyLock>> = LazyLock::new(|| { let mut m = std::collections::HashMap::new(); - m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); - m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); - m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); - m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); - m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); - m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); - m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); - m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); - m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); - m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); + m.insert( + 1, + vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)], + ); + m.insert( + -1, + vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)], + ); + m.insert( + 2, + vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)], + ); + m.insert( + -2, + vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)], + ); + m.insert( + 3, + vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)], + ); + m.insert( + -3, + vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)], + ); + m.insert( + 4, + vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)], + ); + m.insert( + -4, + vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)], + ); + m.insert( + 5, + vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)], + ); + m.insert( + -5, + vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)], + ); m }); @@ -89,12 +143,12 @@ impl ProtoTag { _ => None, } } - + /// Convert to 4 bytes (little-endian) pub fn to_bytes(self) -> [u8; 4] { (self as u32).to_le_bytes() } - + /// Get protocol tag as bytes slice pub fn as_bytes(&self) -> &'static [u8; 4] { match self { @@ -222,9 +276,7 @@ pub const SMALL_BUFFER_SIZE: usize = 8192; // ============= Statistics ============= /// Duration buckets for histogram metrics -pub static DURATION_BUCKETS: &[f64] = &[ - 0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0, -]; +pub static DURATION_BUCKETS: &[f64] = &[0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0]; // ============= Reserved Nonce Patterns ============= @@ -235,29 +287,27 @@ pub static RESERVED_NONCE_FIRST_BYTES: &[u8] = &[0xef]; pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[ [0x48, 0x45, 0x41, 0x44], // HEAD [0x50, 0x4F, 0x53, 0x54], // POST - [0x47, 0x45, 0x54, 0x20], // GET + [0x47, 0x45, 0x54, 0x20], // GET [0xee, 0xee, 0xee, 0xee], // Intermediate [0xdd, 0xdd, 0xdd, 0xdd], // Secure [0x16, 0x03, 0x01, 0x02], // TLS ]; /// Reserved continuation bytes (bytes 4-7) -pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[ - [0x00, 0x00, 0x00, 0x00], -]; +pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[[0x00, 0x00, 0x00, 0x00]]; // ============= RPC Constants (for Middle Proxy) ============= /// RPC Proxy Request /// RPC Flags (from Erlang mtp_rpc.erl) pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2; -pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8; -pub const RPC_FLAG_MAGIC: u32 = 0x1000; -pub const RPC_FLAG_EXTMODE2: u32 = 0x20000; -pub const RPC_FLAG_PAD: u32 = 0x8000000; -pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000; -pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000; -pub const RPC_FLAG_QUICKACK: u32 = 0x80000000; +pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8; +pub const RPC_FLAG_MAGIC: u32 = 0x1000; +pub const RPC_FLAG_EXTMODE2: u32 = 0x20000; +pub const RPC_FLAG_PAD: u32 = 0x8000000; +pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000; +pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000; +pub const RPC_FLAG_QUICKACK: u32 = 0x80000000; pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36]; /// RPC Proxy Answer @@ -285,67 +335,66 @@ pub mod rpc_flags { pub const FLAG_QUICKACK: u32 = 0x80000000; } +// ============= Middle-End Proxy Servers ============= +pub const ME_PROXY_PORT: u16 = 8888; - // ============= Middle-End Proxy Servers ============= - pub const ME_PROXY_PORT: u16 = 8888; - - pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock> = LazyLock::new(|| { - vec![ - (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888), - (IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888), - (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888), - (IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888), - (IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888), - ] - }); - - // ============= RPC Constants (u32 native endian) ============= - // From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c - - pub const RPC_NONCE_U32: u32 = 0x7acb87aa; - pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5; - pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda; - pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121 - - // mtproto-common.h - pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee; - pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d; - pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d; - pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2; - pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b; - pub const RPC_PING_U32: u32 = 0x5730a2df; - pub const RPC_PONG_U32: u32 = 0x8430eaa7; - - pub const RPC_CRYPTO_NONE_U32: u32 = 0; - pub const RPC_CRYPTO_AES_U32: u32 = 1; - - pub mod proxy_flags { - pub const FLAG_HAS_AD_TAG: u32 = 1; - pub const FLAG_NOT_ENCRYPTED: u32 = 0x2; - pub const FLAG_HAS_AD_TAG2: u32 = 0x8; - pub const FLAG_MAGIC: u32 = 0x1000; - pub const FLAG_EXTMODE2: u32 = 0x20000; - pub const FLAG_PAD: u32 = 0x8000000; - pub const FLAG_INTERMEDIATE: u32 = 0x20000000; - pub const FLAG_ABRIDGED: u32 = 0x40000000; - pub const FLAG_QUICKACK: u32 = 0x80000000; - } +pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock> = LazyLock::new(|| { + vec![ + (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888), + (IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888), + (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888), + (IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888), + (IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888), + ] +}); - pub mod rpc_crypto_flags { - pub const USE_CRC32C: u32 = 0x800; - } - - pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; - pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; +// ============= RPC Constants (u32 native endian) ============= +// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c - #[cfg(test)] - #[path = "tests/tls_size_constants_security_tests.rs"] - mod tls_size_constants_security_tests; - - #[cfg(test)] +pub const RPC_NONCE_U32: u32 = 0x7acb87aa; +pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5; +pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda; +pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121 + +// mtproto-common.h +pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee; +pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d; +pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d; +pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2; +pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b; +pub const RPC_PING_U32: u32 = 0x5730a2df; +pub const RPC_PONG_U32: u32 = 0x8430eaa7; + +pub const RPC_CRYPTO_NONE_U32: u32 = 0; +pub const RPC_CRYPTO_AES_U32: u32 = 1; + +pub mod proxy_flags { + pub const FLAG_HAS_AD_TAG: u32 = 1; + pub const FLAG_NOT_ENCRYPTED: u32 = 0x2; + pub const FLAG_HAS_AD_TAG2: u32 = 0x8; + pub const FLAG_MAGIC: u32 = 0x1000; + pub const FLAG_EXTMODE2: u32 = 0x20000; + pub const FLAG_PAD: u32 = 0x8000000; + pub const FLAG_INTERMEDIATE: u32 = 0x20000000; + pub const FLAG_ABRIDGED: u32 = 0x40000000; + pub const FLAG_QUICKACK: u32 = 0x80000000; +} + +pub mod rpc_crypto_flags { + pub const USE_CRC32C: u32 = 0x800; +} + +pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; +pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; + +#[cfg(test)] +#[path = "tests/tls_size_constants_security_tests.rs"] +mod tls_size_constants_security_tests; + +#[cfg(test)] mod tests { use super::*; - + #[test] fn test_proto_tag_roundtrip() { for tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { @@ -354,20 +403,20 @@ mod tests { assert_eq!(tag, parsed); } } - + #[test] fn test_proto_tag_values() { assert_eq!(ProtoTag::Abridged.to_bytes(), PROTO_TAG_ABRIDGED); assert_eq!(ProtoTag::Intermediate.to_bytes(), PROTO_TAG_INTERMEDIATE); assert_eq!(ProtoTag::Secure.to_bytes(), PROTO_TAG_SECURE); } - + #[test] fn test_invalid_proto_tag() { assert!(ProtoTag::from_bytes([0, 0, 0, 0]).is_none()); assert!(ProtoTag::from_bytes([0xff, 0xff, 0xff, 0xff]).is_none()); } - + #[test] fn test_datacenters_count() { assert_eq!(TG_DATACENTERS_V4.len(), 5); diff --git a/src/protocol/frame.rs b/src/protocol/frame.rs index dd59ba9..d8e3d4a 100644 --- a/src/protocol/frame.rs +++ b/src/protocol/frame.rs @@ -22,7 +22,7 @@ impl FrameExtra { pub fn new() -> Self { Self::default() } - + /// Create with quickack flag set pub fn with_quickack() -> Self { Self { @@ -30,7 +30,7 @@ impl FrameExtra { ..Default::default() } } - + /// Create with simple_ack flag set pub fn with_simple_ack() -> Self { Self { @@ -38,7 +38,7 @@ impl FrameExtra { ..Default::default() } } - + /// Check if any flags are set pub fn has_flags(&self) -> bool { self.quickack || self.simple_ack || self.skip_send @@ -76,22 +76,22 @@ impl FrameMode { FrameMode::Abridged => 4, FrameMode::Intermediate => 4, FrameMode::SecureIntermediate => 4 + 3, // length + padding - FrameMode::Full => 12 + 16, // header + max CBC padding + FrameMode::Full => 12 + 16, // header + max CBC padding } } } /// Validate message length for MTProto pub fn validate_message_length(len: usize) -> bool { - use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER}; - + use super::constants::{MAX_MSG_LEN, MIN_MSG_LEN, PADDING_FILLER}; + (MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) && len.is_multiple_of(PADDING_FILLER.len()) } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_frame_extra_default() { let extra = FrameExtra::default(); @@ -100,18 +100,18 @@ mod tests { assert!(!extra.skip_send); assert!(!extra.has_flags()); } - + #[test] fn test_frame_extra_flags() { let extra = FrameExtra::with_quickack(); assert!(extra.quickack); assert!(extra.has_flags()); - + let extra = FrameExtra::with_simple_ack(); assert!(extra.simple_ack); assert!(extra.has_flags()); } - + #[test] fn test_validate_message_length() { assert!(validate_message_length(12)); // MIN_MSG_LEN @@ -119,4 +119,4 @@ mod tests { assert!(!validate_message_length(8)); // Too small assert!(!validate_message_length(13)); // Not aligned to 4 } -} \ No newline at end of file +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 5518df2..f0b3a1a 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -12,4 +12,4 @@ pub use frame::*; #[allow(unused_imports)] pub use obfuscation::*; #[allow(unused_imports)] -pub use tls::*; \ No newline at end of file +pub use tls::*; diff --git a/src/protocol/obfuscation.rs b/src/protocol/obfuscation.rs index d9d1c0a..7aff9f3 100644 --- a/src/protocol/obfuscation.rs +++ b/src/protocol/obfuscation.rs @@ -2,9 +2,9 @@ #![allow(dead_code)] -use zeroize::Zeroize; -use crate::crypto::{sha256, AesCtr}; use super::constants::*; +use crate::crypto::{AesCtr, sha256}; +use zeroize::Zeroize; /// Obfuscation parameters from handshake /// @@ -44,41 +44,40 @@ impl ObfuscationParams { let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; - + for (username, secret) in secrets { let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(secret); let decrypt_key = sha256(&dec_key_input); - + let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); - + let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let decrypted = decryptor.decrypt(handshake); - + let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] .try_into() .unwrap(); - + let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, None => continue, }; - - let dc_idx = i16::from_le_bytes( - decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() - ); - + + let dc_idx = + i16::from_le_bytes(decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()); + let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(secret); let encrypt_key = sha256(&enc_key_input); let encrypt_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); - + return Some(( ObfuscationParams { decrypt_key, @@ -91,20 +90,20 @@ impl ObfuscationParams { username.clone(), )); } - + None } - + /// Create AES-CTR decryptor for client -> proxy direction pub fn create_decryptor(&self) -> AesCtr { AesCtr::new(&self.decrypt_key, self.decrypt_iv) } - + /// Create AES-CTR encryptor for proxy -> client direction pub fn create_encryptor(&self) -> AesCtr { AesCtr::new(&self.encrypt_key, self.encrypt_iv) } - + /// Get the combined encrypt key and IV for fast mode pub fn enc_key_iv(&self) -> Vec { let mut result = Vec::with_capacity(KEY_LEN + IV_LEN); @@ -120,7 +119,7 @@ pub fn generate_nonce Vec>(mut random_bytes: R) -> [u8; H let nonce_vec = random_bytes(HANDSHAKE_LEN); let mut nonce = [0u8; HANDSHAKE_LEN]; nonce.copy_from_slice(&nonce_vec); - + if is_valid_nonce(&nonce) { return nonce; } @@ -132,17 +131,17 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { return false; } - + let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { return false; } - + let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); if RESERVED_NONCE_CONTINUES.contains(&continue_four) { return false; } - + true } @@ -153,7 +152,7 @@ pub fn prepare_tg_nonce( enc_key_iv: Option<&[u8]>, ) { nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); - + if let Some(key_iv) = enc_key_iv { let reversed: Vec = key_iv.iter().rev().copied().collect(); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); @@ -171,39 +170,39 @@ pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key = sha256(key_iv); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); - + let mut encryptor = AesCtr::new(&enc_key, enc_iv); - + let mut result = nonce.to_vec(); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); - + result } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_is_valid_nonce() { let mut valid = [0x42u8; HANDSHAKE_LEN]; valid[4..8].copy_from_slice(&[1, 2, 3, 4]); assert!(is_valid_nonce(&valid)); - + let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[0] = 0xef; assert!(!is_valid_nonce(&invalid)); - + let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[..4].copy_from_slice(b"HEAD"); assert!(!is_valid_nonce(&invalid)); - + let mut invalid = [0x42u8; HANDSHAKE_LEN]; invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); assert!(!is_valid_nonce(&invalid)); } - + #[test] fn test_generate_nonce() { let mut counter = 0u8; @@ -211,7 +210,7 @@ mod tests { counter = counter.wrapping_add(1); vec![counter; n] }); - + assert!(is_valid_nonce(&nonce)); assert_eq!(nonce.len(), HANDSHAKE_LEN); } diff --git a/src/protocol/tests/tls_adversarial_tests.rs b/src/protocol/tests/tls_adversarial_tests.rs index b8df41a..0b36ba3 100644 --- a/src/protocol/tests/tls_adversarial_tests.rs +++ b/src/protocol/tests/tls_adversarial_tests.rs @@ -1,6 +1,6 @@ use super::*; -use std::time::Instant; use crate::crypto::sha256_hmac; +use std::time::Instant; /// Helper to create a byte vector of specific length. fn make_garbage(len: usize) -> Vec { @@ -33,8 +33,7 @@ fn make_valid_tls_handshake_with_session_id( let digest = make_digest(secret, &handshake, timestamp); - handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&digest); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); handshake } @@ -96,15 +95,15 @@ fn extract_sni_with_overlapping_extension_lengths_rejected() { h.push(0); // Session ID length: 0 h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites h.extend_from_slice(&[0x01, 0x00]); // Compression - + // Extensions start h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32 - + // Extension 1: SNI (type 0) - h.extend_from_slice(&[0x00, 0x00]); + h.extend_from_slice(&[0x00, 0x00]); h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32) h.extend_from_slice(&[0u8; 64]); - + assert!(extract_sni_from_client_hello(&h).is_none()); } @@ -118,19 +117,19 @@ fn extract_sni_with_infinite_loop_potential_extension_rejected() { h.push(0); // Session ID length: 0 h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites h.extend_from_slice(&[0x01, 0x00]); // Compression - + // Extensions start h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16 - - // Extension: zero length but claims more? + + // Extension: zero length but claims more? // If our parser didn't advance, it might loop. // Telemt uses `pos += 4 + elen;` so it always advances. h.extend_from_slice(&[0x12, 0x34]); // Unknown type h.extend_from_slice(&[0x00, 0x00]); // Length 0 - + // Fill the rest with garbage h.extend_from_slice(&[0x42; 12]); - + // We expect it to finish without SNI found assert!(extract_sni_from_client_hello(&h).is_none()); } @@ -143,7 +142,7 @@ fn extract_sni_with_invalid_hostname_rejected() { sni.push(0); sni.extend_from_slice(&(host.len() as u16).to_be_bytes()); sni.extend_from_slice(host); - + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header h.push(0x01); // ClientHello h.extend_from_slice(&[0x00, 0x00, 0x5C]); @@ -152,16 +151,19 @@ fn extract_sni_with_invalid_hostname_rejected() { h.push(0); h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); h.extend_from_slice(&[0x01, 0x00]); - + let mut ext = Vec::new(); ext.extend_from_slice(&0x0000u16.to_be_bytes()); ext.extend_from_slice(&(sni.len() as u16).to_be_bytes()); ext.extend_from_slice(&sni); - + h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); h.extend_from_slice(&ext); - - assert!(extract_sni_from_client_hello(&h).is_none(), "Invalid SNI hostname must be rejected"); + + assert!( + extract_sni_from_client_hello(&h).is_none(), + "Invalid SNI hostname must be rejected" + ); } // ------------------------------------------------------------------ @@ -233,7 +235,7 @@ fn is_tls_handshake_robustness_against_probing() { assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); // Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer) assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); - + // Invalid record type but matching version assert!(!is_tls_handshake(&[0x17, 0x03, 0x03])); // Plaintext HTTP request @@ -247,12 +249,12 @@ fn validate_tls_handshake_at_time_strict_boundary() { let secret = b"strict_boundary_secret_32_bytes_"; let secrets = vec![("u".to_string(), secret.to_vec())]; let now: i64 = 1_000_000_000; - + // Boundary: exactly TIME_SKEW_MAX (120s past) let ts_past = (now - TIME_SKEW_MAX) as u32; let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]); assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some()); - + // Boundary + 1s: should be rejected let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]); @@ -268,14 +270,14 @@ fn extract_sni_with_duplicate_extensions_rejected() { sni1.push(0); sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes()); sni1.extend_from_slice(host1); - + let host2 = b"second.com"; let mut sni2 = Vec::new(); sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes()); sni2.push(0); sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes()); sni2.extend_from_slice(host2); - + let mut ext = Vec::new(); // Ext 1: SNI ext.extend_from_slice(&0x0000u16.to_be_bytes()); @@ -285,7 +287,7 @@ fn extract_sni_with_duplicate_extensions_rejected() { ext.extend_from_slice(&0x0000u16.to_be_bytes()); ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes()); ext.extend_from_slice(&sni2); - + let mut body = Vec::new(); body.extend_from_slice(&[0x03, 0x03]); body.extend_from_slice(&[0u8; 32]); @@ -306,7 +308,7 @@ fn extract_sni_with_duplicate_extensions_rejected() { h.extend_from_slice(&[0x03, 0x03]); h.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); h.extend_from_slice(&handshake); - + // Duplicate SNI extensions are ambiguous and must fail closed. assert!(extract_sni_from_client_hello(&h).is_none()); } @@ -317,21 +319,26 @@ fn extract_alpn_with_malformed_list_rejected() { alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5 alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5) alpn_payload.extend_from_slice(b"h2"); - + let mut ext = Vec::new(); ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16) ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes()); ext.extend_from_slice(&alpn_payload); - - let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03]; + + let mut h = vec![ + 0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03, + ]; h.extend_from_slice(&[0u8; 32]); h.push(0); h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); h.extend_from_slice(&ext); - + let res = extract_alpn_from_client_hello(&h); - assert!(res.is_empty(), "Malformed ALPN list must return empty or fail"); + assert!( + res.is_empty(), + "Malformed ALPN list must return empty or fail" + ); } #[test] @@ -343,9 +350,9 @@ fn extract_sni_with_huge_extension_header_rejected() { h.extend_from_slice(&[0u8; 32]); h.push(0); h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); - + // Extensions start h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything) - + assert!(extract_sni_from_client_hello(&h).is_none()); } diff --git a/src/protocol/tests/tls_fuzz_security_tests.rs b/src/protocol/tests/tls_fuzz_security_tests.rs index 32d8efe..903adb3 100644 --- a/src/protocol/tests/tls_fuzz_security_tests.rs +++ b/src/protocol/tests/tls_fuzz_security_tests.rs @@ -84,7 +84,10 @@ fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec> 17) as u8).wrapping_add(1); } @@ -171,9 +182,13 @@ fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() { } for (idx, handshake) in corpus.iter().enumerate() { - let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now)); + let result = + catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now)); assert!(result.is_ok(), "corpus item {idx} must not panic"); - assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed"); + assert!( + result.unwrap().is_none(), + "corpus item {idx} must fail closed" + ); } } diff --git a/src/protocol/tests/tls_security_tests.rs b/src/protocol/tests/tls_security_tests.rs index a6e7b2b..3008e57 100644 --- a/src/protocol/tests/tls_security_tests.rs +++ b/src/protocol/tests/tls_security_tests.rs @@ -1,7 +1,9 @@ use super::*; use crate::crypto::sha256_hmac; use crate::tls_front::emulator::build_emulated_server_hello; -use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource}; +use crate::tls_front::types::{ + CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource, +}; use std::time::SystemTime; /// Build a TLS-handshake-like buffer that contains a valid HMAC digest @@ -39,8 +41,7 @@ fn make_valid_tls_handshake_with_session_id( digest[28 + i] ^= ts[i]; } - handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&digest); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); handshake } @@ -180,7 +181,10 @@ fn second_user_in_list_found_when_first_does_not_match() { ("user_b".to_string(), secret_b.to_vec()), ]; let result = validate_tls_handshake(&handshake, &secrets, true); - assert!(result.is_some(), "user_b must be found even though user_a comes first"); + assert!( + result.is_some(), + "user_b must be found even though user_a comes first" + ); assert_eq!(result.unwrap().user, "user_b"); } @@ -428,8 +432,7 @@ fn censor_probe_random_digests_all_rejected() { let mut h = vec![0x42u8; min_len]; h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; let rand_digest = rng.bytes(TLS_DIGEST_LEN); - h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&rand_digest); + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&rand_digest); assert!( validate_tls_handshake(&h, &secrets, true).is_none(), "Random digest at attempt {attempt} must not match" @@ -553,8 +556,7 @@ fn system_time_before_unix_epoch_is_rejected_without_panic() { fn system_time_far_future_overflowing_i64_returns_none() { // i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`. let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1; - if let Some(far_future) = - UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs)) + if let Some(far_future) = UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs)) { assert!( system_time_to_unix_secs(far_future).is_none(), @@ -620,7 +622,10 @@ fn appended_trailing_byte_causes_rejection() { let mut h = make_valid_tls_handshake(secret, 0); let secrets = vec![("u".to_string(), secret.to_vec())]; - assert!(validate_tls_handshake(&h, &secrets, true).is_some(), "baseline"); + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "baseline" + ); h.push(0x00); assert!( @@ -647,8 +652,7 @@ fn zero_length_session_id_accepted() { let computed = sha256_hmac(secret, &handshake); // timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged. - handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&computed); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&computed); let secrets = vec![("u".to_string(), secret.to_vec())]; let result = validate_tls_handshake(&handshake, &secrets, true); @@ -773,10 +777,18 @@ fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() { let secrets = vec![("u".to_string(), secret.to_vec())]; let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); - let cap_nonzero = - validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_COMPAT_MAX_SECS); + let cap_nonzero = validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + true, + 0, + BOOT_TIME_COMPAT_MAX_SECS, + ); - assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC"); + assert!( + cap_zero.is_some(), + "ignore_time_skew=true must accept valid HMAC" + ); assert!( cap_nonzero.is_some(), "ignore_time_skew path must not depend on boot-time cap" @@ -888,8 +900,8 @@ fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disa let ts_i64 = now - offset; let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix"); let h = make_valid_tls_handshake(secret, ts); - let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) - .is_some(); + let accepted = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset); assert_eq!( accepted, expected, @@ -917,8 +929,8 @@ fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() { let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test"); let h = make_valid_tls_handshake(secret, ts); - let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) - .is_some(); + let accepted = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); assert!( !accepted, "offset {offset} must be rejected outside strict skew window" @@ -940,8 +952,8 @@ fn stress_boot_disabled_validation_matches_time_diff_oracle() { let ts = s as u32; let h = make_valid_tls_handshake(secret, ts); - let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) - .is_some(); + let accepted = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); let time_diff = now - i64::from(ts); let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff); assert_eq!( @@ -960,7 +972,10 @@ fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() { let mut secrets = Vec::new(); for i in 0..512u32 { - secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes())); + secrets.push(( + format!("noise-{i}"), + format!("noise-secret-{i}").into_bytes(), + )); } secrets.push(("target-user".to_string(), target_secret.to_vec())); @@ -1018,7 +1033,10 @@ fn u32_max_timestamp_accepted_with_ignore_time_skew() { let secrets = vec![("u".to_string(), secret.to_vec())]; let result = validate_tls_handshake(&h, &secrets, true); - assert!(result.is_some(), "u32::MAX timestamp must be accepted with ignore_time_skew=true"); + assert!( + result.is_some(), + "u32::MAX timestamp must be accepted with ignore_time_skew=true" + ); assert_eq!( result.unwrap().timestamp, u32::MAX, @@ -1150,16 +1168,17 @@ fn first_matching_user_wins_over_later_duplicate_secret() { let secrets = vec![ ("decoy_1".to_string(), b"wrong_1".to_vec()), - ("winner".to_string(), shared.to_vec()), // first match + ("winner".to_string(), shared.to_vec()), // first match ("decoy_2".to_string(), b"wrong_2".to_vec()), - ("loser".to_string(), shared.to_vec()), // second match — must not win + ("loser".to_string(), shared.to_vec()), // second match — must not win ("decoy_3".to_string(), b"wrong_3".to_vec()), ]; let result = validate_tls_handshake(&h, &secrets, true); assert!(result.is_some()); assert_eq!( - result.unwrap().user, "winner", + result.unwrap().user, + "winner", "first matching user must be returned even when a later entry also matches" ); } @@ -1425,7 +1444,8 @@ fn test_build_server_hello_structure() { assert!(response.len() > ccs_start + 6); assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); - let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let ccs_len = + 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; let app_start = ccs_start + ccs_len; assert!(response.len() > app_start + 5); assert_eq!(response[app_start], TLS_RECORD_APPLICATION); @@ -1729,7 +1749,10 @@ fn empty_secret_hmac_is_supported() { let handshake = make_valid_tls_handshake(secret, 0); let secrets = vec![("empty".to_string(), secret.to_vec())]; let result = validate_tls_handshake(&handshake, &secrets, true); - assert!(result.is_some(), "Empty HMAC key must not panic and must validate when correct"); + assert!( + result.is_some(), + "Empty HMAC key must not panic and must validate when correct" + ); } #[test] @@ -1802,7 +1825,10 @@ fn server_hello_application_data_payload_varies_across_runs() { let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec(); - assert!(payload.iter().any(|&b| b != 0), "Payload must not be all-zero deterministic filler"); + assert!( + payload.iter().any(|&b| b != 0), + "Payload must not be all-zero deterministic filler" + ); unique_payloads.insert(payload); } @@ -1846,7 +1872,13 @@ fn large_replay_window_does_not_expand_time_skew_acceptance() { #[test] fn parse_tls_record_header_accepts_tls_version_constant() { - let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A]; + let header = [ + TLS_RECORD_HANDSHAKE, + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x2A, + ]; let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted"); assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE); assert_eq!(parsed.1, 42); @@ -1868,7 +1900,10 @@ fn server_hello_clamps_fake_cert_len_lower_bound() { let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); - assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes"); + assert_eq!( + app_len, 64, + "fake cert payload must be clamped to minimum 64 bytes" + ); } #[test] @@ -1887,7 +1922,10 @@ fn server_hello_clamps_fake_cert_len_upper_bound() { let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); - assert_eq!(app_len, MAX_TLS_CIPHERTEXT_SIZE, "fake cert payload must be clamped to TLS record max bound"); + assert_eq!( + app_len, MAX_TLS_CIPHERTEXT_SIZE, + "fake cert payload must be clamped to TLS record max bound" + ); } #[test] @@ -1898,7 +1936,15 @@ fn server_hello_new_session_ticket_count_matches_configuration() { let rng = crate::crypto::SecureRandom::new(); let tickets: u8 = 3; - let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets); + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + None, + tickets, + ); let mut pos = 0usize; let mut app_records = 0usize; @@ -1906,7 +1952,10 @@ fn server_hello_new_session_ticket_count_matches_configuration() { let rtype = response[pos]; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let next = pos + 5 + rlen; - assert!(next <= response.len(), "TLS record must stay inside response bounds"); + assert!( + next <= response.len(), + "TLS record must stay inside response bounds" + ); if rtype == TLS_RECORD_APPLICATION { app_records += 1; } @@ -1927,7 +1976,15 @@ fn server_hello_new_session_ticket_count_is_safely_capped() { let session_id = vec![0x54; 32]; let rng = crate::crypto::SecureRandom::new(); - let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, u8::MAX); + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + None, + u8::MAX, + ); let mut pos = 0usize; let mut app_records = 0usize; @@ -1935,7 +1992,10 @@ fn server_hello_new_session_ticket_count_is_safely_capped() { let rtype = response[pos]; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let next = pos + 5 + rlen; - assert!(next <= response.len(), "TLS record must stay inside response bounds"); + assert!( + next <= response.len(), + "TLS record must stay inside response bounds" + ); if rtype == TLS_RECORD_APPLICATION { app_records += 1; } @@ -1943,8 +2003,7 @@ fn server_hello_new_session_ticket_count_is_safely_capped() { } assert_eq!( - app_records, - 5, + app_records, 5, "response must cap ticket-like tail records to four plus one main application record" ); } @@ -1972,10 +2031,14 @@ fn boot_time_handshake_replay_remains_blocked_after_cache_window_expires() { std::thread::sleep(std::time::Duration::from_millis(70)); - let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) - .expect("boot-time handshake must still cryptographically validate after cache expiry"); + let validation_after_expiry = + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must still cryptographically validate after cache expiry"); let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; - assert_eq!(digest_half, digest_half_after_expiry, "replay key must be stable for same handshake"); + assert_eq!( + digest_half, digest_half_after_expiry, + "replay key must be stable for same handshake" + ); assert!( checker.check_and_add_tls_digest(digest_half_after_expiry), @@ -2006,8 +2069,9 @@ fn adversarial_boot_time_handshake_should_not_be_replayable_after_cache_expiry() std::thread::sleep(std::time::Duration::from_millis(70)); - let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) - .expect("boot-time handshake still validates cryptographically after cache expiry"); + let validation_after_expiry = + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake still validates cryptographically after cache expiry"); let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; assert_eq!( @@ -2067,11 +2131,14 @@ fn light_fuzz_boot_time_timestamp_matrix_with_short_replay_window_obeys_boot_cap let ts = (s as u32) % 8; let handshake = make_valid_tls_handshake(secret, ts); - let accepted = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) - .is_some(); + let accepted = + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2).is_some(); if ts < 2 { - assert!(accepted, "timestamp {ts} must remain boot-time compatible under 2s cap"); + assert!( + accepted, + "timestamp {ts} must remain boot-time compatible under 2s cap" + ); } else { assert!( !accepted, @@ -2107,7 +2174,9 @@ fn server_hello_application_data_contains_alpn_marker_when_selected() { let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; assert!( - app_payload.windows(expected.len()).any(|window| window == expected), + app_payload + .windows(expected.len()) + .any(|window| window == expected), "first application payload must carry ALPN marker for selected protocol" ); } @@ -2137,7 +2206,10 @@ fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() { let rtype = response[pos]; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let next = pos + 5 + rlen; - assert!(next <= response.len(), "TLS record must stay inside response bounds"); + assert!( + next <= response.len(), + "TLS record must stay inside response bounds" + ); if rtype == TLS_RECORD_APPLICATION { app_records += 1; if first_app_payload.is_none() { @@ -2146,7 +2218,9 @@ fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() { } pos = next; } - let marker = [0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x']; + let marker = [ + 0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x', + ]; assert_eq!( app_records, 5, @@ -2310,13 +2384,13 @@ fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() { && header[1] == 0x03 && (header[2] == 0x01 || header[2] == 0x03); assert_eq!( - classified, - expected_classified, + classified, expected_classified, "classifier policy mismatch for header {header:02x?}" ); let parsed = parse_tls_record_header(&header); - let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]); + let expected_parsed = + header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]); assert_eq!( parsed.is_some(), expected_parsed, diff --git a/src/protocol/tests/tls_size_constants_security_tests.rs b/src/protocol/tests/tls_size_constants_security_tests.rs index 1389ab6..20e24c7 100644 --- a/src/protocol/tests/tls_size_constants_security_tests.rs +++ b/src/protocol/tests/tls_size_constants_security_tests.rs @@ -1,8 +1,4 @@ -use super::{ - MAX_TLS_CIPHERTEXT_SIZE, - MAX_TLS_PLAINTEXT_SIZE, - MIN_TLS_CLIENT_HELLO_SIZE, -}; +use super::{MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; #[test] fn tls_size_constants_match_rfc_8446() { diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 9cac85e..82527ca 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -6,10 +6,10 @@ #![allow(dead_code)] -use crate::crypto::{sha256_hmac, SecureRandom}; +use super::constants::*; +use crate::crypto::{SecureRandom, sha256_hmac}; #[cfg(test)] use crate::error::ProxyError; -use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; use subtle::ConstantTimeEq; use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; @@ -31,7 +31,7 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Operators with known clock-drifted clients should tune deployment config /// (for example replay-window policy) to match their environment. pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before -pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after +pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after /// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; /// Hard cap for boot-time compatibility bypass to avoid oversized acceptance @@ -69,7 +69,6 @@ pub struct TlsValidation { /// Client digest for response generation pub digest: [u8; TLS_DIGEST_LEN], /// Timestamp extracted from digest - pub timestamp: u32, } @@ -87,60 +86,63 @@ impl TlsExtensionBuilder { extensions: Vec::with_capacity(128), } } - + /// Add Key Share extension with X25519 key fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self { // Extension type: key_share (0x0033) - self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes()); - + self.extensions + .extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes()); + // Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes // Extension data length let entry_len: u16 = 2 + 2 + 32; // curve + length + key self.extensions.extend_from_slice(&entry_len.to_be_bytes()); - + // Named curve: x25519 - self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes()); - + self.extensions + .extend_from_slice(&named_curve::X25519.to_be_bytes()); + // Key length self.extensions.extend_from_slice(&(32u16).to_be_bytes()); - + // Key data self.extensions.extend_from_slice(public_key); - + self } - + /// Add Supported Versions extension fn add_supported_versions(&mut self, version: u16) -> &mut Self { // Extension type: supported_versions (0x002b) - self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes()); - + self.extensions + .extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes()); + // Extension data: length (2) + version (2) self.extensions.extend_from_slice(&(2u16).to_be_bytes()); - + // Selected version self.extensions.extend_from_slice(&version.to_be_bytes()); - + self } /// Build final extensions with length prefix - + fn build(self) -> Vec { let mut result = Vec::with_capacity(2 + self.extensions.len()); - + // Extensions length (2 bytes) let len = self.extensions.len() as u16; result.extend_from_slice(&len.to_be_bytes()); - + // Extensions data result.extend_from_slice(&self.extensions); - + result } - + /// Get current extensions without length prefix (for calculation) - + fn as_bytes(&self) -> &[u8] { &self.extensions } @@ -172,12 +174,12 @@ impl ServerHelloBuilder { extensions: TlsExtensionBuilder::new(), } } - + fn with_x25519_key(mut self, key: &[u8; 32]) -> Self { self.extensions.add_key_share(key); self } - + fn with_tls13_version(mut self) -> Self { // TLS 1.3 = 0x0304 self.extensions.add_supported_versions(0x0304); @@ -188,7 +190,7 @@ impl ServerHelloBuilder { fn build_message(&self) -> Vec { let extensions = self.extensions.extensions.clone(); let extensions_len = extensions.len() as u16; - + // Calculate total length let body_len = 2 + // version 32 + // random @@ -196,55 +198,55 @@ impl ServerHelloBuilder { 2 + // cipher suite 1 + // compression 2 + extensions.len(); // extensions length + data - + let mut message = Vec::with_capacity(4 + body_len); - + // Handshake header message.push(0x02); // ServerHello message type - + // 3-byte length let len_bytes = (body_len as u32).to_be_bytes(); message.extend_from_slice(&len_bytes[1..4]); - + // Server version (TLS 1.2 in header, actual version in extension) message.extend_from_slice(&TLS_VERSION); - + // Random (32 bytes) - placeholder, will be replaced with digest message.extend_from_slice(&self.random); - + // Session ID message.push(self.session_id.len() as u8); message.extend_from_slice(&self.session_id); - + // Cipher suite message.extend_from_slice(&self.cipher_suite); - + // Compression method message.push(self.compression); - + // Extensions length message.extend_from_slice(&extensions_len.to_be_bytes()); - + // Extensions data message.extend_from_slice(&extensions); - + message } - + /// Build complete ServerHello TLS record fn build_record(&self) -> Vec { let message = self.build_message(); - + let mut record = Vec::with_capacity(5 + message.len()); - + // TLS record header record.push(TLS_RECORD_HANDSHAKE); record.extend_from_slice(&TLS_VERSION); record.extend_from_slice(&(message.len() as u16).to_be_bytes()); - + // Message record.extend_from_slice(&message); - + record } } @@ -320,7 +322,6 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option { i64::try_from(d.as_secs()).ok() } - fn validate_tls_handshake_at_time( handshake: &[u8], secrets: &[(String, Vec)], @@ -346,12 +347,12 @@ fn validate_tls_handshake_at_time_with_boot_cap( if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { return None; } - + // Extract digest let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] .try_into() .ok()?; - + // Extract session ID let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; let session_id_len = handshake.get(session_id_len_pos).copied()? as usize; @@ -359,17 +360,17 @@ fn validate_tls_handshake_at_time_with_boot_cap( return None; } let session_id_start = session_id_len_pos + 1; - + if handshake.len() < session_id_start + session_id_len { return None; } - + let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec(); - + // Build message for HMAC (with zeroed digest) let mut msg = handshake.to_vec(); msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); - + let mut first_match: Option<(&String, u32)> = None; for (user, secret) in secrets { @@ -408,7 +409,7 @@ fn validate_tls_handshake_at_time_with_boot_cap( } } } - + if first_match.is_none() { first_match = Some((user, timestamp)); } @@ -453,25 +454,30 @@ pub fn build_server_hello( const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; let fake_cert_len = fake_cert_len.clamp(MIN_APP_DATA, MAX_APP_DATA); let x25519_key = gen_fake_x25519_key(rng); - + // Build ServerHello let server_hello = ServerHelloBuilder::new(session_id.to_vec()) .with_x25519_key(&x25519_key) .with_tls13_version() .build_record(); - + // Build Change Cipher Spec record let change_cipher_spec = [ TLS_RECORD_CHANGE_CIPHER, - TLS_VERSION[0], TLS_VERSION[1], - 0x00, 0x01, // length = 1 - 0x01, // CCS byte + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x01, // length = 1 + 0x01, // CCS byte ]; - + // Build first encrypted flight mimic as opaque ApplicationData bytes. // Embed a compact EncryptedExtensions-like ALPN block when selected. let mut fake_cert = Vec::with_capacity(fake_cert_len); - if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) { + if let Some(proto) = alpn + .as_ref() + .filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) + { let proto_list_len = 1usize + proto.len(); let ext_data_len = 2usize + proto_list_len; let marker_len = 4usize + ext_data_len; @@ -496,7 +502,7 @@ pub fn build_server_hello( // Fill ApplicationData with fully random bytes of desired length to avoid // deterministic DPI fingerprints (fixed inner content type markers). app_data_record.extend_from_slice(&fake_cert); - + // Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted; // here we mimic with opaque ApplicationData records of plausible size). let mut tickets = Vec::new(); @@ -515,7 +521,10 @@ pub fn build_server_hello( // Combine all records let mut response = Vec::with_capacity( - server_hello.len() + change_cipher_spec.len() + app_data_record.len() + tickets.iter().map(|r| r.len()).sum::() + server_hello.len() + + change_cipher_spec.len() + + app_data_record.len() + + tickets.iter().map(|r| r.len()).sum::(), ); response.extend_from_slice(&server_hello); response.extend_from_slice(&change_cipher_spec); @@ -523,18 +532,17 @@ pub fn build_server_hello( for t in &tickets { response.extend_from_slice(t); } - + // Compute HMAC for the response let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len()); hmac_input.extend_from_slice(client_digest); hmac_input.extend_from_slice(&response); let response_digest = sha256_hmac(secret, &hmac_input); - + // Insert computed digest into ServerHello // Position: record header (5) + message type (1) + length (3) + version (2) = 11 - response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&response_digest); - + response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&response_digest); + response } @@ -611,12 +619,14 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { let sn_end = std::cmp::min(sn_pos + list_len, pos + elen); while sn_pos + 3 <= sn_end { let name_type = handshake[sn_pos]; - let name_len = u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize; + let name_len = + u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize; sn_pos += 3; if sn_pos + name_len > sn_end { break; } - if name_type == 0 && name_len > 0 + if name_type == 0 + && name_len > 0 && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) { if is_valid_sni_hostname(host) { @@ -679,35 +689,49 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { } pos += 4; // type + len pos += 2 + 32; // version + random - if pos >= handshake.len() { return Vec::new(); } + if pos >= handshake.len() { + return Vec::new(); + } let session_id_len = *handshake.get(pos).unwrap_or(&0) as usize; pos += 1 + session_id_len; - if pos + 2 > handshake.len() { return Vec::new(); } - let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; + if pos + 2 > handshake.len() { + return Vec::new(); + } + let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; pos += 2 + cipher_len; - if pos >= handshake.len() { return Vec::new(); } + if pos >= handshake.len() { + return Vec::new(); + } let comp_len = *handshake.get(pos).unwrap_or(&0) as usize; pos += 1 + comp_len; - if pos + 2 > handshake.len() { return Vec::new(); } - let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; + if pos + 2 > handshake.len() { + return Vec::new(); + } + let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; pos += 2; let ext_end = pos + ext_len; - if ext_end > handshake.len() { return Vec::new(); } + if ext_end > handshake.len() { + return Vec::new(); + } let mut out = Vec::new(); while pos + 4 <= ext_end { - let etype = u16::from_be_bytes([handshake[pos], handshake[pos+1]]); - let elen = u16::from_be_bytes([handshake[pos+2], handshake[pos+3]]) as usize; + let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]); + let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize; pos += 4; - if pos + elen > ext_end { break; } + if pos + elen > ext_end { + break; + } if etype == extension_type::ALPN && elen >= 3 { - let list_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; + let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; let mut lp = pos + 2; let list_end = (pos + 2).saturating_add(list_len).min(pos + elen); while lp < list_end { let plen = handshake[lp] as usize; lp += 1; - if lp + plen > list_end { break; } - out.push(handshake[lp..lp+plen].to_vec()); + if lp + plen > list_end { + break; + } + out.push(handshake[lp..lp + plen].to_vec()); lp += plen; } break; @@ -717,16 +741,15 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { out } - /// Check if bytes look like a TLS ClientHello pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { if first_bytes.len() < 3 { return false; } - + // TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303. - first_bytes[0] == TLS_RECORD_HANDSHAKE - && first_bytes[1] == 0x03 + first_bytes[0] == TLS_RECORD_HANDSHAKE + && first_bytes[1] == 0x03 && (first_bytes[2] == 0x01 || first_bytes[2] == 0x03) } @@ -735,12 +758,12 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { let record_type = header[0]; let version = [header[1], header[2]]; - + // We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3 if version != [0x03, 0x01] && version != TLS_VERSION { return None; } - + let length = u16::from_be_bytes([header[3], header[4]]); Some((record_type, length)) } @@ -756,7 +779,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { version: [0, 0], }); } - + // Check record header if data[0] != TLS_RECORD_HANDSHAKE { return Err(ProxyError::InvalidTlsRecord { @@ -764,7 +787,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { version: [data[1], data[2]], }); } - + // Check version if data[1..3] != TLS_VERSION { return Err(ProxyError::InvalidTlsRecord { @@ -772,31 +795,34 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { version: [data[1], data[2]], }); } - + // Check record length let record_len = u16::from_be_bytes([data[3], data[4]]) as usize; if data.len() < 5 + record_len { - return Err(ProxyError::InvalidHandshake( - format!("ServerHello record truncated: expected {}, got {}", - 5 + record_len, data.len()) - )); + return Err(ProxyError::InvalidHandshake(format!( + "ServerHello record truncated: expected {}, got {}", + 5 + record_len, + data.len() + ))); } - + // Check message type if data[5] != 0x02 { - return Err(ProxyError::InvalidHandshake( - format!("Expected ServerHello (0x02), got 0x{:02x}", data[5]) - )); + return Err(ProxyError::InvalidHandshake(format!( + "Expected ServerHello (0x02), got 0x{:02x}", + data[5] + ))); } - + // Parse message length let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize; if msg_len + 4 != record_len { - return Err(ProxyError::InvalidHandshake( - format!("Message length mismatch: {} + 4 != {}", msg_len, record_len) - )); + return Err(ProxyError::InvalidHandshake(format!( + "Message length mismatch: {} + 4 != {}", + msg_len, record_len + ))); } - + Ok(()) } @@ -806,7 +832,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { /// Using `static_assertions` ensures these can never silently break across /// refactors without a compile error. mod compile_time_security_checks { - use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN}; + use super::{TLS_DIGEST_HALF_LEN, TLS_DIGEST_LEN}; use static_assertions::const_assert; // The digest must be exactly one SHA-256 output. diff --git a/src/proxy/adaptive_buffers.rs b/src/proxy/adaptive_buffers.rs index 3b1bce9..bb61858 100644 --- a/src/proxy/adaptive_buffers.rs +++ b/src/proxy/adaptive_buffers.rs @@ -170,7 +170,8 @@ impl SessionAdaptiveController { return self.promote(TierTransitionReason::SoftConfirmed, 0); } - let demote_candidate = self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now; + let demote_candidate = + self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now; if demote_candidate { self.quiet_ticks = self.quiet_ticks.saturating_add(1); if self.quiet_ticks >= QUIET_DEMOTE_TICKS { @@ -253,10 +254,7 @@ pub fn record_user_tier(user: &str, tier: AdaptiveTier) { }; return; } - profiles().insert( - user.to_string(), - UserAdaptiveProfile { tier, seen_at: now }, - ); + profiles().insert(user.to_string(), UserAdaptiveProfile { tier, seen_at: now }); } pub fn direct_copy_buffers_for_tier( @@ -339,10 +337,7 @@ mod tests { sample( 300_000, // ~9.6 Mbps 320_000, // incoming > outgoing to confirm tier2 - 250_000, - 10, - 0, - 0, + 250_000, 10, 0, 0, ), tick_secs, ); @@ -358,10 +353,7 @@ mod tests { fn test_hard_promotion_on_pending_pressure() { let mut ctrl = SessionAdaptiveController::new(AdaptiveTier::Base); let transition = ctrl - .observe( - sample(10_000, 20_000, 10_000, 4, 1, 3), - 0.25, - ) + .observe(sample(10_000, 20_000, 10_000, 4, 1, 3), 0.25) .expect("expected hard promotion"); assert_eq!(transition.reason, TierTransitionReason::HardPressure); assert_eq!(transition.to, AdaptiveTier::Tier1); diff --git a/src/proxy/client.rs b/src/proxy/client.rs index a68a8c2..d71fc36 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,5 +1,7 @@ //! Client Handler +use ipnetwork::IpNetwork; +use rand::RngExt; use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; @@ -7,8 +9,6 @@ use std::sync::Arc; use std::sync::OnceLock; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use ipnetwork::IpNetwork; -use rand::RngExt; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::net::TcpStream; use tokio::time::timeout; @@ -75,10 +75,10 @@ use crate::protocol::tls; use crate::stats::beobachten::BeobachtenStore; use crate::stats::{ReplayChecker, Stats}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; -use crate::transport::middle_proxy::MePool; -use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol}; -use crate::transport::socket::normalize_ip; use crate::tls_front::TlsFrontCache; +use crate::transport::middle_proxy::MePool; +use crate::transport::socket::normalize_ip; +use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol}; use crate::proxy::direct_relay::handle_via_direct; use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; @@ -128,7 +128,10 @@ fn tls_clienthello_len_in_bounds(tls_len: usize) -> bool { (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) } -async fn read_with_progress(reader: &mut R, mut buf: &mut [u8]) -> std::io::Result { +async fn read_with_progress( + reader: &mut R, + mut buf: &mut [u8], +) -> std::io::Result { let mut total = 0usize; while !buf.is_empty() { match reader.read(buf).await { @@ -271,10 +274,14 @@ where let mut local_addr = synthetic_local_addr(config.server.port); if proxy_protocol_enabled { - let proxy_header_timeout = Duration::from_millis( - config.server.proxy_protocol_header_timeout_ms.max(1), - ); - match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { + let proxy_header_timeout = + Duration::from_millis(config.server.proxy_protocol_header_timeout_ms.max(1)); + match timeout( + proxy_header_timeout, + parse_proxy_protocol(&mut stream, peer), + ) + .await + { Ok(Ok(info)) => { if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) { @@ -674,9 +681,8 @@ impl RunningClientHandler { let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; if self.proxy_protocol_enabled { - let proxy_header_timeout = Duration::from_millis( - self.config.server.proxy_protocol_header_timeout_ms.max(1), - ); + let proxy_header_timeout = + Duration::from_millis(self.config.server.proxy_protocol_header_timeout_ms.max(1)); match timeout( proxy_header_timeout, parse_proxy_protocol(&mut self.stream, self.peer), @@ -761,7 +767,11 @@ impl RunningClientHandler { } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { + async fn handle_tls_client( + mut self, + first_bytes: [u8; 5], + local_addr: SocketAddr, + ) -> Result { let peer = self.peer; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -895,7 +905,8 @@ impl RunningClientHandler { } else { wrap_tls_application_record(&pending_plaintext) }; - let reader = tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader); + let reader = + tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader); stats.increment_connects_bad(); debug!( peer = %peer, @@ -933,7 +944,11 @@ impl RunningClientHandler { ))) } - async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { + async fn handle_direct_client( + mut self, + first_bytes: [u8; 5], + local_addr: SocketAddr, + ) -> Result { let peer = self.peer; if !self.config.general.modes.classic && !self.config.general.modes.secure { @@ -1035,22 +1050,21 @@ impl RunningClientHandler { { let user = success.user.clone(); - let user_limit_reservation = - match Self::acquire_user_connection_reservation_static( - &user, - &config, - stats.clone(), - peer_addr, - ip_tracker, - ) - .await - { - Ok(reservation) => reservation, - Err(e) => { - warn!(user = %user, error = %e, "User admission check failed"); - return Err(e); - } - }; + let user_limit_reservation = match Self::acquire_user_connection_reservation_static( + &user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + { + Ok(reservation) => reservation, + Err(e) => { + warn!(user = %user, error = %e, "User admission check failed"); + return Err(e); + } + }; let route_snapshot = route_runtime.snapshot(); let session_id = rng.u64(); @@ -1134,7 +1148,11 @@ impl RunningClientHandler { }); } - let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64); + let limit = config + .access + .user_max_tcp_conns + .get(user) + .map(|v| *v as u64); if !stats.try_acquire_user_curr_connects(user, limit) { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string(), diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 18cbda3..7b2572e 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,10 +1,10 @@ +use std::collections::HashSet; use std::ffi::OsString; use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; use std::path::{Component, Path, PathBuf}; use std::sync::Arc; -use std::collections::HashSet; use std::sync::{Mutex, OnceLock}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split}; @@ -25,11 +25,11 @@ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; -#[cfg(unix)] -use std::os::unix::fs::OpenOptionsExt; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd}; const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; @@ -160,7 +160,9 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { } } -fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std::io::Result { +fn open_unknown_dc_log_append_anchored( + path: &SanitizedUnknownDcLogPath, +) -> std::io::Result { #[cfg(unix)] { let parent = OpenOptions::new() @@ -168,14 +170,23 @@ fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std: .custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC) .open(&path.allowed_parent)?; - let file_name = std::ffi::CString::new(path.file_name.as_os_str().as_bytes()) - .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "unknown DC log file name contains NUL byte"))?; + let file_name = + std::ffi::CString::new(path.file_name.as_os_str().as_bytes()).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "unknown DC log file name contains NUL byte", + ) + })?; let fd = unsafe { libc::openat( parent.as_raw_fd(), file_name.as_ptr(), - libc::O_CREAT | libc::O_APPEND | libc::O_WRONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC, + libc::O_CREAT + | libc::O_APPEND + | libc::O_WRONLY + | libc::O_NOFOLLOW + | libc::O_CLOEXEC, 0o600, ) }; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 8be9075..f3e3727 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -2,29 +2,29 @@ #![allow(dead_code)] -use std::net::SocketAddr; +use dashmap::DashMap; +use dashmap::mapref::entry::Entry; use std::collections::HashSet; use std::collections::hash_map::RandomState; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::net::SocketAddr; use std::net::{IpAddr, Ipv6Addr}; use std::sync::Arc; use std::sync::{Mutex, OnceLock}; -use std::hash::{BuildHasher, Hash, Hasher}; use std::time::{Duration, Instant}; -use dashmap::DashMap; -use dashmap::mapref::entry::Entry; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tracing::{debug, warn, trace}; +use tracing::{debug, trace, warn}; use zeroize::{Zeroize, Zeroizing}; -use crate::crypto::{sha256, AesCtr, SecureRandom}; -use rand::RngExt; +use crate::config::ProxyConfig; +use crate::crypto::{AesCtr, SecureRandom, sha256}; +use crate::error::{HandshakeResult, ProxyError}; use crate::protocol::constants::*; use crate::protocol::tls; -use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; -use crate::error::{ProxyError, HandshakeResult}; use crate::stats::ReplayChecker; -use crate::config::ProxyConfig; +use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::tls_front::{TlsFrontCache, emulator}; +use rand::RngExt; const ACCESS_SECRET_BYTES: usize = 16; static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); @@ -67,7 +67,8 @@ struct AuthProbeSaturationState { } static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); -static AUTH_PROBE_SATURATION_STATE: OnceLock>> = OnceLock::new(); +static AUTH_PROBE_SATURATION_STATE: OnceLock>> = + OnceLock::new(); static AUTH_PROBE_EVICTION_HASHER: OnceLock = OnceLock::new(); fn auth_probe_state_map() -> &'static DashMap { @@ -78,8 +79,8 @@ fn auth_probe_saturation_state() -> &'static Mutex std::sync::MutexGuard<'static, Option> { +fn auth_probe_saturation_state_lock() +-> std::sync::MutexGuard<'static, Option> { auth_probe_saturation_state() .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()) @@ -252,9 +253,7 @@ fn auth_probe_record_failure_with_state( match eviction_candidate { Some((_, current_fail, current_seen)) if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => - { - } + || (fail_streak == current_fail && last_seen >= current_seen) => {} _ => eviction_candidate = Some((key, fail_streak, last_seen)), } } @@ -284,9 +283,7 @@ fn auth_probe_record_failure_with_state( match eviction_candidate { Some((_, current_fail, current_seen)) if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => - { - } + || (fail_streak == current_fail && last_seen >= current_seen) => {} _ => eviction_candidate = Some((key, fail_streak, last_seen)), } if auth_probe_state_expired(entry.value(), now) { @@ -306,9 +303,7 @@ fn auth_probe_record_failure_with_state( match eviction_candidate { Some((_, current_fail, current_seen)) if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => - { - } + || (fail_streak == current_fail && last_seen >= current_seen) => {} _ => eviction_candidate = Some((key, fail_streak, last_seen)), } if auth_probe_state_expired(entry.value(), now) { @@ -539,13 +534,12 @@ pub struct HandshakeSuccess { /// Decryption key and IV (for reading from client) pub dec_key: [u8; 32], pub dec_iv: u128, - /// Encryption key and IV (for writing to client) + /// Encryption key and IV (for writing to client) pub enc_key: [u8; 32], pub enc_iv: u128, /// Client address pub peer: SocketAddr, /// Whether TLS was used - pub is_tls: bool, } @@ -603,7 +597,7 @@ where auth_probe_record_failure(peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; debug!( - peer = %peer, + peer = %peer, ignore_time_skew = config.access.ignore_time_skew, "TLS handshake validation failed - no matching user or time skew" ); @@ -769,7 +763,6 @@ where let decoded_users = decode_user_secrets(config, preferred_user); for (user, secret) in decoded_users { - let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; @@ -820,12 +813,12 @@ where let encryptor = AesCtr::new(&enc_key, enc_iv); -// Apply replay tracking only after successful authentication. - // - // This ordering prevents an attacker from producing invalid handshakes that - // still collide with a valid handshake's replay slot and thus evict a valid - // entry from the cache. We accept the cost of performing the full - // authentication check first to avoid poisoning the replay cache. + // Apply replay tracking only after successful authentication. + // + // This ordering prevents an attacker from producing invalid handshakes that + // still collide with a valid handshake's replay slot and thus evict a valid + // entry from the cache. We accept the cost of performing the full + // authentication check first to avoid poisoning the replay cache. if replay_checker.check_and_add_handshake(dec_prekey_iv) { auth_probe_record_failure(peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; @@ -872,7 +865,7 @@ where /// Generate nonce for Telegram connection pub fn generate_tg_nonce( - proto_tag: ProtoTag, + proto_tag: ProtoTag, dc_idx: i16, client_enc_key: &[u8; 32], client_enc_iv: u128, @@ -885,13 +878,19 @@ pub fn generate_tg_nonce( continue; }; - if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } + if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { + continue; + } let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; - if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } + if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { + continue; + } let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]]; - if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } + if RESERVED_NONCE_CONTINUES.contains(&continue_four) { + continue; + } nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); // CRITICAL: write dc_idx so upstream DC knows where to route @@ -942,7 +941,7 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, A let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut encryptor = AesCtr::new(&enc_key, enc_iv); - let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 + let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 let mut result = nonce[..PROTO_TAG_POS].to_vec(); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index d647a3a..adbb3ad 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -1,19 +1,19 @@ //! Masking - forward unrecognized traffic to mask host -use std::str; -use std::net::SocketAddr; -use std::time::Duration; -use rand::{Rng, RngExt}; -use tokio::net::TcpStream; -#[cfg(unix)] -use tokio::net::UnixStream; -use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tokio::time::{Instant, timeout}; -use tracing::debug; use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; +use rand::{Rng, RngExt}; +use std::net::SocketAddr; +use std::str; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::time::{Instant, timeout}; +use tracing::debug; #[cfg(not(test))] const MASK_TIMEOUT: Duration = Duration::from_secs(5); @@ -98,8 +98,7 @@ async fn maybe_write_shape_padding( cap: usize, above_cap_blur: bool, above_cap_blur_max_bytes: usize, -) -where +) where W: AsyncWrite + Unpin, { if !enabled { @@ -167,7 +166,10 @@ async fn consume_client_data_with_timeout(reader: R) where R: AsyncRead + Unpin, { - if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() { + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) + .await + .is_err() + { debug!("Timed out while consuming client data on masking fallback path"); } } @@ -213,9 +215,12 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request if data.len() > 4 - && (data.starts_with(b"GET ") || data.starts_with(b"POST") || - data.starts_with(b"HEAD") || data.starts_with(b"PUT ") || - data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS")) + && (data.starts_with(b"GET ") + || data.starts_with(b"POST") + || data.starts_with(b"HEAD") + || data.starts_with(b"PUT ") + || data.starts_with(b"DELETE") + || data.starts_with(b"OPTIONS")) { return "HTTP"; } @@ -252,16 +257,12 @@ fn build_mask_proxy_header( ), _ => { let header = match (peer, local_addr) { - (SocketAddr::V4(src), SocketAddr::V4(dst)) => { - ProxyProtocolV1Builder::new() - .tcp4(src.into(), dst.into()) - .build() - } - (SocketAddr::V6(src), SocketAddr::V6(dst)) => { - ProxyProtocolV1Builder::new() - .tcp6(src.into(), dst.into()) - .build() - } + (SocketAddr::V4(src), SocketAddr::V4(dst)) => ProxyProtocolV1Builder::new() + .tcp4(src.into(), dst.into()) + .build(), + (SocketAddr::V6(src), SocketAddr::V6(dst)) => ProxyProtocolV1Builder::new() + .tcp6(src.into(), dst.into()) + .build(), _ => ProxyProtocolV1Builder::new().build(), }; Some(header) @@ -278,8 +279,7 @@ pub async fn handle_bad_client( local_addr: SocketAddr, config: &ProxyConfig, beobachten: &BeobachtenStore, -) -where +) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, { @@ -311,8 +311,11 @@ where match connect_result { Ok(Ok(stream)) => { let (mask_read, mut mask_write) = stream.into_split(); - let proxy_header = - build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); + let proxy_header = build_mask_proxy_header( + config.censorship.mask_proxy_protocol, + peer, + local_addr, + ); if let Some(header) = proxy_header { if !write_proxy_header_with_timeout(&mut mask_write, &header).await { wait_mask_outcome_budget(outcome_started, config).await; @@ -356,7 +359,10 @@ where return; } - let mask_host = config.censorship.mask_host.as_deref() + let mask_host = config + .censorship + .mask_host + .as_deref() .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; @@ -435,8 +441,7 @@ async fn relay_to_mask( shape_bucket_cap_bytes: usize, shape_above_cap_blur: bool, shape_above_cap_blur_max_bytes: usize, -) -where +) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, MR: AsyncRead + Unpin + Send + 'static, @@ -455,9 +460,8 @@ where let copied = copy_with_idle_timeout(&mut reader, &mut mask_write).await; let total_sent = initial_data.len().saturating_add(copied.total); - let should_shape = shape_hardening_enabled - && copied.ended_by_eof - && !initial_data.is_empty(); + let should_shape = + shape_hardening_enabled && copied.ended_by_eof && !initial_data.is_empty(); maybe_write_shape_padding( &mut mask_write, diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 2000977..21fda15 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -9,17 +9,17 @@ use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex}; +use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; -use crate::protocol::constants::{*, secure_padding_len}; +use crate::protocol::constants::{secure_padding_len, *}; use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::route_mode::{ - RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state, + ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; use crate::stats::Stats; @@ -503,8 +503,7 @@ fn report_desync_frame_too_large( ProxyError::Proxy(format!( "Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}", - state.conn_id, - state.trace_id + state.conn_id, state.trace_id )) } @@ -629,11 +628,9 @@ where stats.increment_user_connects(&user); let _me_connection_lease = stats.acquire_me_connection_lease(); - if let Some(cutover) = affected_cutover_state( - &route_rx, - RelayRouteMode::Middle, - route_snapshot.generation, - ) { + if let Some(cutover) = + affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) + { let delay = cutover_stagger_delay(session_id, cutover.generation); warn!( conn_id, @@ -695,15 +692,17 @@ where while let Some(cmd) = c2me_rx.recv().await { match cmd { C2MeCommand::Data { payload, flags } => { - me_pool_c2me.send_proxy_req( - conn_id, - success.dc_idx, - peer, - translated_local_addr, - payload.as_ref(), - flags, - effective_tag.as_deref(), - ).await?; + me_pool_c2me + .send_proxy_req( + conn_id, + success.dc_idx, + peer, + translated_local_addr, + payload.as_ref(), + flags, + effective_tag.as_deref(), + ) + .await?; sent_since_yield = sent_since_yield.saturating_add(1); if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { sent_since_yield = 0; @@ -916,7 +915,11 @@ where let mut seen_pressure_seq = relay_pressure_event_seq(); loop { if relay_idle_policy.enabled - && maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen_pressure_seq, stats.as_ref()) + && maybe_evict_idle_candidate_on_pressure( + conn_id, + &mut seen_pressure_seq, + stats.as_ref(), + ) { info!( conn_id, @@ -931,11 +934,9 @@ where break; } - if let Some(cutover) = affected_cutover_state( - &route_rx, - RelayRouteMode::Middle, - route_snapshot.generation, - ) { + if let Some(cutover) = + affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) + { let delay = cutover_stagger_delay(session_id, cutover.generation); warn!( conn_id, @@ -1102,7 +1103,8 @@ where return deadline; } - let downstream_at = session_started_at + Duration::from_millis(last_downstream_activity_ms); + let downstream_at = + session_started_at + Duration::from_millis(last_downstream_activity_ms); if downstream_at > idle_state.last_client_frame_at { let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; if grace_deadline > deadline { @@ -1117,12 +1119,8 @@ where let timeout_window = if idle_policy.enabled { let now = Instant::now(); let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); - let hard_deadline = hard_deadline( - idle_policy, - idle_state, - session_started_at, - downstream_ms, - ); + let hard_deadline = + hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); if now >= hard_deadline { clear_relay_idle_candidate(forensics.conn_id); stats.increment_relay_idle_hard_close_total(); @@ -1130,7 +1128,9 @@ where .saturating_duration_since(idle_state.last_client_frame_at) .as_secs(); let downstream_idle_secs = now - .saturating_duration_since(session_started_at + Duration::from_millis(downstream_ms)) + .saturating_duration_since( + session_started_at + Duration::from_millis(downstream_ms), + ) .as_secs(); warn!( trace_id = format_args!("0x{:016x}", forensics.trace_id), @@ -1204,7 +1204,9 @@ where Err(_) if !idle_policy.enabled => { return Err(ProxyError::Io(std::io::Error::new( std::io::ErrorKind::TimedOut, - format!("middle-relay client frame read timeout while reading {read_label}"), + format!( + "middle-relay client frame read timeout while reading {read_label}" + ), ))); } Err(_) => {} @@ -1470,15 +1472,8 @@ where user: user.to_string(), }); } - write_client_payload( - client_writer, - proto_tag, - flags, - &data, - rng, - frame_buf, - ) - .await?; + write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await?; bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); stats.add_user_octets_to(user, data.len() as u64); @@ -1489,15 +1484,8 @@ where }); } } else { - write_client_payload( - client_writer, - proto_tag, - flags, - &data, - rng, - frame_buf, - ) - .await?; + write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await?; bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); stats.add_user_octets_to(user, data.len() as u64); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index ab840f6..3db6000 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -6,8 +6,8 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; -pub mod route_mode; pub mod relay; +pub mod route_mode; pub mod session_eviction; pub use client::ClientHandler; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 88a8bd5..2431ff4 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -51,21 +51,19 @@ //! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to` //! - `SharedCounters` (atomics) let the watchdog read stats without locking -use std::io; -use std::pin::Pin; -use std::sync::{Arc, Mutex, OnceLock}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::task::{Context, Poll}; -use std::time::Duration; -use dashmap::DashMap; -use tokio::io::{ - AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes, -}; -use tokio::time::Instant; -use tracing::{debug, trace, warn}; use crate::error::{ProxyError, Result}; use crate::stats::Stats; use crate::stream::BufferPool; +use dashmap::DashMap; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; +use tokio::time::Instant; +use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -251,7 +249,8 @@ impl StatsIo { impl Drop for StatsIo { fn drop(&mut self) { self.quota_read_retry_active.store(false, Ordering::Relaxed); - self.quota_write_retry_active.store(false, Ordering::Relaxed); + self.quota_write_retry_active + .store(false, Ordering::Relaxed); } } @@ -428,7 +427,9 @@ impl AsyncRead for StatsIo { } // C→S: client sent data - this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed); + this.counters + .c2s_bytes + .fetch_add(n as u64, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); @@ -467,7 +468,8 @@ impl AsyncWrite for StatsIo { match lock.try_lock() { Ok(guard) => { this.quota_write_wake_scheduled = false; - this.quota_write_retry_active.store(false, Ordering::Relaxed); + this.quota_write_retry_active + .store(false, Ordering::Relaxed); Some(guard) } Err(_) => { @@ -509,7 +511,9 @@ impl AsyncWrite for StatsIo { Poll::Ready(Ok(n)) => { if n > 0 { // S→C: data written to client - this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed); + this.counters + .s2c_bytes + .fetch_add(n as u64, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); @@ -786,4 +790,4 @@ mod relay_quota_waker_storm_adversarial_tests; #[cfg(test)] #[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] -mod relay_quota_wake_liveness_regression_tests; \ No newline at end of file +mod relay_quota_wake_liveness_regression_tests; diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index e2232d2..5aa7e91 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -119,9 +119,7 @@ pub(crate) fn affected_cutover_state( } pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration { - let mut value = session_id - ^ generation.rotate_left(17) - ^ 0x9e37_79b9_7f4a_7c15; + let mut value = session_id ^ generation.rotate_left(17) ^ 0x9e37_79b9_7f4a_7c15; value ^= value >> 30; value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9); value ^= value >> 27; diff --git a/src/proxy/tests/client_adversarial_tests.rs b/src/proxy/tests/client_adversarial_tests.rs index 0e780e3..5bc90bc 100644 --- a/src/proxy/tests/client_adversarial_tests.rs +++ b/src/proxy/tests/client_adversarial_tests.rs @@ -1,11 +1,11 @@ use super::*; use crate::config::ProxyConfig; -use crate::stats::Stats; -use crate::ip_tracker::UserIpTracker; use crate::error::ProxyError; +use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; // ------------------------------------------------------------------ // Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6) @@ -15,13 +15,16 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; async fn client_stress_10k_connections_limit_strict() { let user = "stress-user"; let limit = 512; - + let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); - + let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), limit); - + config + .access + .user_max_tcp_conns + .insert(user.to_string(), limit); + let iterations = 1000; let mut tasks = Vec::new(); @@ -30,20 +33,18 @@ async fn client_stress_10k_connections_limit_strict() { let ip_tracker = Arc::clone(&ip_tracker); let config = config.clone(); let user_str = user.to_string(); - + tasks.push(tokio::spawn(async move { let peer = SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)), 10000 + (i % 1000) as u16, ); - + match RunningClientHandler::acquire_user_connection_reservation_static( - &user_str, - &config, - stats, - peer, - ip_tracker, - ).await { + &user_str, &config, stats, peer, ip_tracker, + ) + .await + { Ok(res) => Ok(res), Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()), Err(e) => panic!("Unexpected error: {:?}", e), @@ -67,15 +68,27 @@ async fn client_stress_10k_connections_limit_strict() { } assert_eq!(successes, limit, "Should allow exactly 'limit' connections"); - assert_eq!(failures, iterations - limit, "Should fail the rest with LimitExceeded"); + assert_eq!( + failures, + iterations - limit, + "Should fail the rest with LimitExceeded" + ); assert_eq!(stats.get_user_curr_connects(user), limit as u64); drop(reservations); - + ip_tracker.drain_cleanup_queue().await; - - assert_eq!(stats.get_user_curr_connects(user), 0, "Stats must converge to 0 after all drops"); - assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP tracker must converge to 0"); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "Stats must converge to 0 after all drops" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "IP tracker must converge to 0" + ); } // ------------------------------------------------------------------ @@ -87,14 +100,14 @@ async fn client_ip_tracker_race_condition_stress() { let user = "race-user"; let ip_tracker = Arc::new(UserIpTracker::new()); ip_tracker.set_user_limit(user, 100).await; - + let iterations = 1000; let mut tasks = Vec::new(); for i in 0..iterations { let ip_tracker = Arc::clone(&ip_tracker); let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8)); - + tasks.push(tokio::spawn(async move { for _ in 0..10 { if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await { @@ -105,8 +118,12 @@ async fn client_ip_tracker_race_condition_stress() { } futures::future::join_all(tasks).await; - - assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP count must be zero after balanced add/remove burst"); + + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "IP count must be zero after balanced add/remove burst" + ); } #[tokio::test] @@ -119,7 +136,10 @@ async fn client_limit_burst_peak_never_exceeds_cap() { let ip_tracker = Arc::new(UserIpTracker::new()); let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), limit); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), limit); let peak = Arc::new(AtomicU64::new(0)); let mut tasks = Vec::with_capacity(attempts); @@ -207,10 +227,10 @@ async fn client_expiration_rejection_never_mutates_live_counters() { let ip_tracker = Arc::new(UserIpTracker::new()); let mut config = ProxyConfig::default(); - config - .access - .user_expirations - .insert(user.to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); + config.access.user_expirations.insert( + user.to_string(), + chrono::Utc::now() - chrono::Duration::seconds(1), + ); let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap(); let res = RunningClientHandler::acquire_user_connection_reservation_static( @@ -235,7 +255,10 @@ async fn client_ip_limit_failure_rolls_back_counter_exactly() { ip_tracker.set_user_limit(user, 1).await; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 16); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 16); let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap(); let first = RunningClientHandler::acquire_user_connection_reservation_static( @@ -258,7 +281,10 @@ async fn client_ip_limit_failure_rolls_back_counter_exactly() { ) .await; - assert!(matches!(second, Err(ProxyError::ConnectionLimitExceeded { .. }))); + assert!(matches!( + second, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); assert_eq!(stats.get_user_curr_connects(user), 1); drop(first); @@ -276,7 +302,10 @@ async fn client_parallel_limit_checks_success_path_leaves_no_residue() { ip_tracker.set_user_limit(user, 128).await; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 128); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 128); let mut tasks = Vec::new(); for i in 0..128u16 { @@ -310,7 +339,10 @@ async fn client_parallel_limit_checks_failure_path_leaves_no_residue() { ip_tracker.set_user_limit(user, 0).await; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 512); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 512); let mut tasks = Vec::new(); for i in 0..64u16 { @@ -319,7 +351,10 @@ async fn client_parallel_limit_checks_failure_path_leaves_no_residue() { let config = config.clone(); tasks.push(tokio::spawn(async move { - let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)), 33000 + i); + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)), + 33000 + i, + ); RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) .await })); @@ -360,11 +395,7 @@ async fn client_churn_mixed_success_failure_converges_to_zero_state() { 34000 + (i % 32), ); let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await; @@ -401,11 +432,7 @@ async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one() let config = config.clone(); tasks.push(tokio::spawn(async move { RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await })); @@ -424,7 +451,10 @@ async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one() } } - assert_eq!(granted, 1, "only one reservation may be granted for same IP with limit=1"); + assert_eq!( + granted, 1, + "only one reservation may be granted for same IP with limit=1" + ); drop(reservations); ip_tracker.drain_cleanup_queue().await; assert_eq!(stats.get_user_curr_connects(user), 0); @@ -439,7 +469,10 @@ async fn client_repeat_acquire_release_cycles_never_accumulate_state() { ip_tracker.set_user_limit(user, 32).await; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 32); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 32); for i in 0..500u16 { let peer = SocketAddr::new( @@ -484,11 +517,7 @@ async fn client_multi_user_isolation_under_parallel_limit_exhaustion() { 37000 + i, ); RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await })); @@ -497,7 +526,11 @@ async fn client_multi_user_isolation_under_parallel_limit_exhaustion() { let mut u1_success = 0usize; let mut u2_success = 0usize; let mut reservations = Vec::new(); - for (idx, result) in futures::future::join_all(tasks).await.into_iter().enumerate() { + for (idx, result) in futures::future::join_all(tasks) + .await + .into_iter() + .enumerate() + { let user = if idx % 2 == 0 { "u1" } else { "u2" }; match result.unwrap() { Ok(reservation) => { @@ -556,7 +589,10 @@ async fn client_limit_recovery_after_full_rejection_wave() { ip_tracker.clone(), ) .await; - assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. }))); + assert!(matches!( + denied, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); } drop(reservation); @@ -572,7 +608,10 @@ async fn client_limit_recovery_after_full_rejection_wave() { ip_tracker.clone(), ) .await; - assert!(recovered.is_ok(), "capacity must recover after prior holder drops"); + assert!( + recovered.is_ok(), + "capacity must recover after prior holder drops" + ); } #[tokio::test] @@ -619,7 +658,10 @@ async fn client_dual_limit_cross_product_never_leaks_on_reject() { ip_tracker.clone(), ) .await; - assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. }))); + assert!(matches!( + denied, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); } assert_eq!(stats.get_user_curr_connects(user), 2); @@ -637,7 +679,10 @@ async fn client_check_user_limits_concurrent_churn_no_counter_drift() { ip_tracker.set_user_limit(user, 64).await; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 64); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 64); let mut tasks = Vec::new(); for i in 0..512u16 { diff --git a/src/proxy/tests/client_masking_blackhat_campaign_tests.rs b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs index 3ea9dae..88d4a58 100644 --- a/src/proxy/tests/client_masking_blackhat_campaign_tests.rs +++ b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs @@ -2,17 +2,14 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{ - HANDSHAKE_LEN, - MAX_TLS_PLAINTEXT_SIZE, - MIN_TLS_CLIENT_HELLO_SIZE, - TLS_RECORD_APPLICATION, + HANDSHAKE_LEN, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE, TLS_RECORD_APPLICATION, TLS_VERSION, }; use crate::protocol::tls; use std::collections::HashSet; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant}; @@ -79,7 +76,10 @@ fn build_mask_harness(secret_hex: &str, mask_port: u16) -> CampaignHarness { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; @@ -171,7 +171,10 @@ async fn run_tls_success_mtproto_fail_capture( client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; @@ -427,7 +430,10 @@ async fn blackhat_campaign_06_replayed_tls_hello_is_masked_without_serverhello() client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&first_tail).await.unwrap(); } else { let mut one = [0u8; 1]; @@ -697,13 +703,15 @@ async fn blackhat_campaign_12_parallel_tls_success_mtproto_fail_sessions_keep_is let mut tasks = Vec::new(); for i in 0..sessions { - let mut harness = build_mask_harness("abababababababababababababababab", backend_addr.port()); + let mut harness = + build_mask_harness("abababababababababababababababab", backend_addr.port()); let mut cfg = (*harness.config).clone(); cfg.censorship.mask_port = backend_addr.port(); harness.config = Arc::new(cfg); tasks.push(tokio::spawn(async move { let secret = [0xABu8; 16]; - let hello = make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10)); + let hello = + make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10)); let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]); let (server_side, mut client_side) = duplex(131072); @@ -843,12 +851,12 @@ async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() { tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024); } - let body_to_send = if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) - { - (seed as usize % 29).min(tls_len.saturating_sub(1)) - } else { - 0 - }; + let body_to_send = + if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) { + (seed as usize % 29).min(tls_len.saturating_sub(1)) + } else { + 0 + }; let mut probe = vec![0u8; 5 + body_to_send]; probe[0] = 0x16; @@ -856,7 +864,9 @@ async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() { probe[2] = 0x01; probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); for b in &mut probe[5..] { - seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + seed = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); *b = (seed >> 24) as u8; } @@ -879,7 +889,8 @@ async fn blackhat_campaign_16_mixed_probe_burst_stress_finishes_without_panics() probe[2] = 0x01; probe[3..5].copy_from_slice(&600u16.to_be_bytes()); probe[5..].fill((0x90 + i as u8) ^ 0x5A); - run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe).await; + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe) + .await; } else { let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8]; run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await; diff --git a/src/proxy/tests/client_masking_budget_security_tests.rs b/src/proxy/tests/client_masking_budget_security_tests.rs index 8dcf114..d98c780 100644 --- a/src/proxy/tests/client_masking_budget_security_tests.rs +++ b/src/proxy/tests/client_masking_budget_security_tests.rs @@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; use crate::protocol::tls; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant}; @@ -55,7 +55,10 @@ fn build_harness(config: ProxyConfig) -> PipelineHarness { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; @@ -150,7 +153,10 @@ async fn masking_runs_outside_handshake_timeout_budget_with_high_reject_delay() .unwrap() .unwrap(); - assert!(result.is_ok(), "bad-client fallback must not be canceled by handshake timeout"); + assert!( + result.is_ok(), + "bad-client fallback must not be canceled by handshake timeout" + ); assert_eq!( stats.get_handshake_timeouts(), 0, @@ -175,10 +181,10 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend( config.censorship.mask_port = backend_addr.port(); config.censorship.mask_proxy_protocol = 0; config.access.ignore_time_skew = true; - config - .access - .users - .insert("user".to_string(), "d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string()); + config.access.users.insert( + "user".to_string(), + "d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string(), + ); let harness = build_harness(config); @@ -194,8 +200,7 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend( let mut got = vec![0u8; expected_trailing.len()]; stream.read_exact(&mut got).await.unwrap(); assert_eq!( - got, - expected_trailing, + got, expected_trailing, "mask backend must receive only post-handshake trailing TLS records" ); }); @@ -223,11 +228,17 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend( client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) diff --git a/src/proxy/tests/client_masking_diagnostics_security_tests.rs b/src/proxy/tests/client_masking_diagnostics_security_tests.rs index 1d069c6..0d9ca99 100644 --- a/src/proxy/tests/client_masking_diagnostics_security_tests.rs +++ b/src/proxy/tests/client_masking_diagnostics_security_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant}; @@ -163,21 +163,36 @@ async fn diagnostic_timing_profiles_are_within_realistic_guardrails() { ); assert!(p50 >= 650, "p50 too low for delayed reject class={}", class); - assert!(p95 <= 1200, "p95 too high for delayed reject class={}", class); - assert!(max <= 1500, "max too high for delayed reject class={}", class); + assert!( + p95 <= 1200, + "p95 too high for delayed reject class={}", + class + ); + assert!( + max <= 1500, + "max too high for delayed reject class={}", + class + ); } } #[tokio::test] async fn diagnostic_forwarded_size_profiles_by_probe_class() { - let classes = [0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize]; + let classes = [ + 0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize, + ]; let mut observed = Vec::new(); for class in classes { let len = capture_forwarded_len(class).await; println!("diagnostic_shape class={} forwarded_len={}", class, len); observed.push(len as u128); - assert_eq!(len, 5 + class, "unexpected forwarded len for class={}", class); + assert_eq!( + len, + 5 + class, + "unexpected forwarded len for class={}", + class + ); } let p50 = percentile_ms(observed.clone(), 50, 100); diff --git a/src/proxy/tests/client_masking_hard_adversarial_tests.rs b/src/proxy/tests/client_masking_hard_adversarial_tests.rs index cdaede5..65e66d3 100644 --- a/src/proxy/tests/client_masking_hard_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_hard_adversarial_tests.rs @@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; use crate::protocol::tls; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant}; @@ -70,7 +70,10 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> Harness { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; @@ -158,11 +161,17 @@ async fn run_tls_success_mtproto_fail_capture( client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); read_tls_record_body(&mut client_side, tls_response_head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); for record in trailing_records { client_side.write_all(&record).await.unwrap(); } @@ -330,7 +339,10 @@ async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() { client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); read_tls_record_body(&mut client_side, head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&first_tail).await.unwrap(); } else { let mut one = [0u8; 1]; @@ -402,7 +414,10 @@ async fn connects_bad_increments_once_per_invalid_mtproto() { let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); read_tls_record_body(&mut client_side, head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&tail).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -625,7 +640,8 @@ async fn concurrent_tls_mtproto_fail_sessions_are_isolated() { for idx in 0..sessions { let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"; let harness = build_harness(secret_hex, backend_addr.port()); - let hello = make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8); + let hello = + make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8); let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]); let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16) @@ -685,17 +701,67 @@ macro_rules! tail_length_case { *b = (i as u8).wrapping_mul(17).wrapping_add(5); } let record = wrap_tls_application_data(&payload); - let got = run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()]).await; + let got = + run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()]) + .await; assert_eq!(got, record); } }; } -tail_length_case!(tail_len_1_preserved, "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", [0xD1; 16], 30, 1); -tail_length_case!(tail_len_2_preserved, "d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2", [0xD2; 16], 31, 2); -tail_length_case!(tail_len_3_preserved, "d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", [0xD3; 16], 32, 3); -tail_length_case!(tail_len_7_preserved, "d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4", [0xD4; 16], 33, 7); -tail_length_case!(tail_len_31_preserved, "d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5", [0xD5; 16], 34, 31); -tail_length_case!(tail_len_127_preserved, "d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6", [0xD6; 16], 35, 127); -tail_length_case!(tail_len_511_preserved, "d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7", [0xD7; 16], 36, 511); -tail_length_case!(tail_len_1023_preserved, "d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8", [0xD8; 16], 37, 1023); +tail_length_case!( + tail_len_1_preserved, + "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", + [0xD1; 16], + 30, + 1 +); +tail_length_case!( + tail_len_2_preserved, + "d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2", + [0xD2; 16], + 31, + 2 +); +tail_length_case!( + tail_len_3_preserved, + "d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", + [0xD3; 16], + 32, + 3 +); +tail_length_case!( + tail_len_7_preserved, + "d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4", + [0xD4; 16], + 33, + 7 +); +tail_length_case!( + tail_len_31_preserved, + "d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5", + [0xD5; 16], + 34, + 31 +); +tail_length_case!( + tail_len_127_preserved, + "d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6", + [0xD6; 16], + 35, + 127 +); +tail_length_case!( + tail_len_511_preserved, + "d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7", + [0xD7; 16], + 36, + 511 +); +tail_length_case!( + tail_len_1023_preserved, + "d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8", + [0xD8; 16], + 37, + 1023 +); diff --git a/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs index 1208071..f7229ce 100644 --- a/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs +++ b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs @@ -5,7 +5,7 @@ use rand::{Rng, SeedableRng}; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; @@ -92,10 +92,13 @@ async fn run_generic_probe_and_capture_prefix(payload: Vec, expected_prefix: client_side.shutdown().await.unwrap(); let mut observed = vec![0u8; REPLY_404.len()]; - tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) - .await - .unwrap() - .unwrap(); + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); assert_eq!(observed, REPLY_404); let got = tokio::time::timeout(Duration::from_secs(2), accept_task) @@ -264,7 +267,8 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() { let mut expected = std::collections::HashSet::new(); for idx in 0..session_count { - let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + let probe = + format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); expected.insert(probe); } @@ -274,9 +278,15 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() { let (mut stream, _) = listener.accept().await.unwrap(); let head = read_http_probe_header(&mut stream).await; stream.write_all(REPLY_404).await.unwrap(); - assert!(remaining.remove(&head), "backend received unexpected or duplicated probe prefix"); + assert!( + remaining.remove(&head), + "backend received unexpected or duplicated probe prefix" + ); } - assert!(remaining.is_empty(), "all session prefixes must be observed exactly once"); + assert!( + remaining.is_empty(), + "all session prefixes must be observed exactly once" + ); }); let mut tasks = Vec::with_capacity(session_count); @@ -291,7 +301,8 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() { let ip_tracker = Arc::new(UserIpTracker::new()); let beobachten = Arc::new(BeobachtenStore::new()); - let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + let probe = + format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx) .parse() .unwrap(); @@ -319,10 +330,13 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() { client_side.shutdown().await.unwrap(); let mut observed = vec![0u8; REPLY_404.len()]; - tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) - .await - .unwrap() - .unwrap(); + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); assert_eq!(observed, REPLY_404); let result = tokio::time::timeout(Duration::from_secs(2), handler) diff --git a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs index 08d276d..50aa44c 100644 --- a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs +++ b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs @@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; use crate::protocol::tls; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant}; @@ -67,7 +67,10 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> RedTeamHarness { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; @@ -148,8 +151,14 @@ async fn run_tls_success_mtproto_fail_session( let mut body = vec![0u8; body_len]; client_side.read_exact(&mut body).await.unwrap(); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); - client_side.write_all(&wrap_tls_application_data(&tail)).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side + .write_all(&wrap_tls_application_data(&tail)) + .await + .unwrap(); let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -175,7 +184,10 @@ async fn redteam_01_backend_receives_no_data_after_mtproto_fail() { b"probe-a".to_vec(), ) .await; - assert!(forwarded.is_empty(), "backend unexpectedly received fallback bytes"); + assert!( + forwarded.is_empty(), + "backend unexpectedly received fallback bytes" + ); } #[tokio::test] @@ -188,7 +200,10 @@ async fn redteam_02_backend_must_never_receive_tls_records_after_mtproto_fail() b"probe-b".to_vec(), ) .await; - assert_ne!(forwarded[0], 0x17, "received TLS application record despite strict policy"); + assert_ne!( + forwarded[0], 0x17, + "received TLS application record despite strict policy" + ); } #[tokio::test] @@ -200,9 +215,10 @@ async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() { cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = 1; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "acacacacacacacacacacacacacacacac".to_string()); + cfg.access.users.insert( + "user".to_string(), + "acacacacacacacacacacacacacacacac".to_string(), + ); let harness = RedTeamHarness { config: Arc::new(cfg), @@ -261,7 +277,10 @@ async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() { .unwrap() .unwrap(); - assert!(started.elapsed() < Duration::from_millis(1), "fallback path took longer than 1ms"); + assert!( + started.elapsed() < Duration::from_millis(1), + "fallback path took longer than 1ms" + ); } macro_rules! redteam_tail_must_not_forward_case { @@ -283,18 +302,90 @@ macro_rules! redteam_tail_must_not_forward_case { }; } -redteam_tail_must_not_forward_case!(redteam_04_tail_len_1_not_forwarded, "adadadadadadadadadadadadadadadad", [0xAD; 16], 4, 1); -redteam_tail_must_not_forward_case!(redteam_05_tail_len_2_not_forwarded, "aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae", [0xAE; 16], 5, 2); -redteam_tail_must_not_forward_case!(redteam_06_tail_len_3_not_forwarded, "afafafafafafafafafafafafafafafaf", [0xAF; 16], 6, 3); -redteam_tail_must_not_forward_case!(redteam_07_tail_len_7_not_forwarded, "b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0", [0xB0; 16], 7, 7); -redteam_tail_must_not_forward_case!(redteam_08_tail_len_15_not_forwarded, "b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", [0xB1; 16], 8, 15); -redteam_tail_must_not_forward_case!(redteam_09_tail_len_63_not_forwarded, "b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", [0xB2; 16], 9, 63); -redteam_tail_must_not_forward_case!(redteam_10_tail_len_127_not_forwarded, "b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", [0xB3; 16], 10, 127); -redteam_tail_must_not_forward_case!(redteam_11_tail_len_255_not_forwarded, "b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", [0xB4; 16], 11, 255); -redteam_tail_must_not_forward_case!(redteam_12_tail_len_511_not_forwarded, "b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5", [0xB5; 16], 12, 511); -redteam_tail_must_not_forward_case!(redteam_13_tail_len_1023_not_forwarded, "b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6", [0xB6; 16], 13, 1023); -redteam_tail_must_not_forward_case!(redteam_14_tail_len_2047_not_forwarded, "b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7", [0xB7; 16], 14, 2047); -redteam_tail_must_not_forward_case!(redteam_15_tail_len_4095_not_forwarded, "b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8", [0xB8; 16], 15, 4095); +redteam_tail_must_not_forward_case!( + redteam_04_tail_len_1_not_forwarded, + "adadadadadadadadadadadadadadadad", + [0xAD; 16], + 4, + 1 +); +redteam_tail_must_not_forward_case!( + redteam_05_tail_len_2_not_forwarded, + "aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae", + [0xAE; 16], + 5, + 2 +); +redteam_tail_must_not_forward_case!( + redteam_06_tail_len_3_not_forwarded, + "afafafafafafafafafafafafafafafaf", + [0xAF; 16], + 6, + 3 +); +redteam_tail_must_not_forward_case!( + redteam_07_tail_len_7_not_forwarded, + "b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0", + [0xB0; 16], + 7, + 7 +); +redteam_tail_must_not_forward_case!( + redteam_08_tail_len_15_not_forwarded, + "b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", + [0xB1; 16], + 8, + 15 +); +redteam_tail_must_not_forward_case!( + redteam_09_tail_len_63_not_forwarded, + "b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", + [0xB2; 16], + 9, + 63 +); +redteam_tail_must_not_forward_case!( + redteam_10_tail_len_127_not_forwarded, + "b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", + [0xB3; 16], + 10, + 127 +); +redteam_tail_must_not_forward_case!( + redteam_11_tail_len_255_not_forwarded, + "b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", + [0xB4; 16], + 11, + 255 +); +redteam_tail_must_not_forward_case!( + redteam_12_tail_len_511_not_forwarded, + "b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5", + [0xB5; 16], + 12, + 511 +); +redteam_tail_must_not_forward_case!( + redteam_13_tail_len_1023_not_forwarded, + "b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6", + [0xB6; 16], + 13, + 1023 +); +redteam_tail_must_not_forward_case!( + redteam_14_tail_len_2047_not_forwarded, + "b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7", + [0xB7; 16], + 14, + 2047 +); +redteam_tail_must_not_forward_case!( + redteam_15_tail_len_4095_not_forwarded, + "b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8", + [0xB8; 16], + 15, + 4095 +); #[tokio::test] #[ignore = "red-team expected-fail: impossible indistinguishability envelope"] @@ -349,14 +440,13 @@ async fn redteam_16_timing_delta_between_paths_must_be_sub_1ms_under_concurrency let min = durations.iter().copied().min().unwrap(); let max = durations.iter().copied().max().unwrap(); - assert!(max - min <= Duration::from_millis(1), "timing spread too wide for strict anti-probing envelope"); + assert!( + max - min <= Duration::from_millis(1), + "timing spread too wide for strict anti-probing envelope" + ); } -async fn measure_invalid_probe_duration_ms( - delay_ms: u64, - tls_len: u16, - body_sent: usize, -) -> u128 { +async fn measure_invalid_probe_duration_ms(delay_ms: u64, tls_len: u16, body_sent: usize) -> u128 { let mut cfg = ProxyConfig::default(); cfg.general.beobachten = false; cfg.censorship.mask = true; @@ -501,7 +591,8 @@ macro_rules! redteam_timing_envelope_case { #[tokio::test] #[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"] async fn $name() { - let elapsed_ms = measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await; + let elapsed_ms = + measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await; assert!( elapsed_ms <= $max_ms, "timing envelope violated: elapsed={}ms, max={}ms", @@ -519,11 +610,9 @@ macro_rules! redteam_constant_shape_case { async fn $name() { let got = capture_forwarded_probe_len($tls_len, $body_sent).await; assert_eq!( - got, - $expected_len, + got, $expected_len, "fingerprint shape mismatch: got={} expected={} (strict constant-shape model)", - got, - $expected_len + got, $expected_len ); } }; diff --git a/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs index 5b5344d..3a01a69 100644 --- a/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs +++ b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::Duration; @@ -172,7 +172,10 @@ async fn redteam_fuzz_01_hardened_output_length_correlation_should_be_below_0_2( let y_hard: Vec = hardened.iter().map(|v| *v as f64).collect(); let corr_hard = pearson_corr(&x, &y_hard).abs(); - println!("redteam_fuzz corr_hardened={corr_hard:.4} samples={}", sizes.len()); + println!( + "redteam_fuzz corr_hardened={corr_hard:.4} samples={}", + sizes.len() + ); assert!( corr_hard < 0.2, @@ -234,9 +237,7 @@ async fn redteam_fuzz_03_hardened_signal_must_be_10x_lower_than_plain() { let corr_plain = pearson_corr(&x, &y_plain).abs(); let corr_hard = pearson_corr(&x, &y_hard).abs(); - println!( - "redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}" - ); + println!("redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}"); assert!( corr_hard <= corr_plain * 0.1, diff --git a/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs index 6ce57b3..48e94a5 100644 --- a/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::Duration; diff --git a/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs index a835d00..f91e687 100644 --- a/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs +++ b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant}; @@ -164,10 +164,7 @@ async fn redteam_shape_02_padding_tail_must_be_non_deterministic() { let cap = 4096usize; let got = run_probe_capture(17, 600, true, floor, cap).await; - assert!( - got.len() > 22, - "test requires padding tail to exist" - ); + assert!(got.len() > 22, "test requires padding tail to exist"); let tail = &got[22..]; assert!( @@ -194,7 +191,9 @@ async fn redteam_shape_03_exact_floor_input_should_not_be_fixed_point() { async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() { let floor = 512usize; let cap = 4096usize; - let classes = [17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize]; + let classes = [ + 17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize, + ]; let mut observed = Vec::new(); for body in classes { @@ -203,7 +202,10 @@ async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() { let first = observed[0]; for v in observed { - assert_eq!(v, first, "strict model expects one collapsed class across all sub-cap probes"); + assert_eq!( + v, first, + "strict model expects one collapsed class across all sub-cap probes" + ); } } diff --git a/src/proxy/tests/client_masking_shape_hardening_security_tests.rs b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs index f9c0f17..f2bec42 100644 --- a/src/proxy/tests/client_masking_shape_hardening_security_tests.rs +++ b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs @@ -1,7 +1,7 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::Duration; diff --git a/src/proxy/tests/client_masking_stress_adversarial_tests.rs b/src/proxy/tests/client_masking_stress_adversarial_tests.rs index 52e7da1..5c00c63 100644 --- a/src/proxy/tests/client_masking_stress_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_stress_adversarial_tests.rs @@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; use crate::protocol::tls; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::Duration; @@ -70,7 +70,10 @@ fn build_harness(mask_port: u16, secret_hex: &str) -> StressHarness { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; @@ -150,12 +153,8 @@ async fn run_parallel_tail_fallback_case( for idx in 0..sessions { let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0"); - let hello = make_valid_tls_client_hello( - &[0xE0; 16], - ts_base + idx as u32, - 600, - 0x40 + (idx as u8), - ); + let hello = + make_valid_tls_client_hello(&[0xE0; 16], ts_base + idx as u32, 600, 0x40 + (idx as u8)); let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3]; @@ -170,8 +169,8 @@ async fn run_parallel_tail_fallback_case( peer_ip_fourth, peer_port_base + idx as u16 ) - .parse() - .unwrap(); + .parse() + .unwrap(); tasks.push(tokio::spawn(async move { let (server_side, mut client_side) = duplex(262144); @@ -194,7 +193,10 @@ async fn run_parallel_tail_fallback_case( client_side.write_all(&hello).await.unwrap(); let mut server_hello_head = [0u8; 5]; - client_side.read_exact(&mut server_hello_head).await.unwrap(); + client_side + .read_exact(&mut server_hello_head) + .await + .unwrap(); assert_eq!(server_hello_head[0], 0x16); read_tls_record_body(&mut client_side, server_hello_head).await; diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 98e3cd1..aed6bc4 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use std::net::Ipv4Addr; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; #[test] @@ -49,25 +49,33 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { let stats = Arc::new(crate::stats::Stats::new()); let user = "sync-drop-user".to_string(); let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap(); - + ip_tracker.set_user_limit(&user, 1).await; ip_tracker.check_and_add(&user, ip).await.unwrap(); stats.increment_user_curr_connects(&user); - + assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); assert_eq!(stats.get_user_curr_connects(&user), 1); - - let reservation = UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); - + + let reservation = + UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); + // Drop the reservation synchronously without any tokio::spawn/await yielding! drop(reservation); - + // The IP is now inside the cleanup_queue, check that the queue has length 1 let queue_len = ip_tracker.cleanup_queue.lock().unwrap().len(); - assert_eq!(queue_len, 1, "Reservation drop must push directly to synchronized IP queue"); - - assert_eq!(stats.get_user_curr_connects(&user), 0, "Stats must decrement immediately"); - + assert_eq!( + queue_len, 1, + "Reservation drop must push directly to synchronized IP queue" + ); + + assert_eq!( + stats.get_user_curr_connects(&user), + 0, + "Stats must decrement immediately" + ); + ip_tracker.drain_cleanup_queue().await; assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0); } @@ -286,7 +294,10 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() { .await .expect("relay must terminate after cutover") .expect("relay task must not panic"); - assert!(relay_result.is_err(), "cutover must terminate direct relay session"); + assert!( + relay_result.is_err(), + "cutover must terminate direct relay session" + ); assert_eq!( stats.get_user_curr_connects(user), @@ -447,7 +458,12 @@ async fn stress_drop_without_release_converges_to_zero_user_and_ip_state() { let mut reservations = Vec::new(); for idx in 0..512u16 { let peer = std::net::SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::new(198, 51, (idx >> 8) as u8, (idx & 0xff) as u8)), + std::net::IpAddr::V4(std::net::Ipv4Addr::new( + 198, + 51, + (idx >> 8) as u8, + (idx & 0xff) as u8, + )), 30_000 + idx, ); let reservation = RunningClientHandler::acquire_user_connection_reservation_static( @@ -510,10 +526,15 @@ async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() { false, stats.clone(), )); - let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(128, std::time::Duration::from_secs(60))); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new( + 128, + std::time::Duration::from_secs(60), + )); let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); - let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); + let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new( + crate::proxy::route_mode::RelayRouteMode::Direct, + )); let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); @@ -581,10 +602,16 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load( false, stats.clone(), )); - let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(64, std::time::Duration::from_secs(60))); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new( + 64, + std::time::Duration::from_secs(60), + )); let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); - let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); + let route_runtime = + std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new( + crate::proxy::route_mode::RelayRouteMode::Direct, + )); let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); @@ -669,8 +696,16 @@ async fn reservation_limit_failure_does_not_leak_curr_connects_counter() { matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user), "second reservation must be rejected at the configured tcp-conns limit" ); - assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment"); - assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state"); + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "failed acquisition must not leak a counter increment" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed acquisition must not mutate IP tracker state" + ); first.release().await; ip_tracker.drain_cleanup_queue().await; @@ -1119,7 +1154,10 @@ async fn partial_tls_header_stall_triggers_handshake_timeout() { } fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec { - assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); let total_len = 5 + tls_len; let mut handshake = vec![0x42u8; total_len]; @@ -1140,7 +1178,8 @@ fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: digest[28 + i] ^= ts[i]; } - handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); handshake } @@ -1203,8 +1242,7 @@ fn make_valid_tls_client_hello_with_alpn( digest[28 + i] ^= ts[i]; } - record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] - .copy_from_slice(&digest); + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record } @@ -1233,9 +1271,10 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() { cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "11111111111111111111111111111111".to_string()); + cfg.access.users.insert( + "user".to_string(), + "11111111111111111111111111111111".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -1307,8 +1346,7 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() { let bad_after = stats_for_assert.get_connects_bad(); assert_eq!( - bad_before, - bad_after, + bad_before, bad_after, "Authenticated TLS path must not increment connects_bad" ); } @@ -1341,9 +1379,10 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "33333333333333333333333333333333".to_string()); + cfg.access.users.insert( + "user".to_string(), + "33333333333333333333333333333333".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -1394,7 +1433,10 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&tls_app_record).await.unwrap(); @@ -1443,9 +1485,10 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "44444444444444444444444444444444".to_string()); + cfg.access.users.insert( + "user".to_string(), + "44444444444444444444444444444444".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -1563,9 +1606,10 @@ async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() { cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.alpn_enforce = true; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + cfg.access.users.insert( + "user".to_string(), + "66666666666666666666666666666666".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -1654,9 +1698,10 @@ async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() { cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "77777777777777777777777777777777".to_string()); + cfg.access.users.insert( + "user".to_string(), + "77777777777777777777777777777777".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -1751,9 +1796,10 @@ async fn burst_invalid_tls_probes_are_masked_verbatim() { cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "88888888888888888888888888888888".to_string()); + cfg.access.users.insert( + "user".to_string(), + "88888888888888888888888888888888".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -1981,10 +2027,7 @@ async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { let user = "check-helper-user"; let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 1); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); let stats = Stats::new(); let ip_tracker = UserIpTracker::new(); @@ -1998,7 +2041,10 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio &ip_tracker, ) .await; - assert!(first.is_ok(), "first check-only limit validation must succeed"); + assert!( + first.is_ok(), + "first check-only limit validation must succeed" + ); let second = RunningClientHandler::check_user_limits_static( user, @@ -2008,7 +2054,10 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio &ip_tracker, ) .await; - assert!(second.is_ok(), "second check-only validation must not fail from leaked state"); + assert!( + second.is_ok(), + "second check-only validation must not fail from leaked state" + ); assert_eq!(stats.get_user_curr_connects(user), 0); assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); } @@ -2017,10 +2066,7 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio async fn stress_check_user_limits_static_success_never_leaks_state() { let user = "check-helper-stress-user"; let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 1); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); let stats = Stats::new(); let ip_tracker = UserIpTracker::new(); @@ -2039,7 +2085,10 @@ async fn stress_check_user_limits_static_success_never_leaks_state() { &ip_tracker, ) .await; - assert!(result.is_ok(), "check-only helper must remain leak-free under stress"); + assert!( + result.is_ok(), + "check-only helper must remain leak-free under stress" + ); } assert_eq!( @@ -2090,11 +2139,7 @@ async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak() 41000 + i as u16, ); let result = RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await; assert!(matches!( @@ -2130,10 +2175,7 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() { let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 4); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2171,10 +2213,7 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() { let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 4); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2204,10 +2243,7 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() { let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 4); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2248,10 +2284,7 @@ async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() { let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 4); + config.access.user_max_tcp_conns.insert(user.to_string(), 4); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2473,8 +2506,14 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup() let user_b = "abort-isolation-b"; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user_a.to_string(), 64); - config.access.user_max_tcp_conns.insert(user_b.to_string(), 64); + config + .access + .user_max_tcp_conns + .insert(user_a.to_string(), 64); + config + .access + .user_max_tcp_conns + .insert(user_b.to_string(), 64); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2595,10 +2634,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() { let ip_tracker = Arc::new(UserIpTracker::new()); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 1); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); config .dc_overrides .insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]); @@ -2661,7 +2697,10 @@ async fn relay_connect_error_releases_user_and_ip_before_return() { ) .await; - assert!(result.is_err(), "relay must fail when upstream DC is unreachable"); + assert!( + result.is_err(), + "relay must fail when upstream DC is unreachable" + ); assert_eq!( stats.get_user_curr_connects(user), 0, @@ -2680,10 +2719,7 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() { let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 8); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2743,10 +2779,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 8); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2802,7 +2835,10 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { #[tokio::test] async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { let mut config = ProxyConfig::default(); - config.access.user_data_quota.insert("user".to_string(), 1024); + config + .access + .user_data_quota + .insert("user".to_string(), 1024); let stats = Stats::new(); stats.add_user_octets_from("user", 1024); @@ -2838,10 +2874,10 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { #[tokio::test] async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() { let mut config = ProxyConfig::default(); - config - .access - .user_expirations - .insert("user".to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); + config.access.user_expirations.insert( + "user".to_string(), + chrono::Utc::now() - chrono::Duration::seconds(1), + ); let stats = Stats::new(); let ip_tracker = UserIpTracker::new(); @@ -2870,10 +2906,7 @@ async fn same_ip_second_reservation_succeeds_under_unique_ip_limit_one() { let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 8); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2914,10 +2947,7 @@ async fn second_distinct_ip_is_rejected_under_unique_ip_limit_one() { let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 8); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -2958,10 +2988,7 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() { let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 8); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -3005,10 +3032,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() { let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap(); let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 1); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -3043,11 +3067,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() { .expect("cross-thread cleanup must settle before reacquire check"); let reacquire = RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer_addr, - ip_tracker, + user, &config, stats, peer_addr, ip_tracker, ) .await; assert!( @@ -3113,10 +3133,7 @@ async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { .get_recent_ips_for_users(&["user".to_string()]) .await; assert!( - recent - .get("user") - .map(|ips| ips.is_empty()) - .unwrap_or(true), + recent.get("user").map(|ips| ips.is_empty()).unwrap_or(true), "Concurrent rejected attempts must not leave recent IP footprint" ); @@ -3150,11 +3167,7 @@ async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { 30000 + i, ); RunningClientHandler::acquire_user_connection_reservation_static( - "user", - &config, - stats, - peer, - ip_tracker, + "user", &config, stats, peer, ip_tracker, ) .await .ok() @@ -3769,9 +3782,10 @@ async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() { cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "55555555555555555555555555555555".to_string()); + cfg.access.users.insert( + "user".to_string(), + "55555555555555555555555555555555".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -3824,7 +3838,10 @@ async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() { client_side.write_all(&client_hello).await.unwrap(); let mut record_header = [0u8; 5]; client_side.read_exact(&mut record_header).await.unwrap(); - assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + assert_eq!( + record_header[0], 0x16, + "Valid max-length ClientHello must be accepted" + ); drop(client_side); let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) @@ -3865,9 +3882,10 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; - cfg.access - .users - .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + cfg.access.users.insert( + "user".to_string(), + "66666666666666666666666666666666".to_string(), + ); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); @@ -3938,7 +3956,10 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { let mut record_header = [0u8; 5]; client.read_exact(&mut record_header).await.unwrap(); - assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + assert_eq!( + record_header[0], 0x16, + "Valid max-length ClientHello must be accepted" + ); drop(client); @@ -3947,7 +3968,8 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { .unwrap() .unwrap(); - let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await; + let no_mask_connect = + tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await; assert!( no_mask_connect.is_err(), "Valid max-length ClientHello must not trigger mask fallback in ClientHandler path" @@ -4004,11 +4026,7 @@ async fn burst_acquire_distinct_ips( 55000 + i, ); RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await }); @@ -4190,11 +4208,7 @@ async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() { 54000 + i, ); RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await }); @@ -4228,10 +4242,7 @@ async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() { async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { let user: &'static str = "scheduled-attack-user"; let mut config = ProxyConfig::default(); - config - .access - .user_max_tcp_conns - .insert(user.to_string(), 6); + config.access.user_max_tcp_conns.insert(user.to_string(), 6); let config = Arc::new(config); let stats = Arc::new(Stats::new()); @@ -4240,7 +4251,10 @@ async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() let mut base = Vec::new(); for i in 0..5u16 { - let peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), 56000 + i); + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), + 56000 + i, + ); let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, &config, @@ -4288,15 +4302,8 @@ async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() .await .expect("window cleanup must settle to expected occupancy"); - let (wave2_success, wave2_fail) = burst_acquire_distinct_ips( - user, - config, - stats.clone(), - ip_tracker.clone(), - 132, - 32, - ) - .await; + let (wave2_success, wave2_fail) = + burst_acquire_distinct_ips(user, config, stats.clone(), ip_tracker.clone(), 132, 32).await; assert_eq!(wave2_success.len(), 1); assert_eq!(wave2_fail, 31); assert_eq!(stats.get_user_curr_connects(user), 5); diff --git a/src/proxy/tests/client_timing_profile_adversarial_tests.rs b/src/proxy/tests/client_timing_profile_adversarial_tests.rs index 134990e..69a9ff4 100644 --- a/src/proxy/tests/client_timing_profile_adversarial_tests.rs +++ b/src/proxy/tests/client_timing_profile_adversarial_tests.rs @@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; use std::net::SocketAddr; use std::time::{Duration, Instant}; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; @@ -135,10 +135,13 @@ async fn run_generic_once(class: ProbeClass) -> u128 { client_side.shutdown().await.unwrap(); let mut observed = vec![0u8; REPLY_404.len()]; - tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) - .await - .unwrap() - .unwrap(); + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); assert_eq!(observed, REPLY_404); tokio::time::timeout(Duration::from_secs(2), accept_task) diff --git a/src/proxy/tests/client_tls_clienthello_size_security_tests.rs b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs index e54791f..0c864e7 100644 --- a/src/proxy/tests/client_tls_clienthello_size_security_tests.rs +++ b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs @@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; use std::net::SocketAddr; use std::time::Duration; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; fn test_probe_for_len(len: usize) -> [u8; 5] { @@ -100,7 +100,10 @@ async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) { client_side.write_all(&probe).await.unwrap(); let mut observed = vec![0u8; backend_reply.len()]; client_side.read_exact(&mut observed).await.unwrap(); - assert_eq!(observed, backend_reply, "invalid TLS path must be masked as a real site"); + assert_eq!( + observed, backend_reply, + "invalid TLS path must be masked as a real site" + ); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) @@ -109,7 +112,11 @@ async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) { .unwrap(); accept_task.await.unwrap(); - let expected_bad = if expect_bad_increment { bad_before + 1 } else { bad_before }; + let expected_bad = if expect_bad_increment { + bad_before + 1 + } else { + bad_before + }; assert_eq!( stats.get_connects_bad(), expected_bad, @@ -187,7 +194,9 @@ fn tls_client_hello_len_bounds_stress_many_evaluations() { for _ in 0..100_000 { assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); - assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); + assert!(!tls_clienthello_len_in_bounds( + MIN_TLS_CLIENT_HELLO_SIZE - 1 + )); assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); } } diff --git a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs index 6ac02dd..79a8640 100644 --- a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs +++ b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs @@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType}; use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; use std::net::SocketAddr; use std::time::Duration; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; use tokio::time::sleep; @@ -48,7 +48,12 @@ fn truncated_in_range_record(actual_body_len: usize) -> Vec { out } -async fn write_fragmented(writer: &mut W, bytes: &[u8], chunks: &[usize], delay_ms: u64) { +async fn write_fragmented( + writer: &mut W, + bytes: &[u8], + chunks: &[usize], + delay_ms: u64, +) { let mut offset = 0usize; for &chunk in chunks { if offset >= bytes.len() { @@ -130,10 +135,13 @@ async fn run_blackhat_generic_fragmented_probe_should_mask( client_side.shutdown().await.unwrap(); let mut observed = vec![0u8; backend_reply.len()]; - tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) - .await - .unwrap() - .unwrap(); + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); assert_eq!(observed, backend_reply); tokio::time::timeout(Duration::from_secs(2), mask_accept_task) @@ -311,10 +319,13 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() { // Security expectation: even malformed in-range TLS should be masked. // This invariant must hold to avoid probe-distinguishable EOF/timeout behavior. let mut observed = vec![0u8; backend_reply.len()]; - tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) - .await - .unwrap() - .unwrap(); + tokio::time::timeout( + Duration::from_secs(2), + client_side.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); assert_eq!(observed, backend_reply); tokio::time::timeout(Duration::from_secs(2), mask_accept_task) diff --git a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs index 920c013..95e49f7 100644 --- a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs +++ b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs @@ -2,16 +2,11 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{ - HANDSHAKE_LEN, - MAX_TLS_CIPHERTEXT_SIZE, - TLS_RECORD_ALERT, - TLS_RECORD_APPLICATION, - TLS_RECORD_CHANGE_CIPHER, - TLS_RECORD_HANDSHAKE, - TLS_VERSION, + HANDSHAKE_LEN, MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION, + TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, }; use crate::protocol::tls; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; struct PipelineHarness { @@ -74,7 +69,10 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { - assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + assert!( + tls_len <= u16::MAX as usize, + "TLS length must fit into record header" + ); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; @@ -181,11 +179,17 @@ async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -246,10 +250,16 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -264,7 +274,11 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() { .unwrap(); let bad_after = stats_for_assert.get_connects_bad(); - assert_eq!(bad_after, bad_before + 1, "connects_bad must increase exactly once for invalid MTProto after valid TLS"); + assert_eq!( + bad_after, + bad_before + 1, + "connects_bad must increase exactly once for invalid MTProto after valid TLS" + ); } #[tokio::test] @@ -310,10 +324,16 @@ async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -372,10 +392,16 @@ async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -399,7 +425,8 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() { let backend_addr = listener.local_addr().unwrap(); let secret = [0x85u8; 16]; - let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8); + let client_hello = + make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); @@ -443,10 +470,16 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -498,7 +531,10 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { ); } - assert!(remaining.is_empty(), "all expected client sessions must be matched exactly once"); + assert!( + remaining.is_empty(), + "all expected client sessions must be matched exactly once" + ); }); let mut client_tasks = Vec::with_capacity(sessions); @@ -506,7 +542,8 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { for idx in 0..sessions { let harness = build_harness("86868686868686868686868686868686", backend_addr.port()); let secret = [0x86u8; 16]; - let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); + let client_hello = + make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let trailing_payload = vec![idx as u8; 64 + idx]; @@ -538,10 +575,16 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); drop(client_side); @@ -606,10 +649,16 @@ async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); for chunk in trailing_record.chunks(3) { client_side.write_all(chunk).await.unwrap(); @@ -669,10 +718,16 @@ async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); for b in trailing_record.iter().copied() { client_side.write_all(&[b]).await.unwrap(); } @@ -736,10 +791,16 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; let mut pos = 0usize; @@ -747,7 +808,10 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() { while pos < trailing_record.len() { let step = chaos[idx % chaos.len()]; let end = (pos + step).min(trailing_record.len()); - client_side.write_all(&trailing_record[pos..end]).await.unwrap(); + client_side + .write_all(&trailing_record[pos..end]) + .await + .unwrap(); pos = end; idx += 1; } @@ -809,10 +873,16 @@ async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order() client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&r1).await.unwrap(); client_side.write_all(&r2).await.unwrap(); client_side.write_all(&r3).await.unwrap(); @@ -848,7 +918,10 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend() let mut tail = [0u8; 1]; let n = stream.read(&mut tail).await.unwrap(); - assert_eq!(n, 0, "backend must observe EOF after client write half-close"); + assert_eq!( + n, 0, + "backend must observe EOF after client write half-close" + ); }); let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port()); @@ -874,10 +947,16 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend() client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); client_side.shutdown().await.unwrap(); @@ -938,11 +1017,17 @@ async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) @@ -994,10 +1079,16 @@ async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() { client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); let write_res = client_side.write_all(&trailing_record).await; assert!( write_res.is_ok() || write_res.is_err(), @@ -1068,10 +1159,16 @@ async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity() client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(5), accept_task) @@ -1152,7 +1249,10 @@ async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhel let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); } else { let mut one = [0u8; 1]; @@ -1241,10 +1341,16 @@ async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure() client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; for record in [&a, &b, &c] { @@ -1316,10 +1422,16 @@ async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_ve client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; - client_side.read_exact(&mut tls_response_head).await.unwrap(); + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); assert_eq!(tls_response_head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&ccs).await.unwrap(); client_side.write_all(&app).await.unwrap(); client_side.write_all(&alert).await.unwrap(); @@ -1372,7 +1484,10 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak() ); } - assert!(remaining.is_empty(), "all expected sessions must be consumed exactly once"); + assert!( + remaining.is_empty(), + "all expected sessions must be consumed exactly once" + ); }); let mut tasks = Vec::with_capacity(sessions); @@ -1413,7 +1528,10 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak() client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); for chunk in record.chunks((idx % 9) + 1) { client_side.write_all(chunk).await.unwrap(); } @@ -2520,7 +2638,10 @@ async fn blackhat_coalesced_tail_parallel_32_sessions_no_cross_bleed() { "session mixup detected in parallel-32 blackhat test" ); } - assert!(remaining.is_empty(), "all expected sessions must be consumed"); + assert!( + remaining.is_empty(), + "all expected sessions must be consumed" + ); }); let mut tasks = Vec::with_capacity(sessions); diff --git a/src/proxy/tests/direct_relay_business_logic_tests.rs b/src/proxy/tests/direct_relay_business_logic_tests.rs index 166518e..37f9897 100644 --- a/src/proxy/tests/direct_relay_business_logic_tests.rs +++ b/src/proxy/tests/direct_relay_business_logic_tests.rs @@ -5,7 +5,10 @@ use std::net::SocketAddr; #[test] fn business_scope_hint_accepts_exact_boundary_length() { let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN)); - assert_eq!(validated_scope_hint(&value), Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); + assert_eq!( + validated_scope_hint(&value), + Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + ); } #[test] @@ -24,7 +27,8 @@ fn business_known_dc_uses_ipv4_table_by_default() { #[test] fn business_negative_dc_maps_by_absolute_value() { let cfg = ProxyConfig::default(); - let resolved = get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value"); + let resolved = + get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value"); let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT); assert_eq!(resolved, expected); } @@ -45,7 +49,8 @@ fn business_unknown_dc_uses_configured_default_dc_when_in_range() { let mut cfg = ProxyConfig::default(); cfg.default_dc = Some(4); - let resolved = get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default"); + let resolved = + get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default"); let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT); assert_eq!(resolved, expected); } diff --git a/src/proxy/tests/direct_relay_common_mistakes_tests.rs b/src/proxy/tests/direct_relay_common_mistakes_tests.rs index ef40f37..8429449 100644 --- a/src/proxy/tests/direct_relay_common_mistakes_tests.rs +++ b/src/proxy/tests/direct_relay_common_mistakes_tests.rs @@ -12,7 +12,8 @@ fn common_invalid_override_entries_fallback_to_static_table() { vec!["bad-address".to_string(), "still-bad".to_string()], ); - let resolved = get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve"); + let resolved = + get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve"); let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); assert_eq!(resolved, expected); } @@ -25,7 +26,8 @@ fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it cfg.dc_overrides .insert("3".to_string(), vec!["203.0.113.203:443".to_string()]); - let resolved = get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists"); + let resolved = + get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists"); assert_eq!(resolved, "203.0.113.203:443".parse::().unwrap()); } diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index 7c3a51e..3a5ba78 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -15,7 +15,7 @@ use std::time::Duration; use tokio::io::AsyncReadExt; use tokio::io::duplex; use tokio::net::TcpListener; -use tokio::time::{timeout, Duration as TokioDuration}; +use tokio::time::{Duration as TokioDuration, timeout}; fn make_crypto_reader(reader: R) -> CryptoReader where @@ -79,7 +79,9 @@ fn unknown_dc_log_respects_distinct_limit() { #[test] fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { - let poisoned = Arc::new(std::sync::Mutex::new(std::collections::HashSet::::new())); + let poisoned = Arc::new(std::sync::Mutex::new( + std::collections::HashSet::::new(), + )); let poisoned_for_thread = poisoned.clone(); let _ = std::thread::spawn(move || { @@ -243,7 +245,10 @@ fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() { fs::create_dir_all(&base).expect("temp test directory must be creatable"); let candidate = base.join("unknown-dc.txt"); - let candidate_relative = format!("target/telemt-unknown-dc-log-{}/unknown-dc.txt", std::process::id()); + let candidate_relative = format!( + "target/telemt-unknown-dc-log-{}/unknown-dc.txt", + std::process::id() + ); let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) .expect("safe relative path with existing parent must be accepted"); @@ -325,7 +330,10 @@ fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-log-symlink-internal-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-log-symlink-internal-{}", + std::process::id() + )); let real_parent = base.join("real_parent"); fs::create_dir_all(&real_parent).expect("real parent dir must be creatable"); @@ -354,7 +362,10 @@ fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-log-symlink-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-log-symlink-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("symlink test directory must be creatable"); let symlink_parent = base.join("escape_link"); @@ -382,7 +393,10 @@ fn unknown_dc_log_path_revalidation_rejects_symlinked_target_escape() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-target-link-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-target-link-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("target-link base must be creatable"); let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id())); @@ -445,7 +459,10 @@ fn unknown_dc_open_append_rejects_broken_symlink_target_with_nofollow() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-broken-link-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-broken-link-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("broken-link base must be creatable"); let linked_target = base.join("unknown-dc.log"); @@ -470,7 +487,10 @@ fn adversarial_unknown_dc_open_append_symlink_flip_never_writes_outside_file() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-symlink-flip-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-symlink-flip-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("symlink-flip base must be creatable"); let outside = std::env::temp_dir().join(format!( @@ -530,7 +550,10 @@ fn stress_unknown_dc_open_append_regular_file_preserves_line_integrity() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-open-stress-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-open-stress-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("stress open base must be creatable"); let target = base.join("unknown-dc.log"); @@ -556,7 +579,10 @@ fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-safe-target-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-safe-target-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("safe target base must be creatable"); let target = base.join("unknown-dc.log"); @@ -566,8 +592,8 @@ fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() { "target/telemt-unknown-dc-safe-target-{}/unknown-dc.log", std::process::id() ); - let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) - .expect("safe candidate must sanitize"); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("safe candidate must sanitize"); assert!( unknown_dc_log_path_is_still_safe(&sanitized), "revalidation must allow safe existing regular files" @@ -579,7 +605,10 @@ fn unknown_dc_log_path_revalidation_rejects_deleted_parent_after_sanitize() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-vanish-parent-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-vanish-parent-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("vanish-parent base must be creatable"); let rel_candidate = format!( @@ -604,7 +633,10 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() { let parent = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-parent-swap-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-parent-swap-{}", + std::process::id() + )); fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); let rel_candidate = format!( @@ -633,7 +665,10 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() { let parent = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-check-open-race-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-check-open-race-{}", + std::process::id() + )); fs::create_dir_all(&parent).expect("check-open-race parent must be creatable"); let target = parent.join("unknown-dc.log"); @@ -642,8 +677,7 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() { "target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log", std::process::id() ); - let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) - .expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); assert!( unknown_dc_log_path_is_still_safe(&sanitized), @@ -675,7 +709,10 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-parent-swap-openat-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-parent-swap-openat-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); let rel_candidate = format!( @@ -708,7 +745,10 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { .expect_err("anchored open must fail when parent is swapped to symlink"); let raw = err.raw_os_error(); assert!( - matches!(raw, Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)), + matches!( + raw, + Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT) + ), "anchored open must fail closed on parent swap race, got raw_os_error={raw:?}" ); assert!( @@ -896,7 +936,10 @@ async fn unknown_dc_symlinked_target_escape_is_not_written_integration() { let base = std::env::current_dir() .expect("cwd must be available") .join("target") - .join(format!("telemt-unknown-dc-no-write-link-{}", std::process::id())); + .join(format!( + "telemt-unknown-dc-no-write-link-{}", + std::process::id() + )); fs::create_dir_all(&base).expect("integration symlink base must be creatable"); let outside = std::env::temp_dir().join(format!( @@ -1024,11 +1067,17 @@ async fn direct_relay_abort_midflight_releases_route_gauge() { } }) .await; - assert!(started.is_ok(), "direct relay must increment route gauge before abort"); + assert!( + started.is_ok(), + "direct relay must increment route gauge before abort" + ); relay_task.abort(); let joined = relay_task.await; - assert!(joined.is_err(), "aborted direct relay task must return join error"); + assert!( + joined.is_err(), + "aborted direct relay task must return join error" + ); tokio::time::sleep(Duration::from_millis(20)).await; assert_eq!( @@ -1313,15 +1362,22 @@ fn prefer_v6_override_matrix_prefers_matching_family_then_degrades_safely() { ], ); let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve"); - assert!(a.is_ipv6(), "prefer_v6 should choose v6 override when present"); + assert!( + a.is_ipv6(), + "prefer_v6 should choose v6 override when present" + ); let mut cfg_b = ProxyConfig::default(); cfg_b.network.prefer = 6; cfg_b.network.ipv6 = Some(true); - cfg_b.dc_overrides + cfg_b + .dc_overrides .insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]); let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve"); - assert!(b.is_ipv4(), "when no v6 override exists, v4 override must be used"); + assert!( + b.is_ipv4(), + "when no v6 override exists, v4 override must be used" + ); let mut cfg_c = ProxyConfig::default(); cfg_c.network.prefer = 6; @@ -1350,7 +1406,8 @@ fn prefer_v6_override_matrix_ignores_invalid_entries_and_keeps_fail_closed_fallb ], ); - let addr = get_dc_addr_static(dc_idx, &cfg).expect("at least one valid override must keep resolution alive"); + let addr = get_dc_addr_static(dc_idx, &cfg) + .expect("at least one valid override must keep resolution alive"); assert_eq!(addr, "203.0.113.55:443".parse::().unwrap()); } @@ -1370,7 +1427,10 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() { let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve"); let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve"); - assert_eq!(first, second, "override resolution must stay deterministic for dc {idx}"); + assert_eq!( + first, second, + "override resolution must stay deterministic for dc {idx}" + ); assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); } } @@ -1379,12 +1439,12 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() { async fn negative_direct_relay_dc_connection_refused_fails_fast() { let (client_reader_side, _client_writer_side) = duplex(1024); let (_client_reader_relay, client_writer_side) = duplex(1024); - + let key = [0u8; 32]; let iv = 0u128; let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); - + let stats = Arc::new(Stats::new()); let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); let rng = Arc::new(SecureRandom::new()); @@ -1397,9 +1457,11 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() { drop(listener); let mut config_with_override = ProxyConfig::default(); - config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + config_with_override + .dc_overrides + .insert("1".to_string(), vec![dc_addr.to_string()]); let config = Arc::new(config_with_override); - + let upstream_manager = Arc::new(UpstreamManager::new( vec![UpstreamConfig { enabled: true, @@ -1418,7 +1480,7 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() { false, stats.clone(), )); - + let success = HandshakeSuccess { user: "test-user".to_string(), peer: "127.0.0.1:12345".parse().unwrap(), @@ -1460,21 +1522,21 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() { async fn adversarial_direct_relay_cutover_integrity() { let (client_reader_side, _client_writer_side) = duplex(1024); let (_client_reader_relay, client_writer_side) = duplex(1024); - + let key = [0u8; 32]; let iv = 0u128; let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); - + let stats = Arc::new(Stats::new()); let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); let rng = Arc::new(SecureRandom::new()); let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); - + // Mock upstream server. let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let dc_addr = listener.local_addr().unwrap(); - + tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); // Read handshake nonce. @@ -1485,9 +1547,11 @@ async fn adversarial_direct_relay_cutover_integrity() { }); let mut config_with_override = ProxyConfig::default(); - config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + config_with_override + .dc_overrides + .insert("1".to_string(), vec![dc_addr.to_string()]); let config = Arc::new(config_with_override); - + let upstream_manager = Arc::new(UpstreamManager::new( vec![UpstreamConfig { enabled: true, @@ -1506,7 +1570,7 @@ async fn adversarial_direct_relay_cutover_integrity() { false, stats.clone(), )); - + let success = HandshakeSuccess { user: "test-user".to_string(), peer: "127.0.0.1:12345".parse().unwrap(), @@ -1534,7 +1598,8 @@ async fn adversarial_direct_relay_cutover_integrity() { runtime_clone.subscribe(), runtime_clone.snapshot(), 0xABCD_1234, - ).await + ) + .await }); timeout(TokioDuration::from_secs(2), async { @@ -1547,10 +1612,10 @@ async fn adversarial_direct_relay_cutover_integrity() { }) .await .expect("direct relay session must start before cutover"); - + // Trigger cutover. route_runtime.set_mode(RelayRouteMode::Middle).unwrap(); - + // The session should terminate after the staggered delay (1000-2000ms). let result = timeout(TokioDuration::from_secs(5), session_task) .await diff --git a/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs index 5cbbc68..325cffd 100644 --- a/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs +++ b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs @@ -40,9 +40,7 @@ fn subtle_light_fuzz_scope_hint_matches_oracle() { }; !rest.is_empty() && rest.len() <= MAX_SCOPE_HINT_LEN - && rest - .bytes() - .all(|b| b.is_ascii_alphanumeric() || b == b'-') + && rest.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-') } let mut state: u64 = 0xC0FF_EE11_D15C_AFE5; @@ -94,7 +92,10 @@ fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() { let dc_idx = (state as i16).wrapping_sub(16_384); let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail"); - assert_eq!(resolved.port(), crate::protocol::constants::TG_DATACENTER_PORT); + assert_eq!( + resolved.port(), + crate::protocol::constants::TG_DATACENTER_PORT + ); let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true); assert_eq!(resolved.is_ipv6(), expect_v6); } @@ -166,7 +167,9 @@ async fn subtle_integration_parallel_unique_dcs_log_unique_lines() { cfg.general.unknown_dc_log_path = Some(rel_file); let cfg = Arc::new(cfg); - let dcs = [31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908]; + let dcs = [ + 31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908, + ]; let mut tasks = Vec::new(); for dc in dcs { diff --git a/src/proxy/tests/handshake_adversarial_tests.rs b/src/proxy/tests/handshake_adversarial_tests.rs index da93ef4..93832f7 100644 --- a/src/proxy/tests/handshake_adversarial_tests.rs +++ b/src/proxy/tests/handshake_adversarial_tests.rs @@ -1,10 +1,14 @@ use super::*; -use std::sync::Arc; -use std::net::{IpAddr, Ipv4Addr}; -use std::time::{Duration, Instant}; use crate::crypto::sha256; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; -fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { let secret = hex::decode(secret_hex).expect("secret hex must decode"); let mut handshake = [0x5Au8; HANDSHAKE_LEN]; for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] @@ -49,7 +53,9 @@ fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); - cfg.access.users.insert("user".to_string(), secret_hex.to_string()); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); cfg.access.ignore_time_skew = true; cfg.general.modes.secure = true; cfg @@ -71,9 +77,19 @@ async fn mtproto_handshake_bit_flip_anywhere_rejected() { let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); // Baseline check - let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + let res = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; match res { - HandshakeResult::Success(_) => {}, + HandshakeResult::Success(_) => {} _ => panic!("Baseline failed: expected Success"), } @@ -81,8 +97,21 @@ async fn mtproto_handshake_bit_flip_anywhere_rejected() { for byte_pos in SKIP_LEN..HANDSHAKE_LEN { let mut h = base; h[byte_pos] ^= 0x01; // Flip 1 bit - let res = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "Flip at byte {byte_pos} bit 0 must be rejected"); + let res = handle_mtproto_handshake( + &h, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Flip at byte {byte_pos} bit 0 must be rejected" + ); } } @@ -99,25 +128,51 @@ async fn mtproto_handshake_timing_neutrality_mocked() { let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); const ITER: usize = 50; - + let mut start = Instant::now(); for _ in 0..ITER { - let _ = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + let _ = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; } let duration_success = start.elapsed(); start = Instant::now(); for i in 0..ITER { let mut h = base; - h[SKIP_LEN + (i % 48)] ^= 0xFF; - let _ = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + h[SKIP_LEN + (i % 48)] ^= 0xFF; + let _ = handle_mtproto_handshake( + &h, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; } let duration_fail = start.elapsed(); - let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64).abs() / ITER as f64; - + let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64) + .abs() + / ITER as f64; + // Threshold (loose for CI) - assert!(avg_diff_ms < 100.0, "Timing difference too large: {} ms/iter", avg_diff_ms); + assert!( + avg_diff_ms < 100.0, + "Timing difference too large: {} ms/iter", + avg_diff_ms + ); } // ------------------------------------------------------------------ @@ -130,13 +185,13 @@ async fn auth_probe_throttle_saturation_stress() { clear_auth_probe_state_for_testing(); let now = Instant::now(); - + // Record enough failures for one IP to trigger backoff let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { auth_probe_record_failure(target_ip, now); } - + assert!(auth_probe_is_throttled(target_ip, now)); // Stress test with many unique IPs @@ -145,10 +200,7 @@ async fn auth_probe_throttle_saturation_stress() { auth_probe_record_failure(ip, now); } - let tracked = AUTH_PROBE_STATE - .get() - .map(|state| state.len()) - .unwrap_or(0); + let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0); assert!( tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, "auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}" @@ -166,7 +218,17 @@ async fn mtproto_handshake_abridged_prefix_rejected() { let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); - let res = handle_mtproto_handshake(&handshake, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; // MTProxy stops immediately on 0xef assert!(matches!(res, HandshakeResult::BadClient { .. })); } @@ -178,11 +240,17 @@ async fn mtproto_handshake_preferred_user_mismatch_continues() { let secret1_hex = "11111111111111111111111111111111"; let secret2_hex = "22222222222222222222222222222222"; - + let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1); let mut config = ProxyConfig::default(); - config.access.users.insert("user1".to_string(), secret1_hex.to_string()); - config.access.users.insert("user2".to_string(), secret2_hex.to_string()); + config + .access + .users + .insert("user1".to_string(), secret1_hex.to_string()); + config + .access + .users + .insert("user2".to_string(), secret2_hex.to_string()); config.access.ignore_time_skew = true; config.general.modes.secure = true; @@ -190,7 +258,17 @@ async fn mtproto_handshake_preferred_user_mismatch_continues() { let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap(); // Even if we prefer user1, if user2 matches, it should succeed. - let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, Some("user1")).await; + let res = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + Some("user1"), + ) + .await; if let HandshakeResult::Success((_, _, success)) = res { assert_eq!(success.user, "user2"); } else { @@ -209,20 +287,30 @@ async fn mtproto_handshake_concurrent_flood_stability() { config.access.ignore_time_skew = true; let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); let config = Arc::new(config); - + let mut tasks = Vec::new(); for i in 0..50 { let base = base; let config = Arc::clone(&config); let replay_checker = Arc::clone(&replay_checker); let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap(); - + tasks.push(tokio::spawn(async move { - let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + let res = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; matches!(res, HandshakeResult::Success(_)) })); } - + // We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk), // but the system must not panic or hang. for task in tasks { @@ -306,7 +394,10 @@ async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() { .expect("fuzzed mutation must complete in bounded time"); assert!( - matches!(res, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)), + matches!( + res, + HandshakeResult::BadClient { .. } | HandshakeResult::Success(_) + ), "mutation corpus must stay within explicit handshake outcomes" ); } @@ -345,7 +436,12 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() { for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) { let peer: SocketAddr = SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(10, (i / 65535) as u8, ((i / 255) % 255) as u8, (i % 255 + 1) as u8)), + IpAddr::V4(Ipv4Addr::new( + 10, + (i / 65535) as u8, + ((i / 255) % 255) as u8, + (i % 255 + 1) as u8, + )), 43000 + (i % 20000) as u16, ); let res = handle_mtproto_handshake( @@ -362,10 +458,7 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() { assert!(matches!(res, HandshakeResult::BadClient { .. })); } - let tracked = AUTH_PROBE_STATE - .get() - .map(|state| state.len()) - .unwrap_or(0); + let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0); assert!( tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, "probe map must remain bounded under invalid storm: {tracked}" @@ -415,7 +508,10 @@ async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() { .expect("mutation iteration must complete in bounded time"); assert!( - matches!(outcome, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)), + matches!( + outcome, + HandshakeResult::BadClient { .. } | HandshakeResult::Success(_) + ), "mutations must remain fail-closed/auth-only" ); } diff --git a/src/proxy/tests/handshake_fuzz_security_tests.rs b/src/proxy/tests/handshake_fuzz_security_tests.rs index d72c9cd..efb596b 100644 --- a/src/proxy/tests/handshake_fuzz_security_tests.rs +++ b/src/proxy/tests/handshake_fuzz_security_tests.rs @@ -6,7 +6,7 @@ use crate::protocol::constants::ProtoTag; use crate::stats::ReplayChecker; use std::net::SocketAddr; use std::sync::MutexGuard; -use tokio::time::{timeout, Duration as TokioDuration}; +use tokio::time::{Duration as TokioDuration, timeout}; fn make_mtproto_handshake_with_proto_bytes( secret_hex: &str, @@ -48,14 +48,20 @@ fn make_mtproto_handshake_with_proto_bytes( handshake } -fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx) } fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); - cfg.access.users.insert("user".to_string(), secret_hex.to_string()); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); cfg.access.ignore_time_skew = true; cfg.general.modes.secure = true; cfg @@ -140,7 +146,9 @@ async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() { for _ in 0..32 { let mut mutated = base; for _ in 0..4 { - seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + seed = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); } @@ -267,4 +275,4 @@ async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_re } clear_auth_probe_state_for_testing(); -} \ No newline at end of file +} diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index b646d1f..d06f63e 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -1,8 +1,8 @@ use super::*; use crate::crypto::{sha256, sha256_hmac}; use dashmap::DashMap; -use rand::{RngExt, SeedableRng}; use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -80,8 +80,7 @@ fn make_valid_tls_client_hello_with_alpn( for i in 0..4 { digest[28 + i] ^= ts[i]; } - record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] - .copy_from_slice(&digest); + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record } @@ -151,8 +150,7 @@ fn make_valid_tls_client_hello_with_sni_and_alpn( for i in 0..4 { digest[28 + i] ^= ts[i]; } - record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] - .copy_from_slice(&digest); + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record } @@ -167,7 +165,11 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { cfg } -fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); let mut handshake = [0x5Au8; HANDSHAKE_LEN]; @@ -328,7 +330,10 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { expected.extend_from_slice(&client_enc_iv.to_be_bytes()); expected.reverse(); - assert_eq!(&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], expected.as_slice()); + assert_eq!( + &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], + expected.as_slice() + ); } #[test] @@ -445,7 +450,9 @@ async fn tls_replay_with_ignore_time_skew_and_small_boot_timestamp_is_still_bloc #[tokio::test] async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { let secret = [0x77u8; 16]; - let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777")); + let config = Arc::new(test_config_with_secret_hex( + "77777777777777777777777777777777", + )); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); let rng = Arc::new(SecureRandom::new()); let handshake = Arc::new(make_valid_tls_handshake(&secret, 0)); @@ -785,10 +792,10 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() { .access .users .insert("broken_user".to_string(), "aa".to_string()); - config - .access - .users - .insert("valid_user".to_string(), "22222222222222222222222222222222".to_string()); + config.access.users.insert( + "valid_user".to_string(), + "22222222222222222222222222222222".to_string(), + ); config.access.ignore_time_skew = true; let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); @@ -829,12 +836,8 @@ async fn tls_sni_preferred_user_hint_selects_matching_identity_first() { let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let rng = SecureRandom::new(); let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap(); - let handshake = make_valid_tls_client_hello_with_sni_and_alpn( - &shared_secret, - 0, - "user-b", - &[b"h2"], - ); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&shared_secret, 0, "user-b", &[b"h2"]); let result = handle_tls_handshake( &handshake, @@ -868,10 +871,10 @@ fn stress_decode_user_secrets_keeps_preferred_user_first_in_large_set() { let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); for i in 0..4096usize { - config.access.users.insert( - format!("decoy-{i:04}.example"), - secret_hex.clone(), - ); + config + .access + .users + .insert(format!("decoy-{i:04}.example"), secret_hex.clone()); } config .access @@ -910,10 +913,10 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); for i in 0..4096usize { - config.access.users.insert( - format!("decoy-{i:04}.example"), - secret_hex.clone(), - ); + config + .access + .users + .insert(format!("decoy-{i:04}.example"), secret_hex.clone()); } config .access @@ -945,8 +948,7 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { match result { HandshakeResult::Success((_, _, user)) => { assert_eq!( - user, - preferred_user, + user, preferred_user, "SNI preferred-user hint must remain stable under large user cardinality" ); } @@ -1880,11 +1882,15 @@ fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() { "different IPv6 /64 prefixes must not share throttle buckets" ); assert_eq!( - state.get(&normalize_auth_probe_ip(ip_a)).map(|entry| entry.fail_streak), + state + .get(&normalize_auth_probe_ip(ip_a)) + .map(|entry| entry.fail_streak), Some(1) ); assert_eq!( - state.get(&normalize_auth_probe_ip(ip_b)).map(|entry| entry.fail_streak), + state + .get(&normalize_auth_probe_ip(ip_b)) + .map(|entry| entry.fail_streak), Some(1) ); } @@ -1944,7 +1950,6 @@ fn auth_probe_eviction_offset_changes_with_time_component() { ); } - #[test] fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() { let _guard = auth_probe_test_lock() @@ -1986,7 +1991,10 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40)); auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1)); - assert!(state.get(&newcomer).is_some(), "newcomer must still be tracked under over-cap pressure"); + assert!( + state.get(&newcomer).is_some(), + "newcomer must still be tracked under over-cap pressure" + ); assert!( state.get(&sentinel).is_some(), "high fail-streak sentinel must survive round-limited eviction" @@ -2077,13 +2085,20 @@ fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() ((step >> 8) & 0xff) as u8, (step & 0xff) as u8, )); - auth_probe_record_failure_with_state(&state, newcomer, base_now + Duration::from_millis(step as u64 + 1)); + auth_probe_record_failure_with_state( + &state, + newcomer, + base_now + Duration::from_millis(step as u64 + 1), + ); assert!( state.get(&sentinel).is_some(), "step {step}: high-threat sentinel must not be starved by newcomer churn" ); - assert!(state.get(&newcomer).is_some(), "step {step}: newcomer must be tracked"); + assert!( + state.get(&newcomer).is_some(), + "step {step}: newcomer must be tracked" + ); } } @@ -2129,10 +2144,22 @@ fn light_fuzz_auth_probe_overcap_eviction_prefers_less_threatening_entries() { ); } - let newcomer = IpAddr::V4(Ipv4Addr::new(203, 10, ((round >> 8) & 0xff) as u8, (round & 0xff) as u8)); - auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(round as u64 + 1)); + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 10, + ((round >> 8) & 0xff) as u8, + (round & 0xff) as u8, + )); + auth_probe_record_failure_with_state( + &state, + newcomer, + now + Duration::from_millis(round as u64 + 1), + ); - assert!(state.get(&newcomer).is_some(), "round {round}: newcomer should be tracked"); + assert!( + state.get(&newcomer).is_some(), + "round {round}: newcomer should be tracked" + ); assert!( state.get(&sentinel).is_some(), "round {round}: high fail-streak sentinel should survive mixed low-threat pool" @@ -2145,7 +2172,12 @@ fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() { let base = Instant::now(); for _ in 0..4096usize { - let ip = IpAddr::V4(Ipv4Addr::new(rng.random(), rng.random(), rng.random(), rng.random())); + let ip = IpAddr::V4(Ipv4Addr::new( + rng.random(), + rng.random(), + rng.random(), + rng.random(), + )); let offset_ns = rng.random_range(0_u64..2_000_000); let when = base + Duration::from_nanos(offset_ns); @@ -2244,8 +2276,7 @@ async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() { let streak = auth_probe_fail_streak_for_testing(peer_ip) .expect("tracked peer must exist after concurrent failure burst"); assert_eq!( - streak as usize, - tasks, + streak as usize, tasks, "concurrent failures for one source must account every attempt" ); } @@ -2258,7 +2289,9 @@ async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() clear_auth_probe_state_for_testing(); let secret = [0x31u8; 16]; - let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131")); + let config = Arc::new(test_config_with_secret_hex( + "31313131313131313131313131313131", + )); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); let rng = Arc::new(SecureRandom::new()); let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap(); @@ -2845,7 +2878,10 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing() ) .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected) + ); } { @@ -2924,7 +2960,10 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin ) .await; assert!(matches!(result, HandshakeResult::BadClient { .. })); - assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected) + ); } { @@ -3148,7 +3187,9 @@ async fn adversarial_same_peer_invalid_tls_storm_does_not_bypass_saturation_grac .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_auth_probe_state_for_testing(); - let config = Arc::new(test_config_with_secret_hex("75757575757575757575757575757575")); + let config = Arc::new(test_config_with_secret_hex( + "75757575757575757575757575757575", + )); let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); let rng = Arc::new(SecureRandom::new()); let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap(); @@ -3296,7 +3337,11 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak } let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0)); - let valid_mtproto = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 3)); + let valid_mtproto = Arc::new(make_valid_mtproto_handshake( + secret_hex, + ProtoTag::Secure, + 3, + )); let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; let invalid_tls = Arc::new(invalid_tls); @@ -3368,7 +3413,9 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak match task.await.unwrap() { HandshakeResult::BadClient { .. } => bad_clients += 1, HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"), - HandshakeResult::Error(err) => panic!("unexpected error in invalid TLS saturation burst test: {err}"), + HandshakeResult::Error(err) => { + panic!("unexpected error in invalid TLS saturation burst test: {err}") + } } } @@ -3385,8 +3432,7 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak ); assert_eq!( - bad_clients, - 48, + bad_clients, 48, "all invalid TLS probes in mixed saturation burst must be rejected" ); } diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 1b30067..014ce4e 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -277,7 +277,10 @@ async fn integration_ab_harness_envelope_and_blur_improve_obfuscation_vs_baselin hardened_b.len() ); - assert_eq!(baseline_overlap, 0, "baseline above-cap classes should be disjoint"); + assert_eq!( + baseline_overlap, 0, + "baseline above-cap classes should be disjoint" + ); assert!( hardened_overlap > baseline_overlap, "above-cap blur should increase cross-class overlap: baseline={} hardened={}", @@ -314,7 +317,10 @@ fn timing_classifier_helper_threshold_accuracy_drops_for_identical_sets() { let a = vec![10u128, 11, 12, 13, 14]; let b = vec![10u128, 11, 12, 13, 14]; let acc = best_threshold_accuracy_u128(&a, &b); - assert!(acc <= 0.6, "identical sets should not be strongly separable"); + assert!( + acc <= 0.6, + "identical sets should not be strongly separable" + ); } #[test] @@ -336,7 +342,10 @@ async fn timing_classifier_baseline_connect_fail_vs_slow_backend_is_highly_separ let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await; let acc = best_threshold_accuracy_u128(&fail, &slow); - assert!(acc >= 0.80, "baseline timing classes should be separable enough"); + assert!( + acc >= 0.80, + "baseline timing classes should be separable enough" + ); } #[tokio::test] @@ -408,7 +417,10 @@ async fn timing_classifier_normalized_mean_bucket_delta_connect_fail_vs_connect_ let fail_mean = mean_ms(&fail); let success_mean = mean_ms(&success); let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20; - assert!(delta_bucket <= 3, "mean bucket delta too large: {delta_bucket}"); + assert!( + delta_bucket <= 3, + "mean bucket delta too large: {delta_bucket}" + ); } #[tokio::test] @@ -418,7 +430,10 @@ async fn timing_classifier_normalized_p95_bucket_delta_connect_success_vs_slow_i let p95_success = percentile_ms(success, 95, 100); let p95_slow = percentile_ms(slow, 95, 100); let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20; - assert!(delta_bucket <= 4, "p95 bucket delta too large: {delta_bucket}"); + assert!( + delta_bucket <= 4, + "p95 bucket delta too large: {delta_bucket}" + ); } #[tokio::test] @@ -434,7 +449,8 @@ async fn timing_classifier_normalized_spread_is_not_worse_than_baseline_for_conn } #[tokio::test] -async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization() { +async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization() +{ let pairs = [ (PathClass::ConnectFail, PathClass::ConnectSuccess), (PathClass::ConnectFail, PathClass::SlowBackend), @@ -504,7 +520,10 @@ async fn timing_classifier_stress_parallel_sampling_finishes_and_stays_bounded() _ => PathClass::SlowBackend, }; let sample = measure_masking_duration_ms(class, true).await; - assert!((100..=1600).contains(&sample), "stress sample out of bounds: {sample}"); + assert!( + (100..=1600).contains(&sample), + "stress sample out of bounds: {sample}" + ); })); } diff --git a/src/proxy/tests/masking_adversarial_tests.rs b/src/proxy/tests/masking_adversarial_tests.rs index 955e8ec..ce2807a 100644 --- a/src/proxy/tests/masking_adversarial_tests.rs +++ b/src/proxy/tests/masking_adversarial_tests.rs @@ -1,13 +1,13 @@ use super::*; -use std::sync::Arc; -use tokio::io::duplex; -use tokio::net::TcpListener; -use tokio::time::{Instant, Duration}; use crate::config::ProxyConfig; use crate::proxy::relay::relay_bidirectional; use crate::stats::Stats; use crate::stats::beobachten::BeobachtenStore; use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; // ------------------------------------------------------------------ // Probing Indistinguishability (OWASP ASVS 5.1.7) @@ -19,7 +19,7 @@ async fn masking_probes_indistinguishable_timing() { config.censorship.mask = true; config.censorship.mask_host = Some("127.0.0.1".to_string()); config.censorship.mask_port = 80; // Should timeout/refuse - + let peer: SocketAddr = "192.0.2.10:443".parse().unwrap(); let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); let beobachten = BeobachtenStore::new(); @@ -28,14 +28,17 @@ async fn masking_probes_indistinguishable_timing() { let probes = vec![ (b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"), (b"SSH-2.0-probe".to_vec(), "SSH"), - (vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], "TLS-scanner"), + ( + vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], + "TLS-scanner", + ), (vec![0x42; 5], "port-scanner"), ]; for (probe, type_name) in probes { let (client_reader, _client_writer) = duplex(256); let (_client_visible_reader, client_visible_writer) = duplex(256); - + let start = Instant::now(); handle_bad_client( client_reader, @@ -45,13 +48,17 @@ async fn masking_probes_indistinguishable_timing() { local_addr, &config, &beobachten, - ).await; - + ) + .await; + let elapsed = start.elapsed(); - + // We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests) // to mask whether the backend was reachable or refused. - assert!(elapsed >= Duration::from_millis(30), "Probe {type_name} finished too fast: {elapsed:?}"); + assert!( + elapsed >= Duration::from_millis(30), + "Probe {type_name} finished too fast: {elapsed:?}" + ); } } @@ -76,7 +83,7 @@ async fn masking_budget_stress_under_load() { let (_client_visible_reader, client_visible_writer) = duplex(256); let config = config.clone(); let beobachten = Arc::clone(&beobachten); - + tasks.push(tokio::spawn(async move { let start = Instant::now(); handle_bad_client( @@ -87,14 +94,18 @@ async fn masking_budget_stress_under_load() { local_addr, &config, &beobachten, - ).await; + ) + .await; start.elapsed() })); } for task in tasks { let elapsed = task.await.unwrap(); - assert!(elapsed >= Duration::from_millis(30), "Stress probe finished too fast: {elapsed:?}"); + assert!( + elapsed >= Duration::from_millis(30), + "Stress probe finished too fast: {elapsed:?}" + ); } } @@ -108,10 +119,10 @@ fn test_detect_client_type_boundary_cases() { assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner"); // 10 bytes = unknown assert_eq!(detect_client_type(&[0x42; 10]), "unknown"); - + // HTTP verbs without trailing space assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10 - assert_eq!(detect_client_type(b"GET /path"), "HTTP"); + assert_eq!(detect_client_type(b"GET /path"), "HTTP"); } // ------------------------------------------------------------------ @@ -133,7 +144,9 @@ async fn masking_slowloris_client_idle_timeout_rejected() { assert_eq!(observed, initial); let mut drip = [0u8; 1]; - let drip_read = tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)).await; + let drip_read = + tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)) + .await; assert!( drip_read.is_err() || drip_read.unwrap().is_err(), "backend must not receive post-timeout slowloris drip bytes" @@ -183,18 +196,31 @@ async fn masking_fallback_down_mimics_timeout() { config.censorship.mask = true; config.censorship.mask_host = Some("127.0.0.1".to_string()); config.censorship.mask_port = 1; // Unlikely port - + let (server_reader, server_writer) = duplex(1024); let beobachten = BeobachtenStore::new(); let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap(); let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); let start = Instant::now(); - handle_bad_client(server_reader, server_writer, b"GET / HTTP/1.1\r\n", peer, local, &config, &beobachten).await; - + handle_bad_client( + server_reader, + server_writer, + b"GET / HTTP/1.1\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + let elapsed = start.elapsed(); // It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately - assert!(elapsed >= Duration::from_millis(40), "Must respect connect budget even on failure: {:?}", elapsed); + assert!( + elapsed >= Duration::from_millis(40), + "Must respect connect budget even on failure: {:?}", + elapsed + ); } // ------------------------------------------------------------------ @@ -205,7 +231,13 @@ async fn masking_fallback_down_mimics_timeout() { async fn masking_ssrf_resolve_internal_ranges_blocked() { use crate::network::dns_overrides::resolve_socket_addr; - let blocked_ips = ["127.0.0.1", "169.254.169.254", "10.0.0.1", "192.168.1.1", "0.0.0.0"]; + let blocked_ips = [ + "127.0.0.1", + "169.254.169.254", + "10.0.0.1", + "192.168.1.1", + "0.0.0.0", + ]; for ip in blocked_ips { assert!( @@ -270,7 +302,10 @@ async fn masking_zero_length_initial_data_does_not_hang_or_panic() { .await .unwrap() .unwrap(); - assert_eq!(n, 0, "backend must observe clean EOF for empty initial payload"); + assert_eq!( + n, 0, + "backend must observe clean EOF for empty initial payload" + ); }); let mut config = ProxyConfig::default(); @@ -312,7 +347,10 @@ async fn masking_oversized_initial_payload_is_forwarded_verbatim() { let (mut stream, _) = listener.accept().await.unwrap(); let mut observed = vec![0u8; payload.len()]; stream.read_exact(&mut observed).await.unwrap(); - assert_eq!(observed, payload, "large initial payload must stay byte-for-byte"); + assert_eq!( + observed, payload, + "large initial payload must stay byte-for-byte" + ); } }); @@ -491,7 +529,10 @@ async fn chaos_burst_reconnect_storm_for_masking_and_relay_concurrently() { }); let mut observed = vec![0u8; expected_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, expected_reply); timeout(Duration::from_secs(2), handle) @@ -646,7 +687,10 @@ async fn chaos_burst_reconnect_storm_for_masking_and_relay_multiwave_soak() { }); let mut observed = vec![0u8; expected_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, expected_reply); timeout(Duration::from_secs(3), handle) diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs index 9107ca9..d829bca 100644 --- a/src/proxy/tests/masking_security_tests.rs +++ b/src/proxy/tests/masking_security_tests.rs @@ -1,14 +1,14 @@ use super::*; use crate::config::ProxyConfig; +use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; -use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{duplex, AsyncBufReadExt, BufReader}; +use tokio::io::{AsyncBufReadExt, BufReader, duplex}; use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::UnixListener; -use tokio::time::{Instant, sleep, timeout, Duration}; +use tokio::time::{Duration, Instant, sleep, timeout}; #[tokio::test] async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { @@ -56,7 +56,10 @@ async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); accept_task.await.unwrap(); } @@ -108,7 +111,10 @@ async fn tls_scanner_probe_keeps_http_like_fallback_surface() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); @@ -147,8 +153,8 @@ fn build_mask_proxy_header_v2_matches_builder_output() { let expected = ProxyProtocolV2Builder::new() .with_addrs(peer, local_addr) .build(); - let actual = build_mask_proxy_header(2, peer, local_addr) - .expect("v2 mode must produce a header"); + let actual = + build_mask_proxy_header(2, peer, local_addr).expect("v2 mode must produce a header"); assert_eq!(actual, expected, "v2 header bytes must be deterministic"); } @@ -159,8 +165,8 @@ fn build_mask_proxy_header_v1_mixed_ip_family_uses_generic_unknown_form() { let local_addr: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); let expected = ProxyProtocolV1Builder::new().build(); - let actual = build_mask_proxy_header(1, peer, local_addr) - .expect("v1 mode must produce a header"); + let actual = + build_mask_proxy_header(1, peer, local_addr).expect("v1 mode must produce a header"); assert_eq!(actual, expected, "mixed-family v1 must use UNKNOWN form"); } @@ -197,7 +203,10 @@ async fn beobachten_records_scanner_class_when_mask_is_disabled() { client_reader_side.write_all(b"noise").await.unwrap(); drop(client_reader_side); - let beobachten = timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + let beobachten = timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); assert!(snapshot.contains("[SSH]")); assert!(snapshot.contains("203.0.113.99-1")); @@ -241,7 +250,10 @@ async fn backend_unavailable_falls_back_to_silent_consume() { client_reader_side.write_all(b"noise").await.unwrap(); drop(client_reader_side); - timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); let mut buf = [0u8; 1]; let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) @@ -393,9 +405,9 @@ async fn proxy_header_write_error_on_tcp_path_still_honors_coarse_outcome_budget .await; }); - timeout(Duration::from_millis(35), task) - .await - .expect_err("proxy-header write error path should remain inside coarse masking budget window"); + timeout(Duration::from_millis(35), task).await.expect_err( + "proxy-header write error path should remain inside coarse masking budget window", + ); assert!( started.elapsed() >= Duration::from_millis(35), "proxy-header write error path should avoid immediate-return timing signature" @@ -450,9 +462,9 @@ async fn proxy_header_write_error_on_unix_path_still_honors_coarse_outcome_budge .await; }); - timeout(Duration::from_millis(35), task) - .await - .expect_err("unix proxy-header write error path should remain inside coarse masking budget window"); + timeout(Duration::from_millis(35), task).await.expect_err( + "unix proxy-header write error path should remain inside coarse masking budget window", + ); assert!( started.elapsed() >= Duration::from_millis(35), "unix proxy-header write error path should avoid immediate-return timing signature" @@ -486,8 +498,14 @@ async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() { let mut header_line = Vec::new(); reader.read_until(b'\n', &mut header_line).await.unwrap(); let header_text = String::from_utf8(header_line).unwrap(); - assert!(header_text.starts_with("PROXY "), "must start with PROXY prefix"); - assert!(header_text.ends_with("\r\n"), "v1 header must end with CRLF"); + assert!( + header_text.starts_with("PROXY "), + "must start with PROXY prefix" + ); + assert!( + header_text.ends_with("\r\n"), + "v1 header must end with CRLF" + ); let mut received_probe = vec![0u8; probe.len()]; reader.read_exact(&mut received_probe).await.unwrap(); @@ -523,7 +541,10 @@ async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); accept_task.await.unwrap(); @@ -552,7 +573,10 @@ async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() { let mut sig = [0u8; 12]; stream.read_exact(&mut sig).await.unwrap(); - assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n", "v2 signature must match spec"); + assert_eq!( + &sig, b"\r\n\r\n\0\r\nQUIT\n", + "v2 signature must match spec" + ); let mut fixed = [0u8; 4]; stream.read_exact(&mut fixed).await.unwrap(); @@ -593,7 +617,10 @@ async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); accept_task.await.unwrap(); @@ -893,10 +920,16 @@ async fn mask_disabled_consumes_client_data_without_response() { .await; }); - client_reader_side.write_all(b"untrusted payload").await.unwrap(); + client_reader_side + .write_all(b"untrusted payload") + .await + .unwrap(); drop(client_reader_side); - timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); let mut buf = [0u8; 1]; let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) @@ -962,7 +995,10 @@ async fn proxy_protocol_v1_header_is_sent_before_probe() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); accept_task.await.unwrap(); } @@ -1026,7 +1062,10 @@ async fn proxy_protocol_v2_header_is_sent_before_probe() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); accept_task.await.unwrap(); } @@ -1086,7 +1125,10 @@ async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); accept_task.await.unwrap(); } @@ -1094,7 +1136,11 @@ async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() { #[cfg(unix)] #[tokio::test] async fn unix_socket_mask_path_forwards_probe_and_response() { - let sock_path = format!("/tmp/telemt-mask-test-{}-{}.sock", std::process::id(), rand::random::()); + let sock_path = format!( + "/tmp/telemt-mask-test-{}-{}.sock", + std::process::id(), + rand::random::() + ); let _ = std::fs::remove_file(&sock_path); let listener = UnixListener::bind(&sock_path).unwrap(); @@ -1138,7 +1184,10 @@ async fn unix_socket_mask_path_forwards_probe_and_response() { .await; let mut observed = vec![0u8; backend_reply.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, backend_reply); accept_task.await.unwrap(); @@ -1171,7 +1220,10 @@ async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() { .await; }); - timeout(Duration::from_secs(1), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(1), task) + .await + .unwrap() + .unwrap(); } #[tokio::test] @@ -1329,14 +1381,20 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall // Allow relay tasks to start, then emulate mask backend response. sleep(Duration::from_millis(20)).await; - backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap(); + backend_feed_writer + .write_all(b"HTTP/1.1 200 OK\r\n\r\n") + .await + .unwrap(); backend_feed_writer.shutdown().await.unwrap(); let mut observed = vec![0u8; 19]; - timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed)) - .await - .unwrap() - .unwrap(); + timeout( + Duration::from_secs(1), + client_visible_reader.read_exact(&mut observed), + ) + .await + .unwrap() + .unwrap(); assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n"); relay.abort(); @@ -1394,14 +1452,23 @@ async fn relay_to_mask_preserves_backend_response_after_client_half_close() { client_write.shutdown().await.unwrap(); let mut observed_resp = vec![0u8; response.len()]; - timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp)) + timeout( + Duration::from_secs(1), + client_visible_reader.read_exact(&mut observed_resp), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(observed_resp, response); + + timeout(Duration::from_secs(1), fallback_task) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(1), backend_task) .await .unwrap() .unwrap(); - assert_eq!(observed_resp, response); - - timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap(); - timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap(); } #[tokio::test] @@ -1427,16 +1494,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { let timed = timeout( Duration::from_millis(40), relay_to_mask( - reader, - writer, - mask_read, - mask_write, - b"", - false, - 0, - 0, - false, - 0, + reader, writer, mask_read, mask_write, b"", false, 0, 0, false, 0, ), ) .await; @@ -1574,9 +1632,11 @@ async fn timing_matrix_masking_classes_under_controlled_inputs() { (mean, min, p95, max) } - let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples); + let (disabled_mean, disabled_min, disabled_p95, disabled_max) = + summarize(&mut disabled_samples); let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); - let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples); + let (reachable_mean, reachable_min, reachable_p95, reachable_max) = + summarize(&mut reachable_samples); println!( "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", @@ -1698,7 +1758,10 @@ async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() { let elapsed = started.elapsed(); let mut observed = vec![0u8; response.len()]; - client_visible_reader.read_exact(&mut observed).await.unwrap(); + client_visible_reader + .read_exact(&mut observed) + .await + .unwrap(); assert_eq!(observed, response); assert!( elapsed < Duration::from_millis(190), @@ -1763,6 +1826,9 @@ async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() { let _ = client_writer_side.write_all(b"X").await; drop(client_writer_side); - timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap(); accept_task.await.unwrap(); } diff --git a/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs b/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs index d2d522f..3f581e2 100644 --- a/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs +++ b/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs @@ -1,5 +1,5 @@ use super::*; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::Duration; diff --git a/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs b/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs index 9e8c5b7..5d494b8 100644 --- a/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs @@ -1,5 +1,5 @@ use super::*; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::Duration; @@ -90,9 +90,7 @@ fn nearest_centroid_classifier_accuracy( samples_b: &[usize], samples_c: &[usize], ) -> f64 { - let mean = |xs: &[usize]| -> f64 { - xs.iter().copied().sum::() as f64 / xs.len() as f64 - }; + let mean = |xs: &[usize]| -> f64 { xs.iter().copied().sum::() as f64 / xs.len() as f64 }; let ca = mean(samples_a); let cb = mean(samples_b); @@ -104,11 +102,7 @@ fn nearest_centroid_classifier_accuracy( for &x in samples_a { total += 1; let xf = x as f64; - let d = [ - (xf - ca).abs(), - (xf - cb).abs(), - (xf - cc).abs(), - ]; + let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; if d[0] <= d[1] && d[0] <= d[2] { correct += 1; } @@ -117,11 +111,7 @@ fn nearest_centroid_classifier_accuracy( for &x in samples_b { total += 1; let xf = x as f64; - let d = [ - (xf - ca).abs(), - (xf - cb).abs(), - (xf - cc).abs(), - ]; + let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; if d[1] <= d[0] && d[1] <= d[2] { correct += 1; } @@ -130,11 +120,7 @@ fn nearest_centroid_classifier_accuracy( for &x in samples_c { total += 1; let xf = x as f64; - let d = [ - (xf - ca).abs(), - (xf - cb).abs(), - (xf - cc).abs(), - ]; + let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; if d[2] <= d[0] && d[2] <= d[1] { correct += 1; } @@ -166,7 +152,10 @@ async fn masking_shape_classifier_resistance_blur_reduces_threshold_attack_accur let hardened_acc = best_threshold_accuracy(&hardened_a, &hardened_b); // Baseline classes are deterministic/non-overlapping -> near-perfect threshold attack. - assert!(baseline_acc >= 0.99, "baseline separability unexpectedly low: {baseline_acc:.3}"); + assert!( + baseline_acc >= 0.99, + "baseline separability unexpectedly low: {baseline_acc:.3}" + ); // Blur must materially reduce the best one-dimensional length classifier. assert!( hardened_acc <= 0.90, @@ -247,7 +236,11 @@ async fn masking_shape_classifier_resistance_edge_max_extra_one_has_two_point_su seen.insert(observed); } - assert_eq!(seen.len(), 2, "both support points should appear under repeated sampling"); + assert_eq!( + seen.len(), + 2, + "both support points should appear under repeated sampling" + ); } #[tokio::test] @@ -262,13 +255,25 @@ async fn masking_shape_classifier_resistance_negative_blur_without_shape_hardeni bs_observed.insert(capture_forwarded_len(BODY_B, false, true, 96).await); } - assert_eq!(as_observed.len(), 1, "without shape hardening class A must stay deterministic"); - assert_eq!(bs_observed.len(), 1, "without shape hardening class B must stay deterministic"); - assert_ne!(as_observed, bs_observed, "distinct classes should remain separable without shaping"); + assert_eq!( + as_observed.len(), + 1, + "without shape hardening class A must stay deterministic" + ); + assert_eq!( + bs_observed.len(), + 1, + "without shape hardening class B must stay deterministic" + ); + assert_ne!( + as_observed, bs_observed, + "distinct classes should remain separable without shaping" + ); } #[tokio::test] -async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur() { +async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur() + { const SAMPLES: usize = 80; const MAX_EXTRA: usize = 96; const C1: usize = 5000; @@ -295,13 +300,23 @@ async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_at let base_acc = nearest_centroid_classifier_accuracy(&base1, &base2, &base3); let hard_acc = nearest_centroid_classifier_accuracy(&hard1, &hard2, &hard3); - assert!(base_acc >= 0.99, "baseline centroid separability should be near-perfect"); - assert!(hard_acc <= 0.88, "blur should materially degrade 3-class centroid attack"); - assert!(hard_acc <= base_acc - 0.1, "accuracy drop should be meaningful"); + assert!( + base_acc >= 0.99, + "baseline centroid separability should be near-perfect" + ); + assert!( + hard_acc <= 0.88, + "blur should materially degrade 3-class centroid attack" + ); + assert!( + hard_acc <= base_acc - 0.1, + "accuracy drop should be meaningful" + ); } #[tokio::test] -async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign() { +async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign() + { let mut s: u64 = 0xDEAD_BEEF_CAFE_BABE; for _ in 0..96 { s ^= s << 7; diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs index fc0b0b8..b7c884b 100644 --- a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -1,6 +1,6 @@ use super::*; -use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWriteExt}; -use tokio::time::{sleep, timeout, Duration}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, empty, sink}; +use tokio::time::{Duration, sleep, timeout}; fn oracle_len( total_sent: usize, @@ -54,17 +54,23 @@ async fn run_relay_case( client_writer.shutdown().await.unwrap(); } - timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); if !close_client { drop(client_writer); } let mut observed = Vec::new(); - timeout(Duration::from_secs(2), mask_observer.read_to_end(&mut observed)) - .await - .unwrap() - .unwrap(); + timeout( + Duration::from_secs(2), + mask_observer.read_to_end(&mut observed), + ) + .await + .unwrap() + .unwrap(); observed } @@ -97,12 +103,29 @@ async fn masking_shape_guard_positive_clean_eof_path_shapes_and_preserves_prefix let extra = vec![0x55; 300]; let total = initial.len() + extra.len(); - let observed = run_relay_case(initial.clone(), extra.clone(), true, true, 512, 4096, false, 0).await; + let observed = run_relay_case( + initial.clone(), + extra.clone(), + true, + true, + 512, + 4096, + false, + 0, + ) + .await; let expected_len = oracle_len(total, true, true, initial.len(), 512, 4096); - assert_eq!(observed.len(), expected_len, "clean EOF path must be bucket-shaped"); + assert_eq!( + observed.len(), + expected_len, + "clean EOF path must be bucket-shaped" + ); assert_eq!(&observed[..initial.len()], initial.as_slice()); - assert_eq!(&observed[initial.len()..(initial.len() + extra.len())], extra.as_slice()); + assert_eq!( + &observed[initial.len()..(initial.len() + extra.len())], + extra.as_slice() + ); } #[tokio::test] @@ -112,7 +135,11 @@ async fn masking_shape_guard_edge_empty_initial_remains_transparent_under_clean_ let observed = run_relay_case(initial, extra.clone(), true, true, 512, 4096, false, 0).await; - assert_eq!(observed.len(), extra.len(), "empty initial_data must never trigger shaping"); + assert_eq!( + observed.len(), + extra.len(), + "empty initial_data must never trigger shaping" + ); assert_eq!(observed, extra); } @@ -212,13 +239,19 @@ async fn masking_shape_guard_stress_parallel_mixed_sessions_keep_oracle_and_no_h assert_eq!(&observed[..initial_len], initial.as_slice()); } if extra_len > 0 { - assert_eq!(&observed[initial_len..(initial_len + extra_len)], extra.as_slice()); + assert_eq!( + &observed[initial_len..(initial_len + extra_len)], + extra.as_slice() + ); } })); } for task in tasks { - timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); } } @@ -238,7 +271,10 @@ async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_l let mut one = [0u8; 1]; let r = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await; - assert!(r.is_err() || r.unwrap().is_err(), "no post-timeout drip/tail may reach backend"); + assert!( + r.is_err() || r.unwrap().is_err(), + "no post-timeout drip/tail may reach backend" + ); } }); @@ -274,8 +310,14 @@ async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_l sleep(Duration::from_millis(160)).await; let _ = client_writer.write_all(b"X").await; - timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); } #[tokio::test] @@ -352,7 +394,10 @@ async fn masking_shape_guard_above_cap_blur_parallel_stress_keeps_bounds() { } for task in tasks { - timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); } } diff --git a/src/proxy/tests/masking_shape_guard_security_tests.rs b/src/proxy/tests/masking_shape_guard_security_tests.rs index 72c208f..34a89c4 100644 --- a/src/proxy/tests/masking_shape_guard_security_tests.rs +++ b/src/proxy/tests/masking_shape_guard_security_tests.rs @@ -1,7 +1,7 @@ use super::*; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; -use tokio::time::{timeout, Duration}; +use tokio::time::{Duration, timeout}; #[tokio::test] async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() { @@ -15,7 +15,10 @@ async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() let (mut stream, _) = listener.accept().await.unwrap(); let mut got = Vec::new(); stream.read_to_end(&mut got).await.unwrap(); - assert_eq!(got, expected, "empty initial_data path must not inject shape padding"); + assert_eq!( + got, expected, + "empty initial_data path must not inject shape padding" + ); } }); @@ -51,8 +54,14 @@ async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() client_writer.write_all(&client_payload).await.unwrap(); client_writer.shutdown().await.unwrap(); - timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap(); - timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), relay_task) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); } #[tokio::test] @@ -105,7 +114,10 @@ async fn shape_guard_timeout_exit_does_not_append_padding_after_initial_probe() ) .await; - timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); } #[tokio::test] @@ -126,7 +138,11 @@ async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_paddin let expected_prefix_len = initial.len() + extra.len(); assert_eq!(&got[..initial.len()], initial.as_slice()); assert_eq!(&got[initial.len()..expected_prefix_len], extra.as_slice()); - assert_eq!(got.len(), 512, "clean EOF path should still shape to floor bucket"); + assert_eq!( + got.len(), + 512, + "clean EOF path should still shape to floor bucket" + ); } }); @@ -162,6 +178,12 @@ async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_paddin client_writer.write_all(&extra).await.unwrap(); client_writer.shutdown().await.unwrap(); - timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap(); - timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), relay_task) + .await + .unwrap() + .unwrap(); + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); } diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs index eade371..8174a3d 100644 --- a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -1,5 +1,5 @@ use super::*; -use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWrite}; +use tokio::io::{AsyncReadExt, AsyncWrite, duplex, empty, sink}; struct CountingWriter { written: usize, @@ -46,7 +46,10 @@ fn shape_bucket_clamps_to_cap_when_next_power_of_two_exceeds_cap() { fn shape_bucket_never_drops_below_total_for_valid_ranges() { for total in [1usize, 32, 127, 512, 999, 1000, 1001, 1499, 1500, 1501] { let bucket = next_mask_shape_bucket(total, 1000, 1500); - assert!(bucket >= total || total >= 1500, "bucket={bucket} total={total}"); + assert!( + bucket >= total || total >= 1500, + "bucket={bucket} total={total}" + ); } } diff --git a/src/proxy/tests/masking_timing_normalization_security_tests.rs b/src/proxy/tests/masking_timing_normalization_security_tests.rs index a5959b4..327ba6a 100644 --- a/src/proxy/tests/masking_timing_normalization_security_tests.rs +++ b/src/proxy/tests/masking_timing_normalization_security_tests.rs @@ -115,6 +115,12 @@ async fn timing_normalization_does_not_sleep_if_path_already_exceeds_ceiling() { let slow = measure_bad_client_duration_ms(MaskPath::SlowBackend, floor, ceiling).await; - assert!(slow >= 280, "slow backend path should remain slow (got {slow}ms)"); - assert!(slow <= 520, "slow backend path should remain bounded in tests (got {slow}ms)"); + assert!( + slow >= 280, + "slow backend path should remain slow (got {slow}ms)" + ); + assert!( + slow <= 520, + "slow backend path should remain bounded in tests (got {slow}ms)" + ); } diff --git a/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs index 574a3f9..dab0dff 100644 --- a/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs +++ b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs @@ -47,7 +47,11 @@ fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() { ); } - assert_eq!(dedup.len(), 2, "bypass path must not mutate dedup cardinality"); + assert_eq!( + dedup.len(), + 2, + "bypass path must not mutate dedup cardinality" + ); assert_eq!( *dedup .get(&0xAAAABBBBCCCCDDDD) @@ -73,7 +77,11 @@ fn edge_all_full_burst_does_not_poison_later_false_path_tracking() { let now = Instant::now(); for i in 0..8192u64 { - assert!(should_emit_full_desync(0xABCD_0000_0000_0000 ^ i, true, now)); + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0000 ^ i, + true, + now + )); } let tracked_key = 0xDEAD_BEEF_0000_0001u64; @@ -175,5 +183,9 @@ fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() { } assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096); - assert_eq!(dedup.len(), before_len, "parallel all_full storm must not mutate cache len"); + assert_eq!( + dedup.len(), + before_len, + "parallel all_full storm must not mutate cache len" + ); } diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 0efc904..3e0b30f 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -2,8 +2,8 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; -use std::sync::{Arc, Mutex, OnceLock}; use std::sync::atomic::AtomicU64; +use std::sync::{Arc, Mutex, OnceLock}; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; @@ -93,7 +93,9 @@ async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { .await .expect("idle test must complete"); - assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); let err_text = match result { Err(ProxyError::Io(ref e)) => e.to_string(), _ => String::new(), @@ -143,7 +145,9 @@ async fn idle_policy_downstream_activity_grace_extends_hard_deadline() { .await .expect("grace test must complete"); - assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); assert!( start.elapsed() >= TokioDuration::from_millis(100), "recent downstream activity must extend hard idle deadline" @@ -171,7 +175,9 @@ async fn relay_idle_policy_disabled_keeps_legacy_timeout_behavior() { ) .await; - assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); let err_text = match result { Err(ProxyError::Io(ref e)) => e.to_string(), _ => String::new(), @@ -225,8 +231,13 @@ async fn adversarial_partial_frame_trickle_cannot_bypass_hard_idle_close() { .await .expect("partial frame trickle test must complete"); - assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); - assert_eq!(frame_counter, 0, "partial trickle must not count as a valid frame"); + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) + ); + assert_eq!( + frame_counter, 0, + "partial trickle must not count as a valid frame" + ); } #[tokio::test] @@ -291,7 +302,10 @@ async fn protocol_desync_small_frame_updates_reason_counter() { plaintext.extend_from_slice(&3u32.to_le_bytes()); plaintext.extend_from_slice(&[1u8, 2, 3]); let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.expect("must write frame"); + writer + .write_all(&encrypted) + .await + .expect("must write frame"); let result = read_client_payload( &mut crypto_reader, @@ -657,7 +671,8 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { +async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() +{ let _guard = acquire_idle_pressure_test_lock(); clear_relay_idle_pressure_state_for_testing(); @@ -680,7 +695,8 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde let conn_id = *conn_id; let stats = stats.clone(); joins.push(tokio::spawn(async move { - let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + let evicted = + maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); (idx, conn_id, seen, evicted) })); } @@ -753,7 +769,8 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida let conn_id = *conn_id; let stats = stats.clone(); joins.push(tokio::spawn(async move { - let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + let evicted = + maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); (idx, conn_id, seen, evicted) })); } diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs index 874e5ea..5bb6d45 100644 --- a/src/proxy/tests/middle_relay_security_tests.rs +++ b/src/proxy/tests/middle_relay_security_tests.rs @@ -1,27 +1,27 @@ use super::*; -use crate::proxy::handshake::HandshakeSuccess; -use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; -use bytes::Bytes; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::crypto::AesCtr; use crate::crypto::SecureRandom; -use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::network::probe::NetworkDecision; +use crate::proxy::handshake::HandshakeSuccess; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::MePool; +use bytes::Bytes; use rand::rngs::StdRng; use rand::{RngExt, SeedableRng}; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Mutex, OnceLock}; use std::thread; -use tokio::sync::Barrier; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::duplex; +use tokio::sync::Barrier; use tokio::time::{Duration as TokioDuration, timeout}; -use std::sync::{Mutex, OnceLock}; fn make_pooled_payload(data: &[u8]) -> PooledBuffer { let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); @@ -46,8 +46,14 @@ fn quota_user_lock_test_lock() -> &'static Mutex<()> { #[test] fn should_yield_sender_only_on_budget_with_backlog() { assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); + assert!(!should_yield_c2me_sender( + C2ME_SENDER_FAIRNESS_BUDGET - 1, + true + )); + assert!(!should_yield_c2me_sender( + C2ME_SENDER_FAIRNESS_BUDGET, + false + )); assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); } @@ -125,14 +131,7 @@ async fn enqueue_c2me_command_closed_channel_recycles_payload() { let (tx, rx) = mpsc::channel::(1); drop(rx); - let result = enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload, - flags: 0, - }, - ) - .await; + let result = enqueue_c2me_command(&tx, C2MeCommand::Data { payload, flags: 0 }).await; assert!(result.is_err(), "closed queue must fail enqueue"); drop(result); @@ -314,9 +313,7 @@ fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { return; } - panic!( - "unable to observe stable saturated lock-cache precondition after bounded retries" - ); + panic!("unable to observe stable saturated lock-cache precondition after bounded retries"); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -390,14 +387,7 @@ async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_succe 12_000 + round, barrier.clone(), ); - let two = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x72, - 13_000 + round, - barrier, - ); + let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x72, 13_000 + round, barrier); let (r1, r2) = tokio::join!(one, two); assert!( @@ -823,7 +813,9 @@ fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { // Poison the full-cache gate lock intentionally. let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); let _ = std::panic::catch_unwind(|| { - let _lock = gate.lock().expect("gate lock must be lockable before poison"); + let _lock = gate + .lock() + .expect("gate lock must be lockable before poison"); panic!("intentional gate poison for fail-closed regression"); }); @@ -1208,7 +1200,11 @@ async fn read_client_payload_large_intermediate_frame_is_exact() { let (frame, quickack) = read; assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.len(), payload_len, "payload size must match wire length"); + assert_eq!( + frame.len(), + payload_len, + "payload size must match wire length" + ); for (idx, byte) in frame.iter().enumerate() { assert_eq!(*byte, (idx as u8).wrapping_mul(31)); } @@ -1376,7 +1372,10 @@ async fn read_client_payload_abridged_extended_len_sets_quickack() { .expect("frame must be present"); let (frame, quickack) = read; - assert!(quickack, "quickack bit must be propagated from abridged header"); + assert!( + quickack, + "quickack bit must be propagated from abridged header" + ); assert_eq!(frame.len(), payload_len); assert_eq!(frame_counter, 1, "one abridged frame must be counted"); } @@ -1436,7 +1435,11 @@ async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { let pool = Arc::new(BufferPool::with_config(64, 2)); pool.preallocate(1); - assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available"); + assert_eq!( + pool.stats().pooled, + 1, + "one pooled buffer must be available" + ); let (reader, mut writer) = duplex(1024); let mut crypto_reader = make_crypto_reader(reader); @@ -1491,7 +1494,8 @@ async fn enqueue_c2me_close_unblocks_after_queue_drain() { .unwrap(); let tx2 = tx.clone(); - let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + let close_task = + tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); tokio::time::sleep(TokioDuration::from_millis(10)).await; @@ -1501,7 +1505,10 @@ async fn enqueue_c2me_close_unblocks_after_queue_drain() { .expect("first queued item must be present"); assert!(matches!(first, C2MeCommand::Data { .. })); - close_task.await.unwrap().expect("close enqueue must succeed after drain"); + close_task + .await + .unwrap() + .expect("close enqueue must succeed after drain"); let second = timeout(TokioDuration::from_millis(100), rx.recv()) .await @@ -1521,7 +1528,8 @@ async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { .unwrap(); let tx2 = tx.clone(); - let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + let close_task = + tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); tokio::time::sleep(TokioDuration::from_millis(10)).await; drop(rx); @@ -1756,7 +1764,8 @@ async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoo } #[tokio::test] -async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() { +async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() + { let (writer_side, mut reader_side) = duplex(1024); let mut writer = make_crypto_writer(writer_side); let rng = SecureRandom::new(); @@ -1851,11 +1860,17 @@ async fn middle_relay_abort_midflight_releases_route_gauge() { } }) .await; - assert!(started.is_ok(), "middle relay must increment route gauge before abort"); + assert!( + started.is_ok(), + "middle relay must increment route gauge before abort" + ); relay_task.abort(); let joined = relay_task.await; - assert!(joined.is_err(), "aborted middle relay task must return join error"); + assert!( + joined.is_err(), + "aborted middle relay task must return join error" + ); tokio::time::sleep(TokioDuration::from_millis(20)).await; assert_eq!( @@ -2014,8 +2029,14 @@ async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read ) .await; - assert!(result.is_err(), "oversized abridged length must fail closed"); - assert_eq!(frame_counter, 0, "oversized frame must not be counted as accepted"); + assert!( + result.is_err(), + "oversized abridged length must fail closed" + ); + assert_eq!( + frame_counter, 0, + "oversized frame must not be counted as accepted" + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -2067,14 +2088,7 @@ async fn stress_quota_race_bursts_never_allow_double_success_per_round() { 6000 + round, barrier.clone(), ); - let two = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x44, - 7000 + round, - barrier, - ); + let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x44, 7000 + round, barrier); let (r1, r2) = tokio::join!(one, two); assert!( @@ -2274,18 +2288,18 @@ async fn secure_padding_distribution_in_relay_writer() { async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { let (client_reader_side, client_writer_side) = duplex(1024); let (_relay_reader_side, relay_writer_side) = duplex(1024); - + let key = [0u8; 32]; let iv = 0u128; let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - + let stats = Arc::new(Stats::new()); let config = Arc::new(ProxyConfig::default()); let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); let rng = Arc::new(SecureRandom::new()); let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - + // Create an ME pool. let me_pool = make_me_pool_for_abort_test(stats.clone()).await; @@ -2296,7 +2310,7 @@ async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() drop(probe_rx); me_pool.registry().unregister(probe_conn_id).await; let target_conn_id = probe_conn_id.wrapping_add(1); - + let success = HandshakeSuccess { user: "test-user".to_string(), peer: "127.0.0.1:12345".parse().unwrap(), diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs index f87d82b..14754cd 100644 --- a/src/proxy/tests/relay_adversarial_tests.rs +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -3,7 +3,7 @@ use crate::error::ProxyError; use crate::stats::Stats; use crate::stream::BufferPool; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::time::{Duration, Instant, timeout}; // ------------------------------------------------------------------ @@ -14,7 +14,7 @@ use tokio::time::{Duration, Instant, timeout}; async fn relay_hol_blocking_prevention_regression() { let stats = Arc::new(Stats::new()); let user = "hol-user"; - + let (client_peer, relay_client) = duplex(65536); let (relay_server, server_peer) = duplex(65536); @@ -42,7 +42,7 @@ async fn relay_hol_blocking_prevention_regression() { let s2c_handle = tokio::spawn(async move { sp_writer.write_all(&s2c_payload).await.unwrap(); - + let mut total_read = 0; let mut buf = [0u8; 10]; while total_read < payload_size { @@ -54,12 +54,16 @@ async fn relay_hol_blocking_prevention_regression() { let start = Instant::now(); cp_writer.write_all(&c2s_payload).await.unwrap(); - + let mut server_buf = vec![0u8; payload_size]; sp_reader.read_exact(&mut server_buf).await.unwrap(); let elapsed = start.elapsed(); - assert!(elapsed < Duration::from_millis(1000), "C->S must not be blocked by slow S->C (HOL blocking): {:?}", elapsed); + assert!( + elapsed < Duration::from_millis(1000), + "C->S must not be blocked by slow S->C (HOL blocking): {:?}", + elapsed + ); assert_eq!(server_buf, c2s_payload); s2c_handle.abort(); @@ -75,7 +79,7 @@ async fn relay_quota_mid_session_cutoff() { let stats = Arc::new(Stats::new()); let user = "quota-mid-user"; let quota = 5000; - + let (client_peer, relay_client) = duplex(8192); let (relay_server, server_peer) = duplex(8192); @@ -106,9 +110,9 @@ async fn relay_quota_mid_session_cutoff() { // Send another 2000 bytes (Total 6000 > 5000) let buf2 = vec![0x42; 2000]; let _ = cp_writer.write_all(&buf2).await; - + let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap(); - + match relay_res { Ok(Err(ProxyError::DataQuotaExceeded { .. })) => { // Expected @@ -155,7 +159,10 @@ async fn relay_chaos_half_close_crossfire_terminates_without_hang() { .await .expect("relay must terminate after bilateral half-close") .expect("relay task must not panic"); - assert!(done.is_ok(), "relay must terminate cleanly under half-close crossfire"); + assert!( + done.is_ok(), + "relay must terminate cleanly under half-close crossfire" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs index 7a2f8b7..080240a 100644 --- a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -5,8 +5,8 @@ use crate::stream::BufferPool; use rand::rngs::StdRng; use rand::{RngExt, SeedableRng}; use std::sync::Arc; -use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; -use tokio::time::{timeout, Duration, Instant}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; async fn read_available(reader: &mut R, budget: Duration) -> usize { let start = Instant::now(); @@ -52,7 +52,10 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() { Arc::new(BufferPool::new()), )); - client_peer.write_all(&[0x10, 0x11, 0x12, 0x13]).await.unwrap(); + client_peer + .write_all(&[0x10, 0x11, 0x12, 0x13]) + .await + .unwrap(); let mut c2s = [0u8; 4]; server_peer.read_exact(&mut c2s).await.unwrap(); assert_eq!(c2s, [0x10, 0x11, 0x12, 0x13]); @@ -70,8 +73,16 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() { let mut probe_server = [0u8; 1]; let mut probe_client = [0u8; 1]; - let leaked_to_server = timeout(Duration::from_millis(120), server_peer.read(&mut probe_server)).await; - let leaked_to_client = timeout(Duration::from_millis(120), client_peer.read(&mut probe_client)).await; + let leaked_to_server = timeout( + Duration::from_millis(120), + server_peer.read(&mut probe_server), + ) + .await; + let leaked_to_client = timeout( + Duration::from_millis(120), + client_peer.read(&mut probe_client), + ) + .await; assert!( !matches!(leaked_to_server, Ok(Ok(n)) if n > 0), @@ -126,14 +137,23 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() { let leaked_to_server = read_available(&mut server_peer, Duration::from_millis(120)).await; let leaked_to_client = read_available(&mut client_peer, Duration::from_millis(120)).await; - assert_eq!(leaked_to_server, 0, "preloaded limit must block C->S immediately"); - assert_eq!(leaked_to_client, 0, "preloaded limit must block S->C immediately"); + assert_eq!( + leaked_to_server, 0, + "preloaded limit must block C->S immediately" + ); + assert_eq!( + leaked_to_client, 0, + "preloaded limit must block S->C immediately" + ); let relay_result = timeout(Duration::from_secs(2), relay) .await .expect("relay must terminate under preloaded cutoff") .expect("relay task must not panic"); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!(stats.get_user_total_octets(user) <= 5); } @@ -160,19 +180,24 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() Arc::new(BufferPool::new()), )); - let _ = tokio::join!(client_peer.write_all(&[0xAA]), server_peer.write_all(&[0xBB])); + let _ = tokio::join!( + client_peer.write_all(&[0xAA]), + server_peer.write_all(&[0xBB]) + ); let mut to_server = [0u8; 1]; let mut to_client = [0u8; 1]; - let delivered_server = match timeout(Duration::from_millis(120), server_peer.read(&mut to_server)).await { - Ok(Ok(n)) => n, - _ => 0, - }; - let delivered_client = match timeout(Duration::from_millis(120), client_peer.read(&mut to_client)).await { - Ok(Ok(n)) => n, - _ => 0, - }; + let delivered_server = + match timeout(Duration::from_millis(120), server_peer.read(&mut to_server)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + let delivered_client = + match timeout(Duration::from_millis(120), client_peer.read(&mut to_client)).await { + Ok(Ok(n)) => n, + _ => 0, + }; assert!( delivered_server + delivered_client <= 1, @@ -183,7 +208,10 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() .await .expect("relay must terminate under quota=1") .expect("relay task must not panic"); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!(stats.get_user_total_octets(user) <= 1); } @@ -241,7 +269,10 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo .expect("relay must terminate under black-hat jitter attack") .expect("relay task must not panic"); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!( delivered_to_server + delivered_to_client <= quota as usize, "combined forwarded bytes must never exceed configured quota" @@ -291,13 +322,17 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar if rng.random::() { let _ = client_peer.write_all(&[rng.random::()]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(3), server_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), server_peer.read(&mut one)).await + { delivered_total = delivered_total.saturating_add(n); } } else { let _ = server_peer.write_all(&[rng.random::()]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(3), client_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), client_peer.read(&mut one)).await + { delivered_total = delivered_total.saturating_add(n); } } @@ -312,7 +347,8 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar .expect("fuzz relay task must not panic"); assert!( - relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), "relay must either close cleanly or terminate via typed quota error" ); assert!( @@ -371,18 +407,25 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo if ((step as usize + worker_id as usize) & 1) == 0 { let _ = client_peer.write_all(&[step ^ 0x3C]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(3), server_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), server_peer.read(&mut one)).await + { delivered = delivered.saturating_add(n); } } else { let _ = server_peer.write_all(&[step ^ 0xC3]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(3), client_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(3), client_peer.read(&mut one)).await + { delivered = delivered.saturating_add(n); } } - tokio::time::sleep(Duration::from_millis((((worker_id as u64) + (step as u64)) % 3) + 1)).await; + tokio::time::sleep(Duration::from_millis( + (((worker_id as u64) + (step as u64)) % 3) + 1, + )) + .await; } drop(client_peer); @@ -393,7 +436,8 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo .expect("stress relay task must not panic"); assert!( - relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), "stress relay must either close cleanly or terminate via typed quota error" ); delivered @@ -402,7 +446,8 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo let mut delivered_sum = 0usize; for worker in workers { - delivered_sum = delivered_sum.saturating_add(worker.await.expect("stress worker must not panic")); + delivered_sum = + delivered_sum.saturating_add(worker.await.expect("stress worker must not panic")); } assert!( diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs index 4add5f0..e29e86e 100644 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs @@ -6,7 +6,7 @@ use dashmap::DashMap; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::time::Duration; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::sync::Barrier; use tokio::time::Instant; @@ -62,7 +62,10 @@ fn quota_lock_unique_users_materialize_distinct_entries() { } for user in &users { - assert!(map.get(user).is_some(), "lock cache must contain entry for {user}"); + assert!( + map.get(user).is_some(), + "lock cache must contain entry for {user}" + ); } } @@ -160,7 +163,10 @@ fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-saturated-held-{}-{idx}", std::process::id()))); + retained.push(quota_user_lock(&format!( + "quota-saturated-held-{}-{idx}", + std::process::id() + ))); } let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); @@ -183,7 +189,10 @@ async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-saturated-race-held-{}-{idx}", std::process::id()))); + retained.push(quota_user_lock(&format!( + "quota-saturated-race-held-{}-{idx}", + std::process::id() + ))); } let stats = Arc::new(Stats::new()); @@ -234,7 +243,10 @@ async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-saturated-stress-held-{}-{idx}", std::process::id()))); + retained.push(quota_user_lock(&format!( + "quota-saturated-stress-held-{}-{idx}", + std::process::id() + ))); } for round in 0..128u32 { @@ -355,7 +367,8 @@ async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { .expect("client write must succeed"); let mut probe = [0u8; 1]; - let forwarded = tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; + let forwarded = + tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; if let Ok(Ok(n)) = forwarded { assert_eq!(n, 0, "zero quota path must not forward payload bytes"); } @@ -392,14 +405,26 @@ async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { let c2s = vec![0xA5; 2048]; let s2c = vec![0x5A; 1536]; - client_peer.write_all(&c2s).await.expect("client burst write must succeed"); + client_peer + .write_all(&c2s) + .await + .expect("client burst write must succeed"); let mut got_c2s = vec![0u8; c2s.len()]; - server_peer.read_exact(&mut got_c2s).await.expect("server must receive c2s burst"); + server_peer + .read_exact(&mut got_c2s) + .await + .expect("server must receive c2s burst"); assert_eq!(got_c2s, c2s); - server_peer.write_all(&s2c).await.expect("server burst write must succeed"); + server_peer + .write_all(&s2c) + .await + .expect("server burst write must succeed"); let mut got_s2c = vec![0u8; s2c.len()]; - client_peer.read_exact(&mut got_s2c).await.expect("client must receive s2c burst"); + client_peer + .read_exact(&mut got_s2c) + .await + .expect("client must receive s2c burst"); assert_eq!(got_s2c, s2c); drop(client_peer); diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index e9e6a61..5714f48 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -5,9 +5,9 @@ use crate::stream::BufferPool; use rand::rngs::StdRng; use rand::{RngExt, SeedableRng}; use std::sync::Arc; -use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; use tokio::sync::Barrier; -use tokio::time::{timeout, Duration}; +use tokio::time::{Duration, timeout}; fn assert_is_prefix(received: &[u8], sent: &[u8], direction: &str) { assert!( @@ -110,7 +110,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() .expect("fuzz relay task must not panic"); assert!( - relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), "fuzz case {case}: relay must end cleanly or with typed quota error" ); @@ -172,11 +173,21 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt let mut got_at_server = [0u8; 1]; let mut got_at_client = [0u8; 1]; - let n_server = match timeout(Duration::from_millis(120), server_peer.read(&mut got_at_server)).await { + let n_server = match timeout( + Duration::from_millis(120), + server_peer.read(&mut got_at_server), + ) + .await + { Ok(Ok(n)) => n, _ => 0, }; - let n_client = match timeout(Duration::from_millis(120), client_peer.read(&mut got_at_client)).await { + let n_client = match timeout( + Duration::from_millis(120), + client_peer.read(&mut got_at_client), + ) + .await + { Ok(Ok(n)) => n, _ => 0, }; @@ -194,7 +205,10 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt .expect("quota race relay must terminate") .expect("quota race relay task must not panic"); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!(stats.get_user_total_octets(user) <= 1); } @@ -276,7 +290,8 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode .expect("stress relay task must not panic"); assert!( - relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), "stress relay must end cleanly or with typed quota error" ); diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs index 207d603..dfbab85 100644 --- a/src/proxy/tests/relay_quota_overflow_regression_tests.rs +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -3,8 +3,8 @@ use crate::error::ProxyError; use crate::stats::Stats; use crate::stream::BufferPool; use std::sync::Arc; -use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; -use tokio::time::{timeout, Duration}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; async fn read_available(reader: &mut R, budget_ms: u64) -> usize { let mut total = 0usize; @@ -46,7 +46,10 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ )); // Single chunk attempts to cross remaining budget (4 > 1). - client_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + client_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .unwrap(); client_peer.shutdown().await.unwrap(); let forwarded = read_available(&mut server_peer, 60).await; @@ -60,7 +63,10 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ forwarded, 0, "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" ); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!( stats.get_user_total_octets(user) <= 10, "accounted bytes must never exceed quota after overflowing chunk" @@ -94,7 +100,10 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of )); // Exact boundary write should pass once. - client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + client_peer + .write_all(&[0xAA, 0xBB, 0xCC, 0xDD]) + .await + .unwrap(); let mut exact = [0u8; 4]; timeout(Duration::from_secs(1), server_peer.read_exact(&mut exact)) @@ -118,7 +127,10 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of leaked_after, 0, "no bytes may pass after exact boundary is consumed" ); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!(stats.get_user_total_octets(user) <= 10); } @@ -171,7 +183,8 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { .expect("stress relay task must not panic"); assert!( - relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), "stress relay must finish cleanly or with typed quota error" ); forwarded diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs index 1cd5920..9f68258 100644 --- a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs +++ b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs @@ -186,7 +186,10 @@ async fn integration_parallel_waiters_resume_after_single_release_event() { timeout(Duration::from_secs(1), async { for waiter in waiters { let outcome = waiter.await.expect("waiter must not panic"); - assert!(outcome.is_ok(), "waiter must resume and complete after release"); + assert!( + outcome.is_ok(), + "waiter must resume and complete after release" + ); } }) .await @@ -235,7 +238,10 @@ async fn light_fuzz_release_timing_matrix_preserves_liveness() { .await .expect("fuzz round writer must complete") .expect("fuzz writer task must not panic"); - assert!(done.is_ok(), "fuzz round writer must not stall after release"); + assert!( + done.is_ok(), + "fuzz round writer must not stall after release" + ); } } diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs index 2dabaa3..fa4878a 100644 --- a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs @@ -5,7 +5,7 @@ use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::task::{Context, Waker}; -use tokio::io::{ReadBuf, AsyncWriteExt}; +use tokio::io::{AsyncWriteExt, ReadBuf}; use tokio::time::{Duration, timeout}; #[derive(Default)] @@ -83,7 +83,10 @@ async fn positive_contended_writer_emits_deferred_wake_for_liveness() { drop(held_guard); let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); - assert!(ready.is_ready(), "writer must progress after contention release"); + assert!( + ready.is_ready(), + "writer must progress after contention release" + ); } #[tokio::test] @@ -117,7 +120,10 @@ async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { for _ in 0..512 { let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); - assert!(poll.is_pending(), "writer must stay pending while lock is held"); + assert!( + poll.is_pending(), + "writer must stay pending while lock is held" + ); tokio::task::yield_now().await; } diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs index b9b3478..50cdfa3 100644 --- a/src/proxy/tests/relay_security_tests.rs +++ b/src/proxy/tests/relay_security_tests.rs @@ -6,10 +6,10 @@ use std::future::poll_fn; use std::io; use std::pin::Pin; use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Mutex; -use std::task::{Context, Poll}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::Waker; +use std::task::{Context, Poll}; use tokio::io::{AsyncRead, ReadBuf}; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; use tokio::time::{Duration, timeout}; @@ -60,7 +60,10 @@ async fn quota_lock_contention_does_not_self_wake_pending_writer() { let mut cx = Context::from_waker(&waker); let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(poll.is_pending(), "writer must remain pending while lock is contended"); + assert!( + poll.is_pending(), + "writer must remain pending while lock is contended" + ); assert_eq!( wake_counter.wakes.load(Ordering::Relaxed), 0, @@ -99,7 +102,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_ let mut cx = Context::from_waker(&waker); let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(first.is_pending(), "writer must remain pending while lock is contended"); + assert!( + first.is_pending(), + "writer must remain pending while lock is contended" + ); assert_eq!( wake_counter.wakes.load(Ordering::Relaxed), 0, @@ -123,7 +129,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_ ); let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!(second.is_pending(), "writer remains pending while lock is still held"); + assert!( + second.is_pending(), + "writer remains pending while lock is still held" + ); for _ in 0..8 { tokio::task::yield_now().await; @@ -136,7 +145,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_ drop(held_lock); let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!(released.is_ready(), "writer must make progress once quota lock is released"); + assert!( + released.is_ready(), + "writer must make progress once quota lock is released" + ); } #[tokio::test] @@ -172,7 +184,10 @@ async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() let mut buf = ReadBuf::new(&mut storage); let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(first.is_pending(), "reader must remain pending while lock is contended"); + assert!( + first.is_pending(), + "reader must remain pending while lock is contended" + ); assert_eq!( wake_counter.wakes.load(Ordering::Relaxed), 0, @@ -193,7 +208,10 @@ async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() drop(held_lock); let mut buf_after_release = ReadBuf::new(&mut storage); let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); - assert!(released.is_ready(), "reader must make progress once quota lock is released"); + assert!( + released.is_ready(), + "reader must make progress once quota lock is released" + ); } #[tokio::test] @@ -297,7 +315,8 @@ async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhaus } #[tokio::test] -async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() { +async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() + { let stats = Arc::new(Stats::new()); let quota_user = "partial-leak-user"; stats.add_user_octets_from(quota_user, 3); @@ -569,7 +588,7 @@ async fn relay_bidirectional_terminates_on_activity_timeout() { // Wait past the activity timeout threshold (1800 seconds) + buffer tokio::time::sleep(Duration::from_secs(1805)).await; - + // Resume time to process timeouts tokio::time::resume(); @@ -582,7 +601,7 @@ async fn relay_bidirectional_terminates_on_activity_timeout() { relay_result.is_ok(), "relay should complete successfully on scheduled inactivity timeout" ); - + // Verify client/server sockets are closed drop(client_peer); drop(server_peer); @@ -634,12 +653,13 @@ async fn relay_bidirectional_watchdog_resists_premature_execution() { relay_result.is_err(), "Relay must not exit prematurely as long as activity was received before timeout" ); - + // Explicitly drop sockets to cleanly shut down relay loop drop(client_peer); drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task).await + + let completion = timeout(Duration::from_secs(1), relay_task) + .await .expect("relay task must complete securely after client disconnection") .expect("relay task must not panic"); assert!(completion.is_ok(), "relay exits clean"); @@ -654,16 +674,29 @@ async fn relay_bidirectional_half_closure_terminates_cleanly() { let (server_reader, server_writer) = tokio::io::split(relay_server); let relay_task = tokio::spawn(relay_bidirectional( - client_reader, client_writer, server_reader, server_writer, 1024, 1024, "half-close", stats, None, Arc::new(BufferPool::new()), + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "half-close", + stats, + None, + Arc::new(BufferPool::new()), )); - + // Half closure: drop the client completely but leave the server active. drop(client_peer); - + // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. // Eventually dropping the server cleanly closes the task. drop(server_peer); - timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap() + .unwrap(); } #[tokio::test] @@ -675,7 +708,16 @@ async fn relay_bidirectional_zero_length_noise_fuzzing() { let (server_reader, server_writer) = tokio::io::split(relay_server); let relay_task = tokio::spawn(relay_bidirectional( - client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz", stats, None, Arc::new(BufferPool::new()), + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "fuzz", + stats, + None, + Arc::new(BufferPool::new()), )); // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) @@ -684,45 +726,62 @@ async fn relay_bidirectional_zero_length_noise_fuzzing() { } client_peer.write_all(&[1, 2, 3]).await.unwrap(); client_peer.flush().await.unwrap(); - + let mut buf = [0u8; 3]; server_peer.read_exact(&mut buf).await.unwrap(); assert_eq!(&buf, &[1, 2, 3]); - + drop(client_peer); drop(server_peer); - timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap() + .unwrap(); } #[tokio::test] async fn relay_bidirectional_asymmetric_backpressure() { let stats = Arc::new(Stats::new()); // Give the client stream an extremely narrow throughput limit explicitly - let (client_peer, relay_client) = duplex(1024); + let (client_peer, relay_client) = duplex(1024); let (relay_server, mut server_peer) = duplex(4096); let (client_reader, client_writer) = tokio::io::split(relay_client); let (server_reader, server_writer) = tokio::io::split(relay_server); let relay_task = tokio::spawn(relay_bidirectional( - client_reader, client_writer, server_reader, server_writer, 1024, 1024, "slowloris", stats, None, Arc::new(BufferPool::new()), + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "slowloris", + stats, + None, + Arc::new(BufferPool::new()), )); let payload = vec![0xba; 65536]; // 64k payload - + // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! - let write_res = tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; - + let write_res = + tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; + assert!( - write_res.is_err(), + write_res.is_err(), "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" ); - + drop(client_peer); drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); + + let completion = timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap(); assert!( - completion.is_ok() || completion.is_err(), + completion.is_ok() || completion.is_err(), "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" ); } @@ -739,27 +798,43 @@ async fn relay_bidirectional_light_fuzzing_temporal_jitter() { let (server_reader, server_writer) = tokio::io::split(relay_server); let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz-user", stats, None, Arc::new(BufferPool::new()), + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "fuzz-user", + stats, + None, + Arc::new(BufferPool::new()), )); let mut rng = StdRng::seed_from_u64(0xDEADBEEF); - + for _ in 0..10 { // Vary timing significantly up to 1600 seconds (limit is 1800s) - let jitter = rng.random_range(100..1600); + let jitter = rng.random_range(100..1600); tokio::time::sleep(Duration::from_secs(jitter)).await; - + client_peer.write_all(&[0x11]).await.unwrap(); client_peer.flush().await.unwrap(); - + // Ensure task has not died let res = timeout(Duration::from_millis(10), &mut relay_task).await; - assert!(res.is_err(), "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses"); + assert!( + res.is_err(), + "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses" + ); } - + drop(client_peer); drop(server_peer); - timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); + timeout(Duration::from_secs(1), relay_task) + .await + .unwrap() + .unwrap() + .unwrap(); } struct FaultyReader { @@ -1038,11 +1113,14 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { server_peer_b.write_all(&[0x04]), ); - let _ = timeout(Duration::from_millis(50), poll_fn(|cx| { - let mut one = [0u8; 1]; - let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); - Poll::Ready(()) - })) + let _ = timeout( + Duration::from_millis(50), + poll_fn(|cx| { + let mut one = [0u8; 1]; + let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); + Poll::Ready(()) + }), + ) .await; drop(client_peer_a); @@ -1063,7 +1141,10 @@ async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { impl FaultyReader { fn permission_denied_with_message(message: impl Into) -> Self { Self { - error_once: Some(io::Error::new(io::ErrorKind::PermissionDenied, message.into())), + error_once: Some(io::Error::new( + io::ErrorKind::PermissionDenied, + message.into(), + )), } } } @@ -1179,14 +1260,20 @@ async fn relay_half_close_keeps_reverse_direction_progressing() { Arc::new(BufferPool::new()), )); - sp_writer.write_all(&[0x10, 0x20, 0x30, 0x40]).await.unwrap(); + sp_writer + .write_all(&[0x10, 0x20, 0x30, 0x40]) + .await + .unwrap(); sp_writer.shutdown().await.unwrap(); let mut inbound = [0u8; 4]; cp_reader.read_exact(&mut inbound).await.unwrap(); assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); - cp_writer.write_all(&[0xaa, 0xbb, 0xcc, 0xdd]).await.unwrap(); + cp_writer + .write_all(&[0xaa, 0xbb, 0xcc, 0xdd]) + .await + .unwrap(); let mut outbound = [0u8; 4]; sp_reader.read_exact(&mut outbound).await.unwrap(); assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); diff --git a/src/proxy/tests/relay_watchdog_delta_security_tests.rs b/src/proxy/tests/relay_watchdog_delta_security_tests.rs index f05ee62..8b9b209 100644 --- a/src/proxy/tests/relay_watchdog_delta_security_tests.rs +++ b/src/proxy/tests/relay_watchdog_delta_security_tests.rs @@ -44,7 +44,10 @@ fn light_fuzz_mixed_pairs_match_saturating_sub_contract() { let expected = current.saturating_sub(previous); let actual = watchdog_delta(current, previous); - assert_eq!(actual, expected, "delta mismatch for ({current}, {previous})"); + assert_eq!( + actual, expected, + "delta mismatch for ({current}, {previous})" + ); } } diff --git a/src/proxy/tests/route_mode_coherence_adversarial_tests.rs b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs index 4f255d4..b7f816e 100644 --- a/src/proxy/tests/route_mode_coherence_adversarial_tests.rs +++ b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs @@ -18,7 +18,10 @@ fn positive_direct_cutover_sets_timestamp_and_snapshot_coherently() { .expect("middle->direct must emit cutover"); let observed = *rx.borrow(); - assert_eq!(observed, emitted, "watch snapshot must match emitted cutover"); + assert_eq!( + observed, emitted, + "watch snapshot must match emitted cutover" + ); assert_eq!(observed.mode, RelayRouteMode::Direct); assert!( runtime.direct_since_epoch_secs().is_some(), @@ -64,7 +67,10 @@ fn edge_middle_cutover_clears_timestamp() { .expect("direct->middle must emit cutover"); let observed = *rx.borrow(); - assert_eq!(observed, emitted, "watch snapshot must match emitted cutover"); + assert_eq!( + observed, emitted, + "watch snapshot must match emitted cutover" + ); assert_eq!(observed.mode, RelayRouteMode::Middle); assert!( runtime.direct_since_epoch_secs().is_none(), diff --git a/src/proxy/tests/route_mode_security_tests.rs b/src/proxy/tests/route_mode_security_tests.rs index 49cbb66..e5925fc 100644 --- a/src/proxy/tests/route_mode_security_tests.rs +++ b/src/proxy/tests/route_mode_security_tests.rs @@ -1,6 +1,6 @@ use super::*; -use rand::{RngExt, SeedableRng}; use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; @@ -19,14 +19,7 @@ fn cutover_stagger_delay_stays_within_budget_bounds() { // Black-hat model: censors trigger many cutovers and correlate disconnect timing. // Keep delay inside a narrow coarse window to avoid long-tail spikes. for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] { - for session_id in [ - 0u64, - 1, - 2, - 0xdead_beef, - 0xfeed_face_cafe_babe, - u64::MAX, - ] { + for session_id in [0u64, 1, 2, 0xdead_beef, 0xfeed_face_cafe_babe, u64::MAX] { let delay = cutover_stagger_delay(session_id, generation); assert!( (1000..=1999).contains(&delay.as_millis()), @@ -216,7 +209,10 @@ fn light_fuzz_set_mode_generation_tracks_only_real_transitions() { let changed = runtime.set_mode(candidate); if candidate == expected_mode { - assert!(changed.is_none(), "idempotent set_mode must not emit cutover state"); + assert!( + changed.is_none(), + "idempotent set_mode must not emit cutover state" + ); } else { expected_mode = candidate; expected_generation = expected_generation.saturating_add(1); @@ -298,7 +294,9 @@ fn stress_concurrent_transition_count_matches_final_generation() { } for worker in workers { - worker.join().expect("route mode transition worker must not panic"); + worker + .join() + .expect("route mode transition worker must not panic"); } }); @@ -391,8 +389,8 @@ fn stress_cutover_stagger_delay_distribution_remains_stable_across_generations() for generation in [0u64, 1, 7, 31, 255, 1024, u32::MAX as u64, u64::MAX - 1] { let mut buckets = [0usize; 1000]; for session_id in 0..100_000u64 { - let delay_ms = cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation) - .as_millis() as usize; + let delay_ms = + cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation).as_millis() as usize; buckets[delay_ms - 1000] += 1; } diff --git a/src/startup.rs b/src/startup.rs index f6f857c..36b1506 100644 --- a/src/startup.rs +++ b/src/startup.rs @@ -175,7 +175,11 @@ impl StartupTracker { pub async fn start_component(&self, id: &'static str, details: Option) { let mut guard = self.state.write().await; guard.current_stage = id.to_string(); - if let Some(component) = guard.components.iter_mut().find(|component| component.id == id) { + if let Some(component) = guard + .components + .iter_mut() + .find(|component| component.id == id) + { if component.started_at_epoch_ms.is_none() { component.started_at_epoch_ms = Some(now_epoch_ms()); } @@ -208,7 +212,11 @@ impl StartupTracker { ) { let mut guard = self.state.write().await; let finished_at = now_epoch_ms(); - if let Some(component) = guard.components.iter_mut().find(|component| component.id == id) { + if let Some(component) = guard + .components + .iter_mut() + .find(|component| component.id == id) + { if component.started_at_epoch_ms.is_none() { component.started_at_epoch_ms = Some(finished_at); component.attempts = component.attempts.saturating_add(1); diff --git a/src/stats/beobachten.rs b/src/stats/beobachten.rs index 2e87fcc..3d3a2da 100644 --- a/src/stats/beobachten.rs +++ b/src/stats/beobachten.rs @@ -110,8 +110,8 @@ impl BeobachtenStore { } fn cleanup(inner: &mut BeobachtenInner, now: Instant, ttl: Duration) { - inner.entries.retain(|_, entry| { - now.saturating_duration_since(entry.last_seen) <= ttl - }); + inner + .entries + .retain(|_, entry| now.saturating_duration_since(entry.last_seen) <= ttl); } } diff --git a/src/stats/mod.rs b/src/stats/mod.rs index c9fc318..bdabe81 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -5,20 +5,20 @@ pub mod beobachten; pub mod telemetry; -use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; -use std::sync::Arc; -use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; -use parking_lot::Mutex; use lru::LruCache; -use std::num::NonZeroUsize; -use std::hash::{Hash, Hasher}; -use std::collections::hash_map::DefaultHasher; +use parking_lot::Mutex; use std::collections::VecDeque; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::num::NonZeroUsize; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use tracing::debug; -use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use self::telemetry::TelemetryPolicy; +use crate::config::{MeTelemetryLevel, MeWriterPickMode}; #[derive(Clone, Copy)] enum RouteConnectionGauge { @@ -264,8 +264,7 @@ impl Stats { let last_cleanup_epoch_secs = self .user_stats_last_cleanup_epoch_secs .load(Ordering::Relaxed); - if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs) - < USER_STATS_CLEANUP_INTERVAL_SECS + if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs) < USER_STATS_CLEANUP_INTERVAL_SECS { return; } @@ -307,7 +306,7 @@ impl Stats { me_level: self.telemetry_me_level(), } } - + pub fn increment_connects_all(&self) { if self.telemetry_core_enabled() { self.connects_all.fetch_add(1, Ordering::Relaxed); @@ -319,7 +318,8 @@ impl Stats { } } pub fn increment_current_connections_direct(&self) { - self.current_connections_direct.fetch_add(1, Ordering::Relaxed); + self.current_connections_direct + .fetch_add(1, Ordering::Relaxed); } pub fn decrement_current_connections_direct(&self) { Self::decrement_atomic_saturating(&self.current_connections_direct); @@ -460,7 +460,8 @@ impl Stats { } pub fn increment_me_keepalive_timeout_by(&self, value: u64) { if self.telemetry_me_allows_normal() { - self.me_keepalive_timeout.fetch_add(value, Ordering::Relaxed); + self.me_keepalive_timeout + .fetch_add(value, Ordering::Relaxed); } } pub fn increment_me_rpc_proxy_req_signal_sent_total(&self) { @@ -505,7 +506,8 @@ impl Stats { } pub fn increment_me_handshake_reject_total(&self) { if self.telemetry_me_allows_normal() { - self.me_handshake_reject_total.fetch_add(1, Ordering::Relaxed); + self.me_handshake_reject_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_handshake_error_code(&self, code: i32) { @@ -570,22 +572,26 @@ impl Stats { } pub fn increment_me_route_drop_channel_closed(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_channel_closed.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_channel_closed + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_route_drop_queue_full(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_queue_full + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_route_drop_queue_full_base(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full_base.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_queue_full_base + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_route_drop_queue_full_high(&self) { if self.telemetry_me_allows_normal() { - self.me_route_drop_queue_full_high.fetch_add(1, Ordering::Relaxed); + self.me_route_drop_queue_full_high + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_writer_pick_success_try_total(&self, mode: MeWriterPickMode) { @@ -677,12 +683,14 @@ impl Stats { } pub fn increment_me_socks_kdf_strict_reject(&self) { if self.telemetry_me_allows_normal() { - self.me_socks_kdf_strict_reject.fetch_add(1, Ordering::Relaxed); + self.me_socks_kdf_strict_reject + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_socks_kdf_compat_fallback(&self) { if self.telemetry_me_allows_debug() { - self.me_socks_kdf_compat_fallback.fetch_add(1, Ordering::Relaxed); + self.me_socks_kdf_compat_fallback + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_secure_padding_invalid(&self) { @@ -714,13 +722,16 @@ impl Stats { self.desync_frames_bucket_0.fetch_add(1, Ordering::Relaxed); } 1..=2 => { - self.desync_frames_bucket_1_2.fetch_add(1, Ordering::Relaxed); + self.desync_frames_bucket_1_2 + .fetch_add(1, Ordering::Relaxed); } 3..=10 => { - self.desync_frames_bucket_3_10.fetch_add(1, Ordering::Relaxed); + self.desync_frames_bucket_3_10 + .fetch_add(1, Ordering::Relaxed); } _ => { - self.desync_frames_bucket_gt_10.fetch_add(1, Ordering::Relaxed); + self.desync_frames_bucket_gt_10 + .fetch_add(1, Ordering::Relaxed); } } } @@ -771,17 +782,20 @@ impl Stats { } pub fn increment_me_writer_removed_unexpected_total(&self) { if self.telemetry_me_allows_normal() { - self.me_writer_removed_unexpected_total.fetch_add(1, Ordering::Relaxed); + self.me_writer_removed_unexpected_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_refill_triggered_total(&self) { if self.telemetry_me_allows_debug() { - self.me_refill_triggered_total.fetch_add(1, Ordering::Relaxed); + self.me_refill_triggered_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_refill_skipped_inflight_total(&self) { if self.telemetry_me_allows_debug() { - self.me_refill_skipped_inflight_total.fetch_add(1, Ordering::Relaxed); + self.me_refill_skipped_inflight_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_refill_failed_total(&self) { @@ -803,7 +817,8 @@ impl Stats { } pub fn increment_me_no_writer_failfast_total(&self) { if self.telemetry_me_allows_normal() { - self.me_no_writer_failfast_total.fetch_add(1, Ordering::Relaxed); + self.me_no_writer_failfast_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_async_recovery_trigger_total(&self) { @@ -814,7 +829,8 @@ impl Stats { } pub fn increment_me_inline_recovery_total(&self) { if self.telemetry_me_allows_normal() { - self.me_inline_recovery_total.fetch_add(1, Ordering::Relaxed); + self.me_inline_recovery_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_ip_reservation_rollback_tcp_limit_total(&self) { @@ -986,12 +1002,14 @@ impl Stats { } pub fn increment_me_floor_cap_block_total(&self) { if self.telemetry_me_allows_normal() { - self.me_floor_cap_block_total.fetch_add(1, Ordering::Relaxed); + self.me_floor_cap_block_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_floor_swap_idle_total(&self) { if self.telemetry_me_allows_normal() { - self.me_floor_swap_idle_total.fetch_add(1, Ordering::Relaxed); + self.me_floor_swap_idle_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_floor_swap_idle_failed_total(&self) { @@ -1000,8 +1018,12 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } - pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } - pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } + pub fn get_connects_all(&self) -> u64 { + self.connects_all.load(Ordering::Relaxed) + } + pub fn get_connects_bad(&self) -> u64 { + self.connects_bad.load(Ordering::Relaxed) + } pub fn get_current_connections_direct(&self) -> u64 { self.current_connections_direct.load(Ordering::Relaxed) } @@ -1012,10 +1034,18 @@ impl Stats { self.get_current_connections_direct() .saturating_add(self.get_current_connections_me()) } - pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } - pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) } - pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) } - pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) } + pub fn get_me_keepalive_sent(&self) -> u64 { + self.me_keepalive_sent.load(Ordering::Relaxed) + } + pub fn get_me_keepalive_failed(&self) -> u64 { + self.me_keepalive_failed.load(Ordering::Relaxed) + } + pub fn get_me_keepalive_pong(&self) -> u64 { + self.me_keepalive_pong.load(Ordering::Relaxed) + } + pub fn get_me_keepalive_timeout(&self) -> u64 { + self.me_keepalive_timeout.load(Ordering::Relaxed) + } pub fn get_me_rpc_proxy_req_signal_sent_total(&self) -> u64 { self.me_rpc_proxy_req_signal_sent_total .load(Ordering::Relaxed) @@ -1036,8 +1066,12 @@ impl Stats { self.me_rpc_proxy_req_signal_close_sent_total .load(Ordering::Relaxed) } - pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) } - pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) } + pub fn get_me_reconnect_attempts(&self) -> u64 { + self.me_reconnect_attempts.load(Ordering::Relaxed) + } + pub fn get_me_reconnect_success(&self) -> u64 { + self.me_reconnect_success.load(Ordering::Relaxed) + } pub fn get_me_handshake_reject_total(&self) -> u64 { self.me_handshake_reject_total.load(Ordering::Relaxed) } @@ -1057,10 +1091,15 @@ impl Stats { self.relay_pressure_evict_total.load(Ordering::Relaxed) } pub fn get_relay_protocol_desync_close_total(&self) -> u64 { - self.relay_protocol_desync_close_total.load(Ordering::Relaxed) + self.relay_protocol_desync_close_total + .load(Ordering::Relaxed) + } + pub fn get_me_crc_mismatch(&self) -> u64 { + self.me_crc_mismatch.load(Ordering::Relaxed) + } + pub fn get_me_seq_mismatch(&self) -> u64 { + self.me_seq_mismatch.load(Ordering::Relaxed) } - pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) } - pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) } pub fn get_me_endpoint_quarantine_total(&self) -> u64 { self.me_endpoint_quarantine_total.load(Ordering::Relaxed) } @@ -1071,8 +1110,7 @@ impl Stats { self.me_kdf_port_only_drift_total.load(Ordering::Relaxed) } pub fn get_me_hardswap_pending_reuse_total(&self) -> u64 { - self.me_hardswap_pending_reuse_total - .load(Ordering::Relaxed) + self.me_hardswap_pending_reuse_total.load(Ordering::Relaxed) } pub fn get_me_hardswap_pending_ttl_expired_total(&self) -> u64 { self.me_hardswap_pending_ttl_expired_total @@ -1153,12 +1191,10 @@ impl Stats { .load(Ordering::Relaxed) } pub fn get_me_writers_active_current_gauge(&self) -> u64 { - self.me_writers_active_current_gauge - .load(Ordering::Relaxed) + self.me_writers_active_current_gauge.load(Ordering::Relaxed) } pub fn get_me_writers_warm_current_gauge(&self) -> u64 { - self.me_writers_warm_current_gauge - .load(Ordering::Relaxed) + self.me_writers_warm_current_gauge.load(Ordering::Relaxed) } pub fn get_me_floor_cap_block_total(&self) -> u64 { self.me_floor_cap_block_total.load(Ordering::Relaxed) @@ -1178,7 +1214,9 @@ impl Stats { out.sort_by_key(|(code, _)| *code); out } - pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) } + pub fn get_me_route_drop_no_conn(&self) -> u64 { + self.me_route_drop_no_conn.load(Ordering::Relaxed) + } pub fn get_me_route_drop_channel_closed(&self) -> u64 { self.me_route_drop_channel_closed.load(Ordering::Relaxed) } @@ -1283,22 +1321,26 @@ impl Stats { self.me_writer_removed_total.load(Ordering::Relaxed) } pub fn get_me_writer_removed_unexpected_total(&self) -> u64 { - self.me_writer_removed_unexpected_total.load(Ordering::Relaxed) + self.me_writer_removed_unexpected_total + .load(Ordering::Relaxed) } pub fn get_me_refill_triggered_total(&self) -> u64 { self.me_refill_triggered_total.load(Ordering::Relaxed) } pub fn get_me_refill_skipped_inflight_total(&self) -> u64 { - self.me_refill_skipped_inflight_total.load(Ordering::Relaxed) + self.me_refill_skipped_inflight_total + .load(Ordering::Relaxed) } pub fn get_me_refill_failed_total(&self) -> u64 { self.me_refill_failed_total.load(Ordering::Relaxed) } pub fn get_me_writer_restored_same_endpoint_total(&self) -> u64 { - self.me_writer_restored_same_endpoint_total.load(Ordering::Relaxed) + self.me_writer_restored_same_endpoint_total + .load(Ordering::Relaxed) } pub fn get_me_writer_restored_fallback_total(&self) -> u64 { - self.me_writer_restored_fallback_total.load(Ordering::Relaxed) + self.me_writer_restored_fallback_total + .load(Ordering::Relaxed) } pub fn get_me_no_writer_failfast_total(&self) -> u64 { self.me_no_writer_failfast_total.load(Ordering::Relaxed) @@ -1317,7 +1359,7 @@ impl Stats { self.ip_reservation_rollback_quota_limit_total .load(Ordering::Relaxed) } - + pub fn increment_user_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1332,7 +1374,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.connects.fetch_add(1, Ordering::Relaxed); } - + pub fn increment_user_curr_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1360,7 +1402,9 @@ impl Stats { let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); loop { - if let Some(max) = limit && current >= max { + if let Some(max) = limit + && current >= max + { return false; } match counter.compare_exchange_weak( @@ -1374,7 +1418,7 @@ impl Stats { } } } - + pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { @@ -1397,13 +1441,14 @@ impl Stats { } } } - + pub fn get_user_curr_connects(&self, user: &str) -> u64 { - self.user_stats.get(user) + self.user_stats + .get(user) .map(|s| s.curr_connects.load(Ordering::Relaxed)) .unwrap_or(0) } - + pub fn add_user_octets_from(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; @@ -1418,7 +1463,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); } - + pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; @@ -1433,7 +1478,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); } - + pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1448,7 +1493,7 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); } - + pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; @@ -1463,17 +1508,20 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); } - + pub fn get_user_total_octets(&self, user: &str) -> u64 { - self.user_stats.get(user) + self.user_stats + .get(user) .map(|s| { - s.octets_from_client.load(Ordering::Relaxed) + - s.octets_to_client.load(Ordering::Relaxed) + s.octets_from_client.load(Ordering::Relaxed) + + s.octets_to_client.load(Ordering::Relaxed) }) .unwrap_or(0) } - - pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } + + pub fn get_handshake_timeouts(&self) -> u64 { + self.handshake_timeouts.load(Ordering::Relaxed) + } pub fn get_upstream_connect_attempt_total(&self) -> u64 { self.upstream_connect_attempt_total.load(Ordering::Relaxed) } @@ -1488,10 +1536,12 @@ impl Stats { .load(Ordering::Relaxed) } pub fn get_upstream_connect_attempts_bucket_1(&self) -> u64 { - self.upstream_connect_attempts_bucket_1.load(Ordering::Relaxed) + self.upstream_connect_attempts_bucket_1 + .load(Ordering::Relaxed) } pub fn get_upstream_connect_attempts_bucket_2(&self) -> u64 { - self.upstream_connect_attempts_bucket_2.load(Ordering::Relaxed) + self.upstream_connect_attempts_bucket_2 + .load(Ordering::Relaxed) } pub fn get_upstream_connect_attempts_bucket_3_4(&self) -> u64 { self.upstream_connect_attempts_bucket_3_4 @@ -1539,7 +1589,8 @@ impl Stats { } pub fn uptime_secs(&self) -> f64 { - self.start_time.read() + self.start_time + .read() .map(|t| t.elapsed().as_secs_f64()) .unwrap_or(0.0) } @@ -1578,7 +1629,7 @@ impl ReplayShard { seq_counter: 0, } } - + fn next_seq(&mut self) -> u64 { self.seq_counter += 1; self.seq_counter @@ -1589,13 +1640,13 @@ impl ReplayShard { return; } let cutoff = now.checked_sub(window).unwrap_or(now); - + while let Some((ts, _, _)) = self.queue.front() { if *ts >= cutoff { break; } let (_, key, queue_seq) = self.queue.pop_front().unwrap(); - + // Use key.as_ref() to get &[u8] — avoids Borrow ambiguity // between Borrow<[u8]> and Borrow> if let Some(entry) = self.cache.peek(key.as_ref()) @@ -1605,23 +1656,24 @@ impl ReplayShard { } } } - + fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { self.cleanup(now, window); // key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]> self.cache.get(key).is_some() } - + fn add(&mut self, key: &[u8], now: Instant, window: Duration) { self.cleanup(now, window); - + let seq = self.next_seq(); let boxed_key: Box<[u8]> = key.into(); - - self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq }); + + self.cache + .put(boxed_key.clone(), ReplayEntry { seen_at: now, seq }); self.queue.push_back((now, boxed_key, seq)); } - + fn len(&self) -> usize { self.cache.len() } @@ -1696,15 +1748,19 @@ impl ReplayChecker { } // Compatibility helpers (non-atomic split operations) — prefer check_and_add_*. - pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) } + pub fn check_handshake(&self, data: &[u8]) -> bool { + self.check_and_add_handshake(data) + } pub fn add_handshake(&self, data: &[u8]) { self.add_only(data, &self.handshake_shards, self.window) } - pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) } + pub fn check_tls_digest(&self, data: &[u8]) -> bool { + self.check_and_add_tls_digest(data) + } pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data, &self.tls_shards, self.tls_window) } - + pub fn stats(&self) -> ReplayStats { let mut total_entries = 0; let mut total_queue_len = 0; @@ -1718,7 +1774,7 @@ impl ReplayChecker { total_entries += s.cache.len(); total_queue_len += s.queue.len(); } - + ReplayStats { total_entries, total_queue_len, @@ -1730,20 +1786,20 @@ impl ReplayChecker { window_secs: self.window.as_secs(), } } - + pub async fn run_periodic_cleanup(&self) { let interval = if self.window.as_secs() > 60 { Duration::from_secs(30) } else { Duration::from_secs(self.window.as_secs().max(1) / 2) }; - + loop { tokio::time::sleep(interval).await; - + let now = Instant::now(); let mut cleaned = 0usize; - + for shard_mutex in &self.handshake_shards { let mut shard = shard_mutex.lock(); let before = shard.len(); @@ -1758,9 +1814,9 @@ impl ReplayChecker { let after = shard.len(); cleaned += before.saturating_sub(after); } - + self.cleanups.fetch_add(1, Ordering::Relaxed); - + if cleaned > 0 { debug!(cleaned = cleaned, "Replay checker: periodic cleanup"); } @@ -1782,13 +1838,19 @@ pub struct ReplayStats { impl ReplayStats { pub fn hit_rate(&self) -> f64 { - if self.total_checks == 0 { 0.0 } - else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 } + if self.total_checks == 0 { + 0.0 + } else { + (self.total_hits as f64 / self.total_checks as f64) * 100.0 + } } - + pub fn ghost_ratio(&self) -> f64 { - if self.total_entries == 0 { 0.0 } - else { self.total_queue_len as f64 / self.total_entries as f64 } + if self.total_entries == 0 { + 0.0 + } else { + self.total_queue_len as f64 / self.total_entries as f64 + } } } @@ -1797,7 +1859,7 @@ mod tests { use super::*; use crate::config::MeTelemetryLevel; use std::sync::Arc; - + #[test] fn test_stats_shared_counters() { let stats = Arc::new(Stats::new()); @@ -1840,15 +1902,15 @@ mod tests { assert_eq!(stats.get_me_keepalive_sent(), 0); assert_eq!(stats.get_me_route_drop_queue_full(), 0); } - + #[test] fn test_replay_checker_basic() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); assert!(!checker.check_handshake(b"test1")); // first time, inserts - assert!(checker.check_handshake(b"test1")); // duplicate + assert!(checker.check_handshake(b"test1")); // duplicate assert!(!checker.check_handshake(b"test2")); // new key inserts } - + #[test] fn test_replay_checker_duplicate_add() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); @@ -1856,7 +1918,7 @@ mod tests { checker.add_handshake(b"dup"); assert!(checker.check_handshake(b"dup")); } - + #[test] fn test_replay_checker_expiration() { let checker = ReplayChecker::new(100, Duration::from_millis(50)); @@ -1865,7 +1927,7 @@ mod tests { std::thread::sleep(Duration::from_millis(100)); assert!(!checker.check_handshake(b"expire")); } - + #[test] fn test_replay_checker_stats() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); @@ -1878,7 +1940,7 @@ mod tests { assert_eq!(stats.total_checks, 4); assert_eq!(stats.total_hits, 1); } - + #[test] fn test_replay_checker_many_keys() { let checker = ReplayChecker::new(10_000, Duration::from_secs(60)); diff --git a/src/stats/tests/connection_lease_security_tests.rs b/src/stats/tests/connection_lease_security_tests.rs index 69ae89a..1d15773 100644 --- a/src/stats/tests/connection_lease_security_tests.rs +++ b/src/stats/tests/connection_lease_security_tests.rs @@ -56,7 +56,10 @@ fn direct_connection_lease_balances_on_panic_unwind() { panic!("intentional panic to verify lease drop path"); })); - assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert!( + panic_result.is_err(), + "panic must propagate from test closure" + ); assert_eq!( stats.get_current_connections_direct(), 0, @@ -74,7 +77,10 @@ fn middle_connection_lease_balances_on_panic_unwind() { panic!("intentional panic to verify middle lease drop path"); })); - assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert!( + panic_result.is_err(), + "panic must propagate from test closure" + ); assert_eq!( stats.get_current_connections_me(), 0, @@ -109,9 +115,7 @@ async fn concurrent_mixed_route_lease_churn_balances_to_zero() { } for worker in workers { - worker - .await - .expect("lease churn worker must not panic"); + worker.await.expect("lease churn worker must not panic"); } assert_eq!( @@ -168,7 +172,9 @@ async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() { tokio::time::timeout(Duration::from_secs(2), async { loop { - if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 { + if stats.get_current_connections_direct() == 0 + && stats.get_current_connections_me() == 0 + { break; } tokio::time::sleep(Duration::from_millis(10)).await; @@ -197,9 +203,7 @@ fn saturating_route_decrements_do_not_underflow_under_race() { } for worker in workers { - worker - .join() - .expect("decrement race worker must not panic"); + worker.join().expect("decrement race worker must not panic"); } assert_eq!( diff --git a/src/stream/buffer_pool.rs b/src/stream/buffer_pool.rs index dac0fb5..6cdac60 100644 --- a/src/stream/buffer_pool.rs +++ b/src/stream/buffer_pool.rs @@ -8,8 +8,8 @@ use bytes::BytesMut; use crossbeam_queue::ArrayQueue; use std::ops::{Deref, DerefMut}; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; // ============= Configuration ============= @@ -42,7 +42,7 @@ impl BufferPool { pub fn new() -> Self { Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS) } - + /// Create a buffer pool with custom configuration pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self { Self { @@ -54,7 +54,7 @@ impl BufferPool { hits: AtomicUsize::new(0), } } - + /// Get a buffer from the pool, or create a new one if empty pub fn get(self: &Arc) -> PooledBuffer { match self.buffers.pop() { @@ -76,7 +76,7 @@ impl BufferPool { } } } - + /// Try to get a buffer, returns None if pool is empty pub fn try_get(self: &Arc) -> Option { self.buffers.pop().map(|mut buffer| { @@ -88,12 +88,12 @@ impl BufferPool { } }) } - + /// Return a buffer to the pool fn return_buffer(&self, mut buffer: BytesMut) { // Clear the buffer but keep capacity buffer.clear(); - + // Only return if we haven't exceeded max and buffer is right size if buffer.capacity() >= self.buffer_size { // Try to push to pool, if full just drop @@ -103,7 +103,7 @@ impl BufferPool { // Actually we don't decrement here because the buffer might have been // grown beyond our size - we just let it go } - + /// Get pool statistics pub fn stats(&self) -> PoolStats { PoolStats { @@ -115,17 +115,21 @@ impl BufferPool { misses: self.misses.load(Ordering::Relaxed), } } - + /// Get buffer size pub fn buffer_size(&self) -> usize { self.buffer_size } - + /// Preallocate buffers to fill the pool pub fn preallocate(&self, count: usize) { let to_alloc = count.min(self.max_buffers); for _ in 0..to_alloc { - if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() { + if self + .buffers + .push(BytesMut::with_capacity(self.buffer_size)) + .is_err() + { break; } self.allocated.fetch_add(1, Ordering::Relaxed); @@ -183,22 +187,22 @@ impl PooledBuffer { pub fn take(mut self) -> BytesMut { self.buffer.take().unwrap() } - + /// Get the capacity of the buffer pub fn capacity(&self) -> usize { self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0) } - + /// Check if buffer is empty pub fn is_empty(&self) -> bool { self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true) } - + /// Get the length of data in buffer pub fn len(&self) -> usize { self.buffer.as_ref().map(|b| b.len()).unwrap_or(0) } - + /// Clear the buffer pub fn clear(&mut self) { if let Some(ref mut b) = self.buffer { @@ -209,7 +213,7 @@ impl PooledBuffer { impl Deref for PooledBuffer { type Target = BytesMut; - + fn deref(&self) -> &Self::Target { self.buffer.as_ref().expect("buffer taken") } @@ -259,7 +263,7 @@ impl<'a> ScopedBuffer<'a> { impl<'a> Deref for ScopedBuffer<'a> { type Target = BytesMut; - + fn deref(&self) -> &Self::Target { self.buffer.deref() } @@ -280,108 +284,108 @@ impl<'a> Drop for ScopedBuffer<'a> { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_pool_basic() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // Get a buffer let mut buf1 = pool.get(); buf1.extend_from_slice(b"hello"); assert_eq!(&buf1[..], b"hello"); - + // Drop returns to pool drop(buf1); - + let stats = pool.stats(); assert_eq!(stats.pooled, 1); assert_eq!(stats.hits, 0); assert_eq!(stats.misses, 1); - + // Get again - should reuse let buf2 = pool.get(); assert!(buf2.is_empty()); // Buffer was cleared - + let stats = pool.stats(); assert_eq!(stats.pooled, 0); assert_eq!(stats.hits, 1); } - + #[test] fn test_pool_multiple_buffers() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // Get multiple buffers let buf1 = pool.get(); let buf2 = pool.get(); let buf3 = pool.get(); - + let stats = pool.stats(); assert_eq!(stats.allocated, 3); assert_eq!(stats.pooled, 0); - + // Return all drop(buf1); drop(buf2); drop(buf3); - + let stats = pool.stats(); assert_eq!(stats.pooled, 3); } - + #[test] fn test_pool_overflow() { let pool = Arc::new(BufferPool::with_config(1024, 2)); - + // Get 3 buffers (more than max) let buf1 = pool.get(); let buf2 = pool.get(); let buf3 = pool.get(); - + // Return all - only 2 should be pooled drop(buf1); drop(buf2); drop(buf3); - + let stats = pool.stats(); assert_eq!(stats.pooled, 2); } - + #[test] fn test_pool_take() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + let mut buf = pool.get(); buf.extend_from_slice(b"data"); - + // Take ownership, buffer should not return to pool let taken = buf.take(); assert_eq!(&taken[..], b"data"); - + let stats = pool.stats(); assert_eq!(stats.pooled, 0); } - + #[test] fn test_pool_preallocate() { let pool = Arc::new(BufferPool::with_config(1024, 10)); pool.preallocate(5); - + let stats = pool.stats(); assert_eq!(stats.pooled, 5); assert_eq!(stats.allocated, 5); } - + #[test] fn test_pool_try_get() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // Pool is empty, try_get returns None assert!(pool.try_get().is_none()); - + // Add a buffer to pool pool.preallocate(1); - + // Now try_get should succeed once while the buffer is held let buf = pool.try_get(); assert!(buf.is_some()); @@ -391,50 +395,50 @@ mod tests { drop(buf); assert!(pool.try_get().is_some()); } - + #[test] fn test_hit_rate() { let pool = Arc::new(BufferPool::with_config(1024, 10)); - + // First get is a miss let buf1 = pool.get(); drop(buf1); - + // Second get is a hit let buf2 = pool.get(); drop(buf2); - + // Third get is a hit let _buf3 = pool.get(); - + let stats = pool.stats(); assert_eq!(stats.hits, 2); assert_eq!(stats.misses, 1); assert!((stats.hit_rate() - 66.67).abs() < 1.0); } - + #[test] fn test_scoped_buffer() { let pool = Arc::new(BufferPool::with_config(1024, 10)); let mut buf = pool.get(); - + { let mut scoped = ScopedBuffer::new(&mut buf); scoped.extend_from_slice(b"scoped data"); assert_eq!(&scoped[..], b"scoped data"); } - + // After scoped is dropped, buffer is cleared assert!(buf.is_empty()); } - + #[test] fn test_concurrent_access() { use std::thread; - + let pool = Arc::new(BufferPool::with_config(1024, 100)); let mut handles = vec![]; - + for _ in 0..10 { let pool_clone = Arc::clone(&pool); handles.push(thread::spawn(move || { @@ -445,11 +449,11 @@ mod tests { } })); } - + for handle in handles { handle.join().unwrap(); } - + let stats = pool.stats(); // All buffers should be returned assert!(stats.pooled > 0); diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs index 744b186..d962321 100644 --- a/src/stream/crypto_stream.rs +++ b/src/stream/crypto_stream.rs @@ -37,7 +37,7 @@ //! //! Backpressure //! - pending ciphertext buffer is bounded (configurable per connection) -//! - pending is full and upstream is pending +//! - pending is full and upstream is pending //! -> poll_write returns Poll::Pending //! -> do not accept any plaintext //! @@ -59,8 +59,8 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace}; -use crate::crypto::AesCtr; use super::state::{StreamState, YieldBuffer}; +use crate::crypto::AesCtr; // ============= Constants ============= @@ -152,9 +152,9 @@ impl CryptoReader { fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoReaderState::Poisoned { error } => error.take().unwrap_or_else(|| { - io::Error::other("stream previously poisoned") - }), + CryptoReaderState::Poisoned { error } => error + .take() + .unwrap_or_else(|| io::Error::other("stream previously poisoned")), _ => io::Error::other("stream not poisoned"), } } @@ -221,7 +221,11 @@ impl AsyncRead for CryptoReader { let filled = buf.filled_mut(); this.decryptor.apply(&mut filled[before..after]); - trace!(bytes_read, state = this.state_name(), "CryptoReader decrypted chunk"); + trace!( + bytes_read, + state = this.state_name(), + "CryptoReader decrypted chunk" + ); return Poll::Ready(Ok(())); } @@ -503,9 +507,9 @@ impl CryptoWriter { fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoWriterState::Poisoned { error } => error.take().unwrap_or_else(|| { - io::Error::other("stream previously poisoned") - }), + CryptoWriterState::Poisoned { error } => error + .take() + .unwrap_or_else(|| io::Error::other("stream previously poisoned")), _ => io::Error::other("stream not poisoned"), } } @@ -525,7 +529,11 @@ impl CryptoWriter { } /// Select how many plaintext bytes can be accepted in buffering path - fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize, max_pending: usize) -> usize { + fn select_to_accept_for_buffering( + state: &CryptoWriterState, + buf_len: usize, + max_pending: usize, + ) -> usize { if buf_len == 0 { return 0; } @@ -602,11 +610,7 @@ impl CryptoWriter { } impl AsyncWrite for CryptoWriter { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let this = self.get_mut(); // Poisoned? @@ -629,8 +633,11 @@ impl AsyncWrite for CryptoWriter { Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => { // Upstream blocked. Apply ideal backpressure - let to_accept = - Self::select_to_accept_for_buffering(&this.state, buf.len(), this.max_pending_write); + let to_accept = Self::select_to_accept_for_buffering( + &this.state, + buf.len(), + this.max_pending_write, + ); if to_accept == 0 { trace!( diff --git a/src/stream/frame.rs b/src/stream/frame.rs index 5c93ea7..08baf4c 100644 --- a/src/stream/frame.rs +++ b/src/stream/frame.rs @@ -9,8 +9,8 @@ use bytes::{Bytes, BytesMut}; use std::io::Result; use std::sync::Arc; -use crate::protocol::constants::ProtoTag; use crate::crypto::SecureRandom; +use crate::protocol::constants::ProtoTag; // ============= Frame Types ============= @@ -31,27 +31,27 @@ impl Frame { meta: FrameMeta::default(), } } - + /// Create a new frame with data and metadata pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self { Self { data, meta } } - + /// Create an empty frame pub fn empty() -> Self { Self::new(Bytes::new()) } - + /// Check if frame is empty pub fn is_empty(&self) -> bool { self.data.is_empty() } - + /// Get frame length pub fn len(&self) -> usize { self.data.len() } - + /// Create a QuickAck request frame pub fn quickack(data: Bytes) -> Self { Self { @@ -62,7 +62,7 @@ impl Frame { }, } } - + /// Create a simple ACK frame pub fn simple_ack(data: Bytes) -> Self { Self { @@ -91,25 +91,25 @@ impl FrameMeta { pub fn new() -> Self { Self::default() } - + /// Create with quickack flag pub fn with_quickack(mut self) -> Self { self.quickack = true; self } - + /// Create with simple_ack flag pub fn with_simple_ack(mut self) -> Self { self.simple_ack = true; self } - + /// Create with padding length pub fn with_padding(mut self, len: u8) -> Self { self.padding_len = len; self } - + /// Check if any special flags are set pub fn has_flags(&self) -> bool { self.quickack || self.simple_ack @@ -122,12 +122,12 @@ impl FrameMeta { pub trait FrameCodec: Send + Sync { /// Get the protocol tag for this codec fn proto_tag(&self) -> ProtoTag; - + /// Encode a frame into the destination buffer /// /// Returns the number of bytes written. fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result; - + /// Try to decode a frame from the source buffer /// /// Returns: @@ -137,10 +137,10 @@ pub trait FrameCodec: Send + Sync { /// /// On success, the consumed bytes are removed from `src`. fn decode(&self, src: &mut BytesMut) -> Result>; - + /// Get the minimum bytes needed to determine frame length fn min_header_size(&self) -> usize; - + /// Get the maximum allowed frame size fn max_frame_size(&self) -> usize { // Default: 16MB @@ -162,30 +162,28 @@ pub fn create_codec(proto_tag: ProtoTag, rng: Arc) -> Box Self { self.max_frame_size = size; self } - + /// Get protocol tag pub fn proto_tag(&self) -> ProtoTag { self.proto_tag @@ -56,7 +56,7 @@ impl FrameCodec { impl Decoder for FrameCodec { type Item = Frame; type Error = io::Error; - + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match self.proto_tag { ProtoTag::Abridged => decode_abridged(src, self.max_frame_size), @@ -68,7 +68,7 @@ impl Decoder for FrameCodec { impl Encoder for FrameCodec { type Error = io::Error; - + fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { match self.proto_tag { ProtoTag::Abridged => encode_abridged(&frame, dst), @@ -84,18 +84,18 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result= 0x80 { meta.quickack = true; } - + let header_len; - + if len_words == 0x7f { // Extended length (3 more bytes needed) if src.len() < 4 { @@ -106,46 +106,49 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result max_size { return Err(Error::new( ErrorKind::InvalidData, - format!("frame too large: {} bytes (max {})", byte_len, max_size) + format!("frame too large: {} bytes (max {})", byte_len, max_size), )); } - + let total_len = header_len + byte_len; - + if src.len() < total_len { // Reserve space for the rest of the frame src.reserve(total_len - src.len()); return Ok(None); } - + // Extract data let _ = src.split_to(header_len); let data = src.split_to(byte_len).freeze(); - + Ok(Some(Frame::with_meta(data, meta))) } fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { let data = &frame.data; - + // Validate alignment if !data.len().is_multiple_of(4) { return Err(Error::new( ErrorKind::InvalidInput, - format!("abridged frame must be 4-byte aligned, got {} bytes", data.len()) + format!( + "abridged frame must be 4-byte aligned, got {} bytes", + data.len() + ), )); } - + // Simple ACK: send reversed data without header if frame.meta.simple_ack { dst.reserve(data.len()); @@ -154,9 +157,9 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { } return Ok(()); } - + let len_words = data.len() / 4; - + if len_words < 0x7f { // Short header dst.reserve(1 + data.len()); @@ -178,10 +181,10 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { } else { return Err(Error::new( ErrorKind::InvalidInput, - format!("frame too large: {} bytes", data.len()) + format!("frame too large: {} bytes", data.len()), )); } - + dst.extend_from_slice(data); Ok(()) } @@ -192,58 +195,58 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result