diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index 5d7d9b6..15a864f 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -392,12 +392,18 @@ impl UserIpTracker { let now = Instant::now(); let (mut active_ips, mut recent_ips) = self.active_and_recent_write().await; - let user_active = active_ips - .entry(username.to_string()) - .or_insert_with(HashMap::new); - let user_recent = recent_ips - .entry(username.to_string()) - .or_insert_with(HashMap::new); + if !active_ips.contains_key(username) { + active_ips.insert(username.to_string(), HashMap::new()); + } + if !recent_ips.contains_key(username) { + recent_ips.insert(username.to_string(), HashMap::new()); + } + let Some(user_active) = active_ips.get_mut(username) else { + return Err(format!("IP tracker active entry unavailable for user '{username}'")); + }; + let Some(user_recent) = recent_ips.get_mut(username) else { + return Err(format!("IP tracker recent entry unavailable for user '{username}'")); + }; let pruned_recent_entries = Self::prune_recent(user_recent, now, window); Self::decrement_counter(&self.recent_entry_count, pruned_recent_entries); let recent_contains_ip = user_recent.contains_key(&ip); diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 3cd8799..85d737f 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -515,8 +515,13 @@ fn exclusive_mask_target_for_sni<'a>( config: &'a ProxyConfig, sni: &str, ) -> Option> { - for (domain, target) in &config.censorship.exclusive_mask { - if domain.eq_ignore_ascii_case(sni) { + if let Some(target) = config.censorship.exclusive_mask.get(sni) { + return parse_exclusive_mask_target(target); + } + + if sni.bytes().any(|byte| byte.is_ascii_uppercase()) { + let normalized_sni = sni.to_ascii_lowercase(); + if let Some(target) = config.censorship.exclusive_mask.get(&normalized_sni) { return parse_exclusive_mask_target(target); } } @@ -529,17 +534,27 @@ fn mask_host_for_initial_data<'a>(config: &'a ProxyConfig, initial_data: &[u8]) mask_tcp_target_for_initial_data(config, initial_data).host } +#[cfg(test)] fn mask_tcp_target_for_initial_data<'a>( config: &'a ProxyConfig, initial_data: &[u8], ) -> MaskTcpTarget<'a> { - if let Some(target) = tls::extract_sni_from_client_hello(initial_data) + let sni = tls::extract_sni_from_client_hello(initial_data); + if let Some(target) = sni .as_deref() .and_then(|sni| exclusive_mask_target_for_sni(config, sni)) { return target; } + default_mask_tcp_target_for_initial_data(config, initial_data, sni.as_deref()) +} + +fn default_mask_tcp_target_for_initial_data<'a>( + config: &'a ProxyConfig, + initial_data: &[u8], + sni: Option<&str>, +) -> MaskTcpTarget<'a> { let configured_mask_host = config .censorship .mask_host @@ -553,8 +568,13 @@ fn mask_tcp_target_for_initial_data<'a>( }; } - let host = tls::extract_sni_from_client_hello(initial_data) - .as_deref() + let extracted_sni = if sni.is_none() { + tls::extract_sni_from_client_hello(initial_data) + } else { + None + }; + let host = sni + .or(extracted_sni.as_deref()) .and_then(|sni| matching_tls_domain_for_sni(config, sni)) .unwrap_or(configured_mask_host); MaskTcpTarget { @@ -858,7 +878,8 @@ pub async fn handle_bad_client( return; } - let exclusive_tcp_target = tls::extract_sni_from_client_hello(initial_data) + let client_sni = tls::extract_sni_from_client_hello(initial_data); + let exclusive_tcp_target = client_sni .as_deref() .and_then(|sni| exclusive_mask_target_for_sni(config, sni)); @@ -943,8 +964,9 @@ pub async fn handle_bad_client( return; } - let mask_target = exclusive_tcp_target - .unwrap_or_else(|| mask_tcp_target_for_initial_data(config, initial_data)); + let mask_target = exclusive_tcp_target.unwrap_or_else(|| { + default_mask_tcp_target_for_initial_data(config, initial_data, client_sni.as_deref()) + }); let mask_host = mask_target.host; let mask_port = mask_target.port; diff --git a/src/stats/mod.rs b/src/stats/mod.rs index e13cf63..76d464d 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -2678,9 +2678,10 @@ struct ReplayEntry { } struct ReplayShard { - cache: LruCache, ReplayEntry>, - queue: VecDeque<(Instant, Box<[u8]>, u64)>, + cache: LruCache, ReplayEntry>, + queue: VecDeque<(Instant, Arc<[u8]>, u64)>, seq_counter: u64, + capacity: usize, } impl ReplayShard { @@ -2689,6 +2690,7 @@ impl ReplayShard { cache: LruCache::new(cap), queue: VecDeque::with_capacity(cap.get()), seq_counter: 0, + capacity: cap.get(), } } @@ -2709,15 +2711,19 @@ impl ReplayShard { if *ts >= cutoff { break; } - let (_, key, queue_seq) = self.queue.pop_front().unwrap(); + self.evict_queue_front(); + } + } - // Use key.as_ref() to get &[u8] — avoids Borrow ambiguity - // between Borrow<[u8]> and Borrow> - if let Some(entry) = self.cache.peek(key.as_ref()) - && entry.seq == queue_seq - { - self.cache.pop(key.as_ref()); - } + fn evict_queue_front(&mut self) { + let Some((_, key, queue_seq)) = self.queue.pop_front() else { + return; + }; + + if let Some(entry) = self.cache.peek(key.as_ref()) + && entry.seq == queue_seq + { + self.cache.pop(key.as_ref()); } } @@ -2738,13 +2744,16 @@ impl ReplayShard { if self.cache.peek(key).is_some() { return; } + while self.queue.len() >= self.capacity { + self.evict_queue_front(); + } let seq = self.next_seq(); - let boxed_key: Box<[u8]> = key.into(); + let shared_key: Arc<[u8]> = Arc::from(key); self.cache - .put(boxed_key.clone(), ReplayEntry { seen_at: now, seq }); - self.queue.push_back((now, boxed_key, seq)); + .put(Arc::clone(&shared_key), ReplayEntry { seen_at: now, seq }); + self.queue.push_back((now, shared_key, seq)); } fn len(&self) -> usize { diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index c341d50..d820854 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -9,6 +9,7 @@ use crate::protocol::constants::*; pub(crate) enum WriterCommand { Data(Bytes), DataAndFlush(Bytes), + ControlAndFlush([u8; 12]), Close, } @@ -42,6 +43,13 @@ pub(crate) fn rpc_crc(mode: RpcChecksumMode, data: &[u8]) -> u32 { } } +pub(crate) fn build_control_payload(tag: u32, value: u64) -> [u8; 12] { + let mut payload = [0u8; 12]; + payload[..4].copy_from_slice(&tag.to_le_bytes()); + payload[4..].copy_from_slice(&value.to_le_bytes()); + payload +} + pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8], crc_mode: RpcChecksumMode) -> Vec { let total_len = (4 + 4 + payload.len() + 4) as u32; let mut frame = Vec::with_capacity(total_len as usize); diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 0644e8d..218be30 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; use std::time::{Duration, Instant}; -use bytes::Bytes; use bytes::BytesMut; use rand::RngExt; use tokio::sync::mpsc; @@ -17,7 +16,7 @@ use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::protocol::constants::{RPC_CLOSE_EXT_U32, RPC_PING_U32}; -use super::codec::{RpcWriter, WriterCommand}; +use super::codec::{RpcWriter, WriterCommand, build_control_payload}; use super::pool::{MePool, MeWriter, WriterContour}; use super::reader::reader_loop; use super::wire::build_proxy_req_payload; @@ -61,6 +60,9 @@ async fn writer_command_loop( Some(WriterCommand::DataAndFlush(payload)) => { rpc_writer.send_and_flush(&payload).await?; } + Some(WriterCommand::ControlAndFlush(payload)) => { + rpc_writer.send_and_flush(&payload).await?; + } Some(WriterCommand::Close) | None => return Ok(()), } } @@ -130,9 +132,7 @@ async fn ping_loop( _ = tokio::time::sleep(wait) => {} } let sent_id = ping_id; - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); - p.extend_from_slice(&sent_id.to_le_bytes()); + let payload = build_control_payload(RPC_PING_U32, sent_id as u64); { let mut tracker = ping_tracker_ping.lock().await; cleanup_tick = cleanup_tick.wrapping_add(1); @@ -149,7 +149,7 @@ async fn ping_loop( ping_id = ping_id.wrapping_add(1); stats_ping.increment_me_keepalive_sent(); if tx_ping - .send(WriterCommand::DataAndFlush(Bytes::from(p))) + .send(WriterCommand::ControlAndFlush(payload)) .await .is_err() { @@ -253,12 +253,10 @@ async fn rpc_proxy_req_signal_loop( stats_signal.increment_me_rpc_proxy_req_signal_response_total(); } - let mut close_payload = Vec::with_capacity(12); - close_payload.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); - close_payload.extend_from_slice(&conn_id.to_le_bytes()); + let close_payload = build_control_payload(RPC_CLOSE_EXT_U32, conn_id); if tx_signal - .send(WriterCommand::DataAndFlush(Bytes::from(close_payload))) + .send(WriterCommand::ControlAndFlush(close_payload)) .await .is_err() { diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 2dae1f1..66bf83d 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -19,7 +19,7 @@ use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use crate::stats::Stats; -use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc}; +use super::codec::{RpcChecksumMode, WriterCommand, build_control_payload, rpc_crc}; use super::fairness::{ AdmissionDecision, DispatchAction, DispatchFeedback, PressureState, SchedulerDecision, WorkerFairnessConfig, WorkerFairnessSnapshot, WorkerFairnessState, @@ -464,10 +464,8 @@ pub(crate) async fn reader_loop( } else if pt == RPC_PING_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); trace!(ping_id, "RPC_PING -> RPC_PONG"); - let mut pong = Vec::with_capacity(12); - pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); - pong.extend_from_slice(&ping_id.to_le_bytes()); - match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(pong))) { + let pong = build_control_payload(RPC_PONG_U32, ping_id as u64); + match tx.try_send(WriterCommand::ControlAndFlush(pong)) { Ok(()) => {} Err(TrySendError::Full(_)) => { debug!(ping_id, "PONG dropped: writer command channel is full"); @@ -667,10 +665,8 @@ mod tests { } async fn send_close_conn(tx: &mpsc::Sender, conn_id: u64) { - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); - p.extend_from_slice(&conn_id.to_le_bytes()); - match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { + let payload = build_control_payload(RPC_CLOSE_CONN_U32, conn_id); + match tx.try_send(WriterCommand::ControlAndFlush(payload)) { Ok(()) => {} Err(TrySendError::Full(_)) => { debug!( diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 9a5c828..a637948 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -7,7 +7,6 @@ use std::sync::Arc; use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; -use bytes::Bytes; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, warn}; @@ -17,7 +16,7 @@ use crate::network::IpFamily; use crate::protocol::constants::{RPC_CLOSE_CONN_U32, RPC_CLOSE_EXT_U32}; use super::MePool; -use super::codec::WriterCommand; +use super::codec::{WriterCommand, build_control_payload}; use super::pool::WriterContour; use super::registry::ConnMeta; use super::wire::build_proxy_req_payload; @@ -735,11 +734,9 @@ impl MePool { pub async fn send_close(self: &Arc, conn_id: u64) -> Result<()> { if let Some(w) = self.registry.get_writer(conn_id).await { - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); - p.extend_from_slice(&conn_id.to_le_bytes()); + let payload = build_control_payload(RPC_CLOSE_EXT_U32, conn_id); if w.tx - .send(WriterCommand::DataAndFlush(Bytes::from(p))) + .send(WriterCommand::ControlAndFlush(payload)) .await .is_err() { @@ -756,10 +753,8 @@ impl MePool { pub async fn send_close_conn(self: &Arc, conn_id: u64) -> Result<()> { if let Some(w) = self.registry.get_writer(conn_id).await { - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); - p.extend_from_slice(&conn_id.to_le_bytes()); - match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { + let payload = build_control_payload(RPC_CLOSE_CONN_U32, conn_id); + match w.tx.try_send(WriterCommand::ControlAndFlush(payload)) { Ok(()) => {} Err(TrySendError::Full(cmd)) => { let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs index 6e556d4..074fa33 100644 --- a/src/transport/middle_proxy/tests/send_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -165,6 +165,7 @@ async fn recv_data_count(rx: &mut mpsc::Receiver, budget: Duratio match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await { Ok(Some(WriterCommand::Data(_))) => data_count += 1, Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1, + Ok(Some(WriterCommand::ControlAndFlush(_))) => data_count += 1, Ok(Some(WriterCommand::Close)) => {} Ok(None) => break, Err(_) => break,