feat(proxy): enhance logging and deduplication for unknown datacenters

- Implemented a mechanism to log unknown datacenter indices with a distinct limit to avoid excessive logging.
- Introduced tests to ensure that logging is deduplicated per datacenter index and respects the distinct limit.
- Updated the fallback logic for datacenter resolution to prevent panics when only a single datacenter is available.

feat(proxy): add authentication probe throttling

- Added a pre-authentication probe throttling mechanism to limit the rate of invalid TLS and MTProto handshake attempts.
- Introduced a backoff strategy for repeated failures and ensured that successful handshakes reset the failure count.
- Implemented tests to validate the behavior of the authentication probe under various conditions.

fix(proxy): ensure proper flushing of masked writes

- Added a flush operation after writing initial data to the mask writer to ensure data integrity.

refactor(proxy): optimize desynchronization deduplication

- Replaced the Mutex-based deduplication structure with a DashMap for improved concurrency and performance.
- Implemented a bounded cache for deduplication to limit memory usage and prevent stale entries from persisting.

test(proxy): enhance security tests for middle relay and handshake

- Added comprehensive tests for the middle relay and handshake processes, including scenarios for deduplication and authentication probe behavior.
- Ensured that the tests cover edge cases and validate the expected behavior of the system under load.
This commit is contained in:
David Osipov
2026-03-17 01:29:30 +04:00
parent e4a50f9286
commit 205fc88718
15 changed files with 1124 additions and 150 deletions

View File

@@ -1,12 +1,14 @@
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
#[cfg(test)]
use std::sync::Mutex;
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch};
use tracing::{debug, trace, warn};
@@ -30,13 +32,15 @@ enum C2MeCommand {
}
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
static DESYNC_DEDUP: OnceLock<Mutex<HashMap<u64, Instant>>> = OnceLock::new();
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
struct RelayForensicsState {
trace_id: u64,
@@ -90,24 +94,46 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
return true;
}
let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new()));
let mut guard = dedup.lock().expect("desync dedup mutex poisoned");
guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW);
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
match guard.get_mut(&key) {
Some(seen_at) => {
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
*seen_at = now;
true
} else {
false
if let Some(mut seen_at) = dedup.get_mut(&key) {
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
*seen_at = now;
return true;
}
return false;
}
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let mut stale_keys = Vec::new();
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
stale_keys.push(*entry.key());
}
}
None => {
guard.insert(key, now);
true
for stale_key in stale_keys {
dedup.remove(&stale_key);
}
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
return false;
}
}
dedup.insert(key, now);
true
}
#[cfg(test)]
fn clear_desync_dedup_for_testing() {
if let Some(dedup) = DESYNC_DEDUP.get() {
dedup.clear();
}
}
#[cfg(test)]
fn desync_dedup_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn report_desync_frame_too_large(
@@ -229,7 +255,7 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
me_pool: Arc<MePool>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>,
buffer_pool: Arc<BufferPool>,
local_addr: SocketAddr,
rng: Arc<SecureRandom>,
mut route_rx: watch::Receiver<RouteCutoverState>,
@@ -271,7 +297,6 @@ where
};
stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
stats.increment_current_connections_me();
if let Some(cutover) = affected_cutover_state(
@@ -291,7 +316,6 @@ where
let _ = me_pool.send_close(conn_id).await;
me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
stats.decrement_user_curr_connects(&user);
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
}
@@ -557,6 +581,7 @@ where
&mut crypto_reader,
proto_tag,
frame_limit,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
@@ -638,7 +663,6 @@ where
);
me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
stats.decrement_user_curr_connects(&user);
result
}
@@ -646,6 +670,7 @@ async fn read_client_payload<R>(
client_reader: &mut CryptoReader<R>,
proto_tag: ProtoTag,
max_frame: usize,
buffer_pool: &Arc<BufferPool>,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
stats: &Stats,
@@ -737,18 +762,27 @@ where
len
};
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
let chunk_cap = buffer_pool.buffer_size().max(1024);
let mut payload = BytesMut::with_capacity(len.min(chunk_cap));
let mut remaining = len;
while remaining > 0 {
let chunk_len = remaining.min(chunk_cap);
let mut chunk = buffer_pool.get();
chunk.resize(chunk_len, 0);
client_reader
.read_exact(&mut chunk[..chunk_len])
.await
.map_err(ProxyError::Io)?;
payload.extend_from_slice(&chunk[..chunk_len]);
remaining -= chunk_len;
}
// Secure Intermediate: strip validated trailing padding bytes.
if proto_tag == ProtoTag::Secure {
payload.truncate(secure_payload_len);
}
*frame_counter += 1;
return Ok(Some((Bytes::from(payload), quickack)));
return Ok(Some((payload.freeze(), quickack)));
}
}
@@ -940,82 +974,5 @@ where
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration as TokioDuration, timeout};
#[test]
fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
}
#[tokio::test]
async fn enqueue_c2me_command_uses_try_send_fast_path() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
enqueue_c2me_command(
&tx,
C2MeCommand::Data {
payload: Bytes::from_static(&[1, 2, 3]),
flags: 0,
},
)
.await
.unwrap();
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[1, 2, 3]);
assert_eq!(flags, 0);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
#[tokio::test]
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: Bytes::from_static(&[9]),
flags: 9,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: Bytes::from_static(&[7, 7]),
flags: 7,
},
)
.await
.unwrap();
});
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap();
producer.await.unwrap();
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[7, 7]);
assert_eq!(flags, 7);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
}
#[path = "middle_relay_security_tests.rs"]
mod security_tests;