diff --git a/Cargo.toml b/Cargo.toml index 725fa26..53082db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ static_assertions = "1.1" # Network socket2 = { version = "0.6", features = ["all"] } -nix = { version = "0.31", default-features = false, features = ["net"] } +nix = { version = "0.31", default-features = false, features = ["net", "fs"] } shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] } # Serialization diff --git a/src/api/mod.rs b/src/api/mod.rs index b622c5e..c1e3557 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; diff --git a/src/api/users.rs b/src/api/users.rs index 4793f89..2ee8b98 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -495,11 +495,11 @@ fn resolve_link_hosts( push_unique_host(&mut hosts, host); continue; } - if let Some(ip) = listener.announce_ip { - if !ip.is_unspecified() { - push_unique_host(&mut hosts, &ip.to_string()); - continue; - } + if let Some(ip) = listener.announce_ip + && !ip.is_unspecified() + { + push_unique_host(&mut hosts, &ip.to_string()); + continue; } if listener.ip.is_unspecified() { let detected_ip = if listener.ip.is_ipv4() { diff --git a/src/maestro/connectivity.rs b/src/maestro/connectivity.rs index ee5fdb9..0cb561d 100644 --- a/src/maestro/connectivity.rs +++ b/src/maestro/connectivity.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::sync::Arc; use std::time::Instant; diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index ffa4d1b..35f796f 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -1,3 +1,5 @@ +#![allow(clippy::items_after_test_module)] + use std::path::PathBuf; use std::time::Duration; diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index c668734..022f8ae 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::sync::Arc; use std::time::Duration; diff --git a/src/network/probe.rs b/src/network/probe.rs index 098e2eb..1787b92 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +#![allow(clippy::items_after_test_module)] use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; @@ -197,11 +198,10 @@ pub async fn run_probe( if nat_probe && probe.reflected_ipv4.is_none() && probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false) + && let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await { - if let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await { - probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); - info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); - } + probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); + info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); } probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) { @@ -286,8 +286,6 @@ async fn probe_stun_servers_parallel( while next_idx < servers.len() && join_set.len() < concurrency { let stun_addr = servers[next_idx].clone(); next_idx += 1; - let bind_v4 = bind_v4; - let bind_v6 = bind_v6; join_set.spawn(async move { let res = timeout(STUN_BATCH_TIMEOUT, async { let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?; diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 82527ca..b9bca49 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -5,6 +5,61 @@ //! actually carries MTProto authentication data. #![allow(dead_code)] +#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] +#![cfg_attr( + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) +)] +#![cfg_attr( + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) +)] use super::constants::*; use crate::crypto::{SecureRandom, sha256_hmac}; @@ -127,7 +182,6 @@ impl TlsExtensionBuilder { } /// Build final extensions with length prefix - fn build(self) -> Vec { let mut result = Vec::with_capacity(2 + self.extensions.len()); @@ -142,7 +196,6 @@ impl TlsExtensionBuilder { } /// Get current extensions without length prefix (for calculation) - fn as_bytes(&self) -> &[u8] { &self.extensions } @@ -258,7 +311,6 @@ impl ServerHelloBuilder { /// Returns validation result if a matching user is found. /// The result **must** be used — ignoring it silently bypasses authentication. #[must_use] - pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], @@ -628,11 +680,10 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { if name_type == 0 && name_len > 0 && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) + && is_valid_sni_hostname(host) { - if is_valid_sni_hostname(host) { - extracted_sni = Some(host.to_string()); - break; - } + extracted_sni = Some(host.to_string()); + break; } sn_pos += name_len; } @@ -754,7 +805,6 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { } /// Parse TLS record header, returns (record_type, length) - pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { let record_type = header[0]; let version = [header[1], header[2]]; diff --git a/src/proxy/adaptive_buffers.rs b/src/proxy/adaptive_buffers.rs index bb61858..0c210dd 100644 --- a/src/proxy/adaptive_buffers.rs +++ b/src/proxy/adaptive_buffers.rs @@ -1,3 +1,8 @@ +#![allow(dead_code)] + +// Adaptive buffer policy is staged and retained for deterministic rollout. +// Keep definitions compiled for compatibility and security test scaffolding. + use dashmap::DashMap; use std::cmp::max; use std::sync::OnceLock; diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 7b2572e..5d9c450 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -24,13 +24,13 @@ use crate::proxy::route_mode::{ use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +#[cfg(unix)] +use nix::fcntl::{Flock, FlockArg, OFlag, openat}; +#[cfg(unix)] +use nix::sys::stat::Mode; -#[cfg(unix)] -use std::os::unix::ffi::OsStrExt; #[cfg(unix)] use std::os::unix::fs::OpenOptionsExt; -#[cfg(unix)] -use std::os::unix::io::{AsRawFd, FromRawFd}; const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); @@ -170,32 +170,16 @@ fn open_unknown_dc_log_append_anchored( .custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC) .open(&path.allowed_parent)?; - let file_name = - std::ffi::CString::new(path.file_name.as_os_str().as_bytes()).map_err(|_| { - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "unknown DC log file name contains NUL byte", - ) - })?; - - let fd = unsafe { - libc::openat( - parent.as_raw_fd(), - file_name.as_ptr(), - libc::O_CREAT - | libc::O_APPEND - | libc::O_WRONLY - | libc::O_NOFOLLOW - | libc::O_CLOEXEC, - 0o600, - ) - }; - - if fd < 0 { - return Err(std::io::Error::last_os_error()); - } - - let file = unsafe { std::fs::File::from_raw_fd(fd) }; + let oflags = OFlag::O_CREAT + | OFlag::O_APPEND + | OFlag::O_WRONLY + | OFlag::O_NOFOLLOW + | OFlag::O_CLOEXEC; + let mode = Mode::from_bits_truncate(0o600); + let path_component = Path::new(path.file_name.as_os_str()); + let fd = openat(&parent, path_component, oflags, mode) + .map_err(|err| std::io::Error::from_raw_os_error(err as i32))?; + let file = std::fs::File::from(fd); Ok(file) } #[cfg(not(unix))] @@ -211,16 +195,13 @@ fn open_unknown_dc_log_append_anchored( fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> { #[cfg(unix)] { - if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) } != 0 { - return Err(std::io::Error::last_os_error()); - } - - let write_result = writeln!(file, "dc_idx={dc_idx}"); - - if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } != 0 { - return Err(std::io::Error::last_os_error()); - } - + let cloned = file.try_clone()?; + let mut locked = Flock::lock(cloned, FlockArg::LockExclusive) + .map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?; + let write_result = writeln!(&mut *locked, "dc_idx={dc_idx}"); + let _ = locked + .unlock() + .map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?; write_result } #[cfg(not(unix))] diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index f3e3727..5632977 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -626,7 +626,7 @@ where let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { let selected_domain = if let Some(sni) = client_sni.as_ref() { - if cache.contains_domain(&sni).await { + if cache.contains_domain(sni).await { sni.clone() } else { config.censorship.tls_domain.clone() @@ -954,7 +954,6 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, A } /// Encrypt nonce for sending to Telegram (legacy function for compatibility) - pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce); encrypted diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index adbb3ad..509b01e 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -316,11 +316,11 @@ pub async fn handle_bad_client( peer, local_addr, ); - if let Some(header) = proxy_header { - if !write_proxy_header_with_timeout(&mut mask_write, &header).await { - wait_mask_outcome_budget(outcome_started, config).await; - return; - } + if let Some(header) = proxy_header + && !write_proxy_header_with_timeout(&mut mask_write, &header).await + { + wait_mask_outcome_budget(outcome_started, config).await; + return; } if timeout( MASK_RELAY_TIMEOUT, @@ -387,11 +387,11 @@ pub async fn handle_bad_client( build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); let (mask_read, mut mask_write) = stream.into_split(); - if let Some(header) = proxy_header { - if !write_proxy_header_with_timeout(&mut mask_write, &header).await { - wait_mask_outcome_budget(outcome_started, config).await; - return; - } + if let Some(header) = proxy_header + && !write_proxy_header_with_timeout(&mut mask_write, &header).await + { + wait_mask_outcome_budget(outcome_started, config).await; + return; } if timeout( MASK_RELAY_TIMEOUT, diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 21fda15..d8d94d2 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,7 +1,6 @@ use std::collections::hash_map::RandomState; use std::collections::{BTreeSet, HashMap}; -use std::hash::BuildHasher; -use std::hash::{Hash, Hasher}; +use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; @@ -286,9 +285,7 @@ impl MeD2cFlushPolicy { fn hash_value(value: &T) -> u64 { let state = DESYNC_HASHER.get_or_init(RandomState::new); - let mut hasher = state.build_hasher(); - value.hash(&mut hasher); - hasher.finish() + state.hash_one(value) } fn hash_ip(ip: IpAddr) -> u64 { @@ -686,7 +683,6 @@ where .max(C2ME_CHANNEL_CAPACITY_FALLBACK); let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); - let effective_tag = effective_tag; let c2me_sender = tokio::spawn(async move { let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { @@ -1645,3 +1641,7 @@ mod idle_policy_security_tests; #[cfg(test)] #[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"] mod desync_all_full_dedup_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_stub_completion_security_tests.rs"] +mod stub_completion_security_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 3db6000..eebc188 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,5 +1,63 @@ //! Proxy Defs +// Apply strict linting to proxy production code while keeping test builds noise-tolerant. +#![cfg_attr(test, allow(warnings))] +#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] +#![cfg_attr( + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) +)] +#![cfg_attr( + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) +)] + pub mod adaptive_buffers; pub mod client; pub mod direct_relay; diff --git a/src/proxy/session_eviction.rs b/src/proxy/session_eviction.rs index c735cae..800e5b8 100644 --- a/src/proxy/session_eviction.rs +++ b/src/proxy/session_eviction.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + /// Session eviction is intentionally disabled in runtime. /// /// The initial `user+dc` single-lease model caused valid parallel client diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index 3a5ba78..16fe8da 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -757,6 +757,284 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { ); } +#[cfg(unix)] +#[test] +fn anchored_open_nix_path_writes_expected_lines() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-open-ok-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-open-ok base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let mut first = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must create log file in allowed parent"); + append_unknown_dc_line(&mut first, 31_200).expect("first append must succeed"); + + let mut second = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored reopen must succeed for existing regular file"); + append_unknown_dc_line(&mut second, 31_201).expect("second append must succeed"); + + let content = + fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable"); + let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + assert_eq!(lines.len(), 2, "expected one line per anchored append call"); + assert!( + lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"), + "anchored append output must contain both expected dc_idx lines" + ); +} + +#[cfg(unix)] +#[test] +fn anchored_open_parallel_appends_preserve_line_integrity() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-open-parallel-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-open-parallel base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let mut workers = Vec::new(); + for idx in 0..64i16 { + let sanitized = sanitized.clone(); + workers.push(std::thread::spawn(move || { + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must succeed in worker"); + append_unknown_dc_line(&mut file, 32_000 + idx).expect("worker append must succeed"); + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + let content = + fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable"); + let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + assert_eq!(lines.len(), 64, "expected one complete line per worker append"); + for line in lines { + assert!( + line.starts_with("dc_idx="), + "line must keep dc_idx prefix and not be interleaved: {line}" + ); + let value = line + .strip_prefix("dc_idx=") + .expect("prefix checked above") + .parse::(); + assert!( + value.is_ok(), + "line payload must remain parseable i16 and not be corrupted: {line}" + ); + } +} + +#[cfg(unix)] +#[test] +fn anchored_open_creates_private_0600_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-perms-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-perms base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must create file with restricted mode"); + append_unknown_dc_line(&mut file, 31_210).expect("initial append must succeed"); + drop(file); + + let mode = fs::metadata(&sanitized.resolved_path) + .expect("created log file metadata must be readable") + .permissions() + .mode() + & 0o777; + assert_eq!( + mode, 0o600, + "anchored open must create unknown-dc log file with owner-only rw permissions" + ); +} + +#[cfg(unix)] +#[test] +fn anchored_open_rejects_existing_symlink_target() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-symlink-target-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-symlink-target base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-anchored-symlink-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside\n").expect("outside baseline file must be writable"); + + let _ = fs::remove_file(&sanitized.resolved_path); + symlink(&outside, &sanitized.resolved_path) + .expect("target symlink for anchored-open rejection test must be creatable"); + + let err = open_unknown_dc_log_append_anchored(&sanitized) + .expect_err("anchored open must reject symlinked filename target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "anchored open should fail closed with ELOOP on symlinked target" + ); +} + +#[cfg(unix)] +#[test] +fn anchored_open_high_contention_multi_write_preserves_complete_lines() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-anchored-contention-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("anchored-contention base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let _ = fs::remove_file(&sanitized.resolved_path); + + let workers = 24usize; + let rounds = 40usize; + let mut threads = Vec::new(); + + for worker in 0..workers { + let sanitized = sanitized.clone(); + threads.push(std::thread::spawn(move || { + for round in 0..rounds { + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must succeed under contention"); + let dc_idx = 20_000i16.wrapping_add((worker * rounds + round) as i16); + append_unknown_dc_line(&mut file, dc_idx) + .expect("each contention append must complete"); + } + })); + } + + for thread in threads { + thread.join().expect("contention worker must not panic"); + } + + let content = fs::read_to_string(&sanitized.resolved_path) + .expect("contention output file must be readable"); + let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + assert_eq!( + lines.len(), + workers * rounds, + "every contention append must produce exactly one line" + ); + + let mut unique = std::collections::HashSet::new(); + for line in lines { + assert!( + line.starts_with("dc_idx="), + "line must preserve expected prefix under heavy contention: {line}" + ); + let value = line + .strip_prefix("dc_idx=") + .expect("prefix validated") + .parse::() + .expect("line payload must remain parseable i16 under contention"); + unique.insert(value); + } + + assert_eq!( + unique.len(), + workers * rounds, + "contention output must not lose or duplicate logical writes" + ); +} + +#[cfg(unix)] +#[test] +fn append_unknown_dc_line_returns_error_for_read_only_descriptor() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!( + "telemt-unknown-dc-append-ro-{}", + std::process::id() + )); + fs::create_dir_all(&base).expect("append-ro base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-append-ro-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = + sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable"); + + let mut readonly = std::fs::OpenOptions::new() + .read(true) + .open(&sanitized.resolved_path) + .expect("readonly file open must succeed"); + + append_unknown_dc_line(&mut readonly, 31_222) + .expect_err("append on readonly descriptor must fail closed"); + + let content_after = + fs::read_to_string(&sanitized.resolved_path).expect("seed file must remain readable"); + assert_eq!( + nonempty_line_count(&content_after), + 1, + "failed readonly append must not modify persisted unknown-dc log content" + ); +} + #[tokio::test] async fn unknown_dc_absolute_log_path_writes_one_entry() { let _guard = unknown_dc_test_lock() diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs index 5bb6d45..8b4f7f1 100644 --- a/src/proxy/tests/middle_relay_security_tests.rs +++ b/src/proxy/tests/middle_relay_security_tests.rs @@ -953,24 +953,6 @@ fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { panic!("expected at least one post-window sample to re-emit forensic record"); } -#[test] -#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] -fn should_emit_full_desync_filters_duplicates() { - unimplemented!("Stub for M-04"); -} - -#[test] -#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"] -fn desync_dedup_eviction_under_map_full_condition() { - unimplemented!("Stub for M-04"); -} - -#[tokio::test] -#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"] -async fn c2me_channel_full_path_yields_then_sends() { - unimplemented!("Stub for M-05"); -} - fn make_forensics_state() -> RelayForensicsState { RelayForensicsState { trace_id: 1, diff --git a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs new file mode 100644 index 0000000..2635a28 --- /dev/null +++ b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs @@ -0,0 +1,168 @@ +use super::*; +use crate::stream::BufferPool; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::time::{Duration as TokioDuration, timeout}; + +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +#[test] +#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] +fn should_emit_full_desync_filters_duplicates() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let key = 0x4D04_0000_0000_0001_u64; + let base = Instant::now(); + + assert!( + should_emit_full_desync(key, false, base), + "first occurrence must emit full forensic record" + ); + assert!( + !should_emit_full_desync(key, false, base), + "duplicate at same timestamp must be suppressed" + ); + + let within_window = base + DESYNC_DEDUP_WINDOW - TokioDuration::from_millis(1); + assert!( + !should_emit_full_desync(key, false, within_window), + "duplicate strictly inside dedup window must stay suppressed" + ); + + let on_window_edge = base + DESYNC_DEDUP_WINDOW; + assert!( + should_emit_full_desync(key, false, on_window_edge), + "duplicate at window boundary must re-emit and refresh" + ); +} + +#[test] +#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"] +fn desync_dedup_eviction_under_map_full_condition() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let base = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!( + should_emit_full_desync(key, false, base), + "unique key should be inserted while warming dedup cache" + ); + } + + let dedup = DESYNC_DEDUP + .get() + .expect("dedup map must exist after warm-up insertions"); + assert_eq!( + dedup.len(), + DESYNC_DEDUP_MAX_ENTRIES, + "cache warm-up must reach exact hard cap" + ); + + let before_keys: HashSet = dedup.iter().map(|entry| *entry.key()).collect(); + let newcomer_key = 0x4D04_FFFF_FFFF_0001_u64; + + assert!( + should_emit_full_desync(newcomer_key, false, base), + "first newcomer at map-full must emit under bounded full-cache gate" + ); + + let after_keys: HashSet = dedup.iter().map(|entry| *entry.key()).collect(); + assert_eq!( + dedup.len(), + DESYNC_DEDUP_MAX_ENTRIES, + "map-full insertion must preserve hard capacity bound" + ); + assert!( + after_keys.contains(&newcomer_key), + "newcomer must be present after bounded eviction path" + ); + + let removed_count = before_keys.difference(&after_keys).count(); + let added_count = after_keys.difference(&before_keys).count(); + assert_eq!( + removed_count, 1, + "map-full insertion must evict exactly one prior key" + ); + assert_eq!( + added_count, 1, + "map-full insertion must add exactly one newcomer key" + ); + + assert!( + !should_emit_full_desync(newcomer_key, false, base), + "immediate duplicate newcomer must remain suppressed" + ); +} + +#[tokio::test] +#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"] +async fn c2me_channel_full_path_yields_then_sends() { + let (tx, mut rx) = mpsc::channel::(1); + + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0xAA]), + flags: 1, + }) + .await + .expect("priming queue with one frame must succeed"); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload(&[0xBB, 0xCC]), + flags: 2, + }, + ) + .await + }); + + tokio::task::yield_now().await; + tokio::time::sleep(TokioDuration::from_millis(10)).await; + assert!( + !producer.is_finished(), + "producer should stay pending while queue is full" + ); + + let first = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .expect("receiver should observe primed frame") + .expect("first queued command must exist"); + match first { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[0xAA]); + assert_eq!(flags, 1); + } + C2MeCommand::Close => panic!("unexpected close command as first item"), + } + + producer + .await + .expect("producer task must not panic") + .expect("blocked enqueue must succeed once receiver drains capacity"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .expect("receiver should observe backpressure-resumed frame") + .expect("second queued command must exist"); + match second { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[0xBB, 0xCC]); + assert_eq!(flags, 2); + } + C2MeCommand::Close => panic!("unexpected close command as second item"), + } +} diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index d7bb153..2542e37 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -652,7 +652,7 @@ mod tests { let mut out = BytesMut::new(); codec.encode(&frame, &mut out).unwrap(); - assert!(out.len() >= 4 + payload.len() + 1); + assert!(out.len() > 4 + payload.len()); let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize; assert!( (payload.len() + 1..=payload.len() + 3).contains(&wire_len), diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index 5ee66f0..e9f1d3e 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -584,7 +584,7 @@ mod tests { // Long frame (> 0x7f words = 508 bytes) let data: Vec = (0..1000).map(|i| (i % 256) as u8).collect(); - let padded_len = (data.len() + 3) / 4 * 4; + let padded_len = data.len().div_ceil(4) * 4; let mut padded = data.clone(); padded.resize(padded_len, 0); diff --git a/src/stream/frame_stream_padding_security_tests.rs b/src/stream/frame_stream_padding_security_tests.rs new file mode 100644 index 0000000..83b30f9 --- /dev/null +++ b/src/stream/frame_stream_padding_security_tests.rs @@ -0,0 +1,56 @@ +fn old_padding_round_up_to_4(len: usize) -> Option { + len.checked_add(3) + .map(|sum| sum / 4) + .and_then(|words| words.checked_mul(4)) +} + +fn new_padding_round_up_to_4(len: usize) -> Option { + len.div_ceil(4).checked_mul(4) +} + +#[test] +fn padding_rounding_equivalent_for_extensive_safe_domain() { + for len in 0usize..=200_000usize { + let old = old_padding_round_up_to_4(len).expect("old expression must be safe"); + let new = new_padding_round_up_to_4(len).expect("new expression must be safe"); + assert_eq!(old, new, "mismatch for len={len}"); + assert!(new >= len, "rounded length must not shrink: len={len}, out={new}"); + assert_eq!(new % 4, 0, "rounded length must stay 4-byte aligned"); + } +} + +#[test] +fn padding_rounding_equivalent_near_usize_limit_when_old_is_defined() { + let candidates = [ + usize::MAX - 3, + usize::MAX - 4, + usize::MAX - 5, + usize::MAX - 6, + usize::MAX - 7, + usize::MAX - 8, + usize::MAX - 15, + usize::MAX / 2, + (usize::MAX / 2) + 1, + ]; + + for len in candidates { + let old = old_padding_round_up_to_4(len); + let new = new_padding_round_up_to_4(len); + if let Some(old_val) = old { + assert_eq!(Some(old_val), new, "safe-domain mismatch for len={len}"); + } + } +} + +#[test] +fn padding_rounding_documents_overflow_boundary_behavior() { + // For very large lengths, arithmetic round-up may overflow regardless of spelling. + // This documents the boundary so future changes do not assume universal safety. + assert_eq!(old_padding_round_up_to_4(usize::MAX), None); + assert_eq!(old_padding_round_up_to_4(usize::MAX - 1), None); + assert_eq!(old_padding_round_up_to_4(usize::MAX - 2), None); + + // The div_ceil form avoids `len + 3` overflow, but final `* 4` can still overflow. + assert_eq!(new_padding_round_up_to_4(usize::MAX), None); + assert_eq!(new_padding_round_up_to_4(usize::MAX - 1), None); +} diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 8925fa0..13cd806 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -8,6 +8,9 @@ pub mod state; pub mod tls_stream; pub mod traits; +#[cfg(test)] +mod frame_stream_padding_security_tests; + // Legacy compatibility - will be removed later pub mod frame_stream; diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index 7a15365..3f100d1 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -154,7 +154,7 @@ impl TlsRecordHeader { } TLS_RECORD_HANDSHAKE => { - if len < 4 || len > MAX_TLS_PLAINTEXT_SIZE { + if !(4..=MAX_TLS_PLAINTEXT_SIZE).contains(&len) { return Err(Error::new( ErrorKind::InvalidData, format!( diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 972d1ca..80f2b1b 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use crate::crypto::{SecureRandom, sha256_hmac}; use crate::protocol::constants::{ MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 824e155..4408b5a 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::sync::Arc; use std::time::Duration; @@ -810,7 +812,8 @@ mod tests { #[test] fn test_encode_tls13_certificate_message_single_cert() { let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01]; - let message = encode_tls13_certificate_message(&[cert.clone()]).expect("message"); + let message = encode_tls13_certificate_message(std::slice::from_ref(&cert)) + .expect("message"); assert_eq!(message[0], 0x0b); assert_eq!(read_u24(&message[1..4]), message.len() - 4); diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 4479046..8e5a701 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -355,49 +355,49 @@ async fn run_update_cycle( let mut ready_v4: Option<(ProxyConfigData, u64)> = None; let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await; - if let Some(cfg_v4) = cfg_v4 { - if snapshot_passes_guards(cfg, &cfg_v4, "getProxyConfig") { - let cfg_v4_hash = hash_proxy_config(&cfg_v4); - let stable_hits = state.config_v4.observe(cfg_v4_hash); - if stable_hits < required_cfg_snapshots { - debug!( - stable_hits, - required_cfg_snapshots, - snapshot = format_args!("0x{cfg_v4_hash:016x}"), - "ME config v4 candidate observed" - ); - } else if state.config_v4.is_applied(cfg_v4_hash) { - debug!( - snapshot = format_args!("0x{cfg_v4_hash:016x}"), - "ME config v4 stable snapshot already applied" - ); - } else { - ready_v4 = Some((cfg_v4, cfg_v4_hash)); - } + if let Some(cfg_v4) = cfg_v4 + && snapshot_passes_guards(cfg, &cfg_v4, "getProxyConfig") + { + let cfg_v4_hash = hash_proxy_config(&cfg_v4); + let stable_hits = state.config_v4.observe(cfg_v4_hash); + if stable_hits < required_cfg_snapshots { + debug!( + stable_hits, + required_cfg_snapshots, + snapshot = format_args!("0x{cfg_v4_hash:016x}"), + "ME config v4 candidate observed" + ); + } else if state.config_v4.is_applied(cfg_v4_hash) { + debug!( + snapshot = format_args!("0x{cfg_v4_hash:016x}"), + "ME config v4 stable snapshot already applied" + ); + } else { + ready_v4 = Some((cfg_v4, cfg_v4_hash)); } } let mut ready_v6: Option<(ProxyConfigData, u64)> = None; let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await; - if let Some(cfg_v6) = cfg_v6 { - if snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") { - let cfg_v6_hash = hash_proxy_config(&cfg_v6); - let stable_hits = state.config_v6.observe(cfg_v6_hash); - if stable_hits < required_cfg_snapshots { - debug!( - stable_hits, - required_cfg_snapshots, - snapshot = format_args!("0x{cfg_v6_hash:016x}"), - "ME config v6 candidate observed" - ); - } else if state.config_v6.is_applied(cfg_v6_hash) { - debug!( - snapshot = format_args!("0x{cfg_v6_hash:016x}"), - "ME config v6 stable snapshot already applied" - ); - } else { - ready_v6 = Some((cfg_v6, cfg_v6_hash)); - } + if let Some(cfg_v6) = cfg_v6 + && snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") + { + let cfg_v6_hash = hash_proxy_config(&cfg_v6); + let stable_hits = state.config_v6.observe(cfg_v6_hash); + if stable_hits < required_cfg_snapshots { + debug!( + stable_hits, + required_cfg_snapshots, + snapshot = format_args!("0x{cfg_v6_hash:016x}"), + "ME config v6 candidate observed" + ); + } else if state.config_v6.is_applied(cfg_v6_hash) { + debug!( + snapshot = format_args!("0x{cfg_v6_hash:016x}"), + "ME config v6 stable snapshot already applied" + ); + } else { + ready_v6 = Some((cfg_v6, cfg_v6_hash)); } } diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index bb47e43..3e53f38 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::collections::HashMap; use std::collections::HashSet; use std::net::SocketAddr; diff --git a/src/transport/middle_proxy/ping.rs b/src/transport/middle_proxy/ping.rs index 4432282..bff088b 100644 --- a/src/transport/middle_proxy/ping.rs +++ b/src/transport/middle_proxy/ping.rs @@ -1,3 +1,5 @@ +#![allow(clippy::items_after_test_module)] + use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; @@ -12,6 +14,8 @@ use crate::transport::{UpstreamEgressInfo, UpstreamRouteKind}; use super::MePool; +type MePingGroup = (MePingFamily, i32, Vec<(IpAddr, u16)>); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MePingFamily { V4, @@ -137,14 +141,14 @@ fn detect_interface_for_ip(ip: IpAddr) -> Option { if let Ok(addrs) = getifaddrs() { for iface in addrs { if let Some(address) = iface.address { - if let Some(v4) = address.as_sockaddr_in() { - if IpAddr::V4(v4.ip()) == ip { - return Some(iface.interface_name); - } - } else if let Some(v6) = address.as_sockaddr_in6() { - if IpAddr::V6(v6.ip()) == ip { - return Some(iface.interface_name); - } + if let Some(v4) = address.as_sockaddr_in() + && IpAddr::V4(v4.ip()) == ip + { + return Some(iface.interface_name); + } else if let Some(v6) = address.as_sockaddr_in6() + && IpAddr::V6(v6.ip()) == ip + { + return Some(iface.interface_name); } } } @@ -329,7 +333,7 @@ pub async fn run_me_ping(pool: &Arc, rng: &SecureRandom) -> Vec)> = Vec::new(); + let mut grouped: Vec = Vec::new(); for (dc, addrs) in v4_map { grouped.push((MePingFamily::V4, dc, addrs)); } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index f6c17a3..71ab257 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments, clippy::type_complexity)] + use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -619,6 +621,7 @@ impl MePool { self.runtime_ready.load(Ordering::Relaxed) } + #[allow(dead_code)] pub(super) fn set_family_runtime_state( &self, family: IpFamily, @@ -982,28 +985,33 @@ impl MePool { Some(Duration::from_secs(secs)) } + #[allow(dead_code)] pub(super) fn drain_soft_evict_enabled(&self) -> bool { self.me_pool_drain_soft_evict_enabled .load(Ordering::Relaxed) } + #[allow(dead_code)] pub(super) fn drain_soft_evict_grace_secs(&self) -> u64 { self.me_pool_drain_soft_evict_grace_secs .load(Ordering::Relaxed) } + #[allow(dead_code)] pub(super) fn drain_soft_evict_per_writer(&self) -> usize { self.me_pool_drain_soft_evict_per_writer .load(Ordering::Relaxed) .max(1) as usize } + #[allow(dead_code)] pub(super) fn drain_soft_evict_budget_per_core(&self) -> usize { self.me_pool_drain_soft_evict_budget_per_core .load(Ordering::Relaxed) .max(1) as usize } + #[allow(dead_code)] pub(super) fn drain_soft_evict_cooldown(&self) -> Duration { Duration::from_millis( self.me_pool_drain_soft_evict_cooldown_ms diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index d2b234e..918ccd4 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -293,20 +293,20 @@ impl MePool { WriterContour::Draining => "draining", }; - if !draining { - if let Some(dc_idx) = dc { - *live_writers_by_dc_endpoint - .entry((dc_idx, endpoint)) - .or_insert(0) += 1; - *live_writers_by_dc.entry(dc_idx).or_insert(0) += 1; - if let Some(ema_ms) = rtt_ema_ms { - let entry = dc_rtt_agg.entry(dc_idx).or_insert((0.0, 0)); - entry.0 += ema_ms; - entry.1 += 1; - } - if matches_active_generation && in_desired_map { - *fresh_writers_by_dc.entry(dc_idx).or_insert(0) += 1; - } + if !draining + && let Some(dc_idx) = dc + { + *live_writers_by_dc_endpoint + .entry((dc_idx, endpoint)) + .or_insert(0) += 1; + *live_writers_by_dc.entry(dc_idx).or_insert(0) += 1; + if let Some(ema_ms) = rtt_ema_ms { + let entry = dc_rtt_agg.entry(dc_idx).or_insert((0.0, 0)); + entry.0 += ema_ms; + entry.1 += 1; + } + if matches_active_generation && in_desired_map { + *fresh_writers_by_dc.entry(dc_idx).or_insert(0) += 1; } } diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index bba336b..22fb909 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -268,10 +268,10 @@ impl MePool { cancel_reader_token.cancel(); } } - if let Err(e) = res { - if !idle_close_by_peer { - warn!(error = %e, "ME reader ended"); - } + if let Err(e) = res + && !idle_close_by_peer + { + warn!(error = %e, "ME reader ended"); } let remaining = writers_arc.read().await.len(); debug!(writer_id, remaining, "ME reader task finished"); @@ -386,10 +386,9 @@ impl MePool { if cleanup_for_ping .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) .is_ok() + && let Some(pool) = pool_ping.upgrade() { - if let Some(pool) = pool_ping.upgrade() { - pool.remove_writer_and_close_clients(writer_id).await; - } + pool.remove_writer_and_close_clients(writer_id).await; } break; } @@ -538,6 +537,7 @@ impl MePool { .await } + #[allow(dead_code)] async fn remove_writer_only(self: &Arc, writer_id: u64) -> bool { self.remove_writer_with_mode(writer_id, WriterTeardownMode::Any) .await diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 44d7464..4137b2b 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::collections::HashMap; use std::io::ErrorKind; use std::sync::Arc; diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 4226081..0a95e18 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -165,6 +165,7 @@ impl ConnRegistry { None } + #[allow(dead_code)] pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { let tx = { let inner = self.inner.read().await; @@ -438,6 +439,7 @@ impl ConnRegistry { .unwrap_or(true) } + #[allow(dead_code)] pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { let mut inner = self.inner.write().await; let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else { diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index e73defe..b1cf54e 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + use std::cmp::Reverse; use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; @@ -593,7 +595,7 @@ impl MePool { let round = *hybrid_recovery_round; let target_triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await; - if !target_triggered || round % HYBRID_GLOBAL_BURST_PERIOD_ROUNDS == 0 { + if !target_triggered || round.is_multiple_of(HYBRID_GLOBAL_BURST_PERIOD_ROUNDS) { self.trigger_async_recovery_global().await; } *hybrid_recovery_round = round.saturating_add(1); @@ -672,7 +674,7 @@ impl MePool { if !self.writer_eligible_for_selection(w, include_warm) { continue; } - if w.writer_dc == routed_dc && preferred.iter().any(|endpoint| *endpoint == w.addr) { + if w.writer_dc == routed_dc && preferred.contains(&w.addr) { out.push(idx); } } diff --git a/src/transport/pool.rs b/src/transport/pool.rs index a20ba04..60f8a01 100644 --- a/src/transport/pool.rs +++ b/src/transport/pool.rs @@ -199,8 +199,12 @@ impl ConnectionPool { /// Close all pooled connections pub async fn close_all(&self) { - let pools = self.pools.read(); - for (addr, pool) in pools.iter() { + let pools_snapshot: Vec<(SocketAddr, Arc>)> = { + let pools = self.pools.read(); + pools.iter().map(|(addr, pool)| (*addr, Arc::clone(pool))).collect() + }; + + for (addr, pool) in pools_snapshot { let mut inner = pool.lock().await; let count = inner.connections.len(); inner.connections.clear(); @@ -210,12 +214,15 @@ impl ConnectionPool { /// Get pool statistics pub async fn stats(&self) -> PoolStats { - let pools = self.pools.read(); + let pools_snapshot: Vec>> = { + let pools = self.pools.read(); + pools.values().cloned().collect() + }; let mut total_connections = 0; let mut total_pending = 0; let mut endpoints = 0; - for pool in pools.values() { + for pool in pools_snapshot { let inner = pool.lock().await; total_connections += inner.connections.len(); total_pending += inner.pending;