Phase 2 implemented with additional guards

This commit is contained in:
David Osipov
2026-04-03 02:08:59 +04:00
parent a9f695623d
commit 6ea867ce36
27 changed files with 2513 additions and 1131 deletions

View File

@@ -4,13 +4,16 @@
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
#[cfg(test)]
use std::collections::HashSet;
#[cfg(test)]
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
#[cfg(test)]
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, info, trace, warn};
@@ -21,15 +24,15 @@ use crate::crypto::{AesCtr, SecureRandom, sha256};
use crate::error::{HandshakeResult, ProxyError};
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::proxy::shared_state::ProxySharedState;
use crate::stats::ReplayChecker;
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::tls_front::{TlsFrontCache, emulator};
#[cfg(test)]
use rand::RngExt;
const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5;
static UNKNOWN_SNI_WARN_NEXT_ALLOWED: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
#[cfg(test)]
const WARNED_SECRET_MAX_ENTRIES: usize = 64;
#[cfg(not(test))]
@@ -55,48 +58,30 @@ const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16;
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000;
#[derive(Clone, Copy)]
struct AuthProbeState {
pub(crate) struct AuthProbeState {
fail_streak: u32,
blocked_until: Instant,
last_seen: Instant,
}
#[derive(Clone, Copy)]
struct AuthProbeSaturationState {
pub(crate) struct AuthProbeSaturationState {
fail_streak: u32,
blocked_until: Instant,
last_seen: Instant,
}
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> =
OnceLock::new();
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
AUTH_PROBE_STATE.get_or_init(DashMap::new)
}
fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationState>> {
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
}
fn auth_probe_saturation_state_lock()
-> std::sync::MutexGuard<'static, Option<AuthProbeSaturationState>> {
auth_probe_saturation_state()
fn unknown_sni_warn_state_lock_in(
shared: &ProxySharedState,
) -> std::sync::MutexGuard<'_, Option<Instant>> {
shared
.handshake
.unknown_sni_warn_next_allowed
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn unknown_sni_warn_state_lock() -> std::sync::MutexGuard<'static, Option<Instant>> {
UNKNOWN_SNI_WARN_NEXT_ALLOWED
.get_or_init(|| Mutex::new(None))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn should_emit_unknown_sni_warn(now: Instant) -> bool {
let mut guard = unknown_sni_warn_state_lock();
fn should_emit_unknown_sni_warn_in(shared: &ProxySharedState, now: Instant) -> bool {
let mut guard = unknown_sni_warn_state_lock_in(shared);
if let Some(next_allowed) = *guard
&& now < next_allowed
{
@@ -133,15 +118,16 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
now.duration_since(state.last_seen) > retention
}
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new);
fn auth_probe_eviction_offset_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) -> usize {
let hasher_state = &shared.handshake.auth_probe_eviction_hasher;
let mut hasher = hasher_state.build_hasher();
peer_ip.hash(&mut hasher);
now.hash(&mut hasher);
hasher.finish() as usize
}
fn auth_probe_scan_start_offset(
fn auth_probe_scan_start_offset_in(
shared: &ProxySharedState,
peer_ip: IpAddr,
now: Instant,
state_len: usize,
@@ -151,12 +137,12 @@ fn auth_probe_scan_start_offset(
return 0;
}
auth_probe_eviction_offset(peer_ip, now) % state_len
auth_probe_eviction_offset_in(shared, peer_ip, now) % state_len
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
fn auth_probe_is_throttled_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let state = &shared.handshake.auth_probe;
let Some(entry) = state.get(&peer_ip) else {
return false;
};
@@ -168,9 +154,13 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
now < entry.blocked_until
}
fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool {
fn auth_probe_saturation_grace_exhausted_in(
shared: &ProxySharedState,
peer_ip: IpAddr,
now: Instant,
) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let state = &shared.handshake.auth_probe;
let Some(entry) = state.get(&peer_ip) else {
return false;
};
@@ -183,20 +173,28 @@ fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool
entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
}
fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool {
if !auth_probe_is_throttled(peer_ip, now) {
fn auth_probe_should_apply_preauth_throttle_in(
shared: &ProxySharedState,
peer_ip: IpAddr,
now: Instant,
) -> bool {
if !auth_probe_is_throttled_in(shared, peer_ip, now) {
return false;
}
if !auth_probe_saturation_is_throttled(now) {
if !auth_probe_saturation_is_throttled_in(shared, now) {
return true;
}
auth_probe_saturation_grace_exhausted(peer_ip, now)
auth_probe_saturation_grace_exhausted_in(shared, peer_ip, now)
}
fn auth_probe_saturation_is_throttled(now: Instant) -> bool {
let mut guard = auth_probe_saturation_state_lock();
fn auth_probe_saturation_is_throttled_in(shared: &ProxySharedState, now: Instant) -> bool {
let mut guard = shared
.handshake
.auth_probe_saturation
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let Some(state) = guard.as_mut() else {
return false;
@@ -214,8 +212,12 @@ fn auth_probe_saturation_is_throttled(now: Instant) -> bool {
false
}
fn auth_probe_note_saturation(now: Instant) {
let mut guard = auth_probe_saturation_state_lock();
fn auth_probe_note_saturation_in(shared: &ProxySharedState, now: Instant) {
let mut guard = shared
.handshake
.auth_probe_saturation
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
match guard.as_mut() {
Some(state)
@@ -237,13 +239,14 @@ fn auth_probe_note_saturation(now: Instant) {
}
}
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
fn auth_probe_record_failure_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
auth_probe_record_failure_with_state(state, peer_ip, now);
let state = &shared.handshake.auth_probe;
auth_probe_record_failure_with_state_in(shared, state, peer_ip, now);
}
fn auth_probe_record_failure_with_state(
fn auth_probe_record_failure_with_state_in(
shared: &ProxySharedState,
state: &DashMap<IpAddr, AuthProbeState>,
peer_ip: IpAddr,
now: Instant,
@@ -277,7 +280,7 @@ fn auth_probe_record_failure_with_state(
while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
rounds += 1;
if rounds > 8 {
auth_probe_note_saturation(now);
auth_probe_note_saturation_in(shared, now);
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
let key = *entry.key();
@@ -320,7 +323,7 @@ fn auth_probe_record_failure_with_state(
}
} else {
let start_offset =
auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
auth_probe_scan_start_offset_in(shared, peer_ip, now, state_len, scan_limit);
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
@@ -369,11 +372,11 @@ fn auth_probe_record_failure_with_state(
}
let Some((evict_key, _, _)) = eviction_candidate else {
auth_probe_note_saturation(now);
auth_probe_note_saturation_in(shared, now);
return;
};
state.remove(&evict_key);
auth_probe_note_saturation(now);
auth_probe_note_saturation_in(shared, now);
}
}
@@ -387,89 +390,58 @@ fn auth_probe_record_failure_with_state(
}
}
fn auth_probe_record_success(peer_ip: IpAddr) {
fn auth_probe_record_success_in(shared: &ProxySharedState, peer_ip: IpAddr) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
let state = &shared.handshake.auth_probe;
state.remove(&peer_ip);
}
#[cfg(test)]
fn clear_auth_probe_state_for_testing() {
if let Some(state) = AUTH_PROBE_STATE.get() {
state.clear();
}
if AUTH_PROBE_SATURATION_STATE.get().is_some() {
let mut guard = auth_probe_saturation_state_lock();
*guard = None;
}
pub(crate) fn auth_probe_record_failure_for_testing(
shared: &ProxySharedState,
peer_ip: IpAddr,
now: Instant,
) {
auth_probe_record_failure_in(shared, peer_ip, now);
}
#[cfg(test)]
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
pub(crate) fn auth_probe_fail_streak_for_testing_in_shared(
shared: &ProxySharedState,
peer_ip: IpAddr,
) -> Option<u32> {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = AUTH_PROBE_STATE.get()?;
state.get(&peer_ip).map(|entry| entry.fail_streak)
shared
.handshake
.auth_probe
.get(&peer_ip)
.map(|entry| entry.fail_streak)
}
#[cfg(test)]
fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
auth_probe_is_throttled(peer_ip, Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_for_testing() -> bool {
auth_probe_saturation_is_throttled(Instant::now())
}
#[cfg(test)]
fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool {
auth_probe_saturation_is_throttled(now)
}
#[cfg(test)]
fn auth_probe_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn unknown_sni_warn_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn clear_unknown_sni_warn_state_for_testing() {
if UNKNOWN_SNI_WARN_NEXT_ALLOWED.get().is_some() {
let mut guard = unknown_sni_warn_state_lock();
*guard = None;
pub(crate) fn clear_auth_probe_state_for_testing_in_shared(shared: &ProxySharedState) {
shared.handshake.auth_probe.clear();
match shared.handshake.auth_probe_saturation.lock() {
Ok(mut saturation) => {
*saturation = None;
}
Err(poisoned) => {
let mut saturation = poisoned.into_inner();
*saturation = None;
shared.handshake.auth_probe_saturation.clear_poison();
}
}
}
#[cfg(test)]
fn should_emit_unknown_sni_warn_for_testing(now: Instant) -> bool {
should_emit_unknown_sni_warn(now)
}
#[cfg(test)]
fn clear_warned_secrets_for_testing() {
if let Some(warned) = INVALID_SECRET_WARNED.get()
&& let Ok(mut guard) = warned.lock()
{
guard.clear();
}
}
#[cfg(test)]
fn warned_secrets_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
fn warn_invalid_secret_once_in(
shared: &ProxySharedState,
name: &str,
reason: &str,
expected: usize,
got: Option<usize>,
) {
let key = (name.to_string(), reason.to_string());
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let should_warn = match warned.lock() {
let should_warn = match shared.handshake.invalid_secret_warned.lock() {
Ok(mut guard) => {
if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES {
false
@@ -502,11 +474,12 @@ fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Opti
}
}
fn decode_user_secret(name: &str, secret_hex: &str) -> Option<Vec<u8>> {
fn decode_user_secret(shared: &ProxySharedState, name: &str, secret_hex: &str) -> Option<Vec<u8>> {
match hex::decode(secret_hex) {
Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes),
Ok(bytes) => {
warn_invalid_secret_once(
warn_invalid_secret_once_in(
shared,
name,
"invalid_length",
ACCESS_SECRET_BYTES,
@@ -515,7 +488,7 @@ fn decode_user_secret(name: &str, secret_hex: &str) -> Option<Vec<u8>> {
None
}
Err(_) => {
warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None);
warn_invalid_secret_once_in(shared, name, "invalid_hex", ACCESS_SECRET_BYTES, None);
None
}
}
@@ -543,7 +516,8 @@ fn mode_enabled_for_proto(config: &ProxyConfig, proto_tag: ProtoTag, is_tls: boo
}
}
fn decode_user_secrets(
fn decode_user_secrets_in(
shared: &ProxySharedState,
config: &ProxyConfig,
preferred_user: Option<&str>,
) -> Vec<(String, Vec<u8>)> {
@@ -551,7 +525,7 @@ fn decode_user_secrets(
if let Some(preferred) = preferred_user
&& let Some(secret_hex) = config.access.users.get(preferred)
&& let Some(bytes) = decode_user_secret(preferred, secret_hex)
&& let Some(bytes) = decode_user_secret(shared, preferred, secret_hex)
{
secrets.push((preferred.to_string(), bytes));
}
@@ -560,7 +534,7 @@ fn decode_user_secrets(
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
continue;
}
if let Some(bytes) = decode_user_secret(name, secret_hex) {
if let Some(bytes) = decode_user_secret(shared, name, secret_hex) {
secrets.push((name.clone(), bytes));
}
}
@@ -568,6 +542,86 @@ fn decode_user_secrets(
secrets
}
#[cfg(test)]
pub(crate) fn auth_probe_state_for_testing_in_shared(
shared: &ProxySharedState,
) -> &DashMap<IpAddr, AuthProbeState> {
&shared.handshake.auth_probe
}
#[cfg(test)]
pub(crate) fn auth_probe_saturation_state_for_testing_in_shared(
shared: &ProxySharedState,
) -> &Mutex<Option<AuthProbeSaturationState>> {
&shared.handshake.auth_probe_saturation
}
#[cfg(test)]
pub(crate) fn auth_probe_saturation_state_lock_for_testing_in_shared(
shared: &ProxySharedState,
) -> std::sync::MutexGuard<'_, Option<AuthProbeSaturationState>> {
shared
.handshake
.auth_probe_saturation
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[cfg(test)]
pub(crate) fn clear_unknown_sni_warn_state_for_testing_in_shared(shared: &ProxySharedState) {
let mut guard = shared
.handshake
.unknown_sni_warn_next_allowed
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*guard = None;
}
#[cfg(test)]
pub(crate) fn should_emit_unknown_sni_warn_for_testing_in_shared(
shared: &ProxySharedState,
now: Instant,
) -> bool {
should_emit_unknown_sni_warn_in(shared, now)
}
#[cfg(test)]
pub(crate) fn clear_warned_secrets_for_testing_in_shared(shared: &ProxySharedState) {
if let Ok(mut guard) = shared.handshake.invalid_secret_warned.lock() {
guard.clear();
}
}
#[cfg(test)]
pub(crate) fn warned_secrets_for_testing_in_shared(
shared: &ProxySharedState,
) -> &Mutex<HashSet<(String, String)>> {
&shared.handshake.invalid_secret_warned
}
#[cfg(test)]
pub(crate) fn auth_probe_is_throttled_for_testing_in_shared(
shared: &ProxySharedState,
peer_ip: IpAddr,
) -> bool {
auth_probe_is_throttled_in(shared, peer_ip, Instant::now())
}
#[cfg(test)]
pub(crate) fn auth_probe_saturation_is_throttled_for_testing_in_shared(
shared: &ProxySharedState,
) -> bool {
auth_probe_saturation_is_throttled_in(shared, Instant::now())
}
#[cfg(test)]
pub(crate) fn auth_probe_saturation_is_throttled_at_for_testing_in_shared(
shared: &ProxySharedState,
now: Instant,
) -> bool {
auth_probe_saturation_is_throttled_in(shared, now)
}
#[inline]
fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
@@ -635,6 +689,7 @@ impl Drop for HandshakeSuccess {
}
/// Handle fake TLS handshake
#[cfg(test)]
pub async fn handle_tls_handshake<R, W>(
handshake: &[u8],
reader: R,
@@ -645,6 +700,65 @@ pub async fn handle_tls_handshake<R, W>(
rng: &SecureRandom,
tls_cache: Option<Arc<TlsFrontCache>>,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let shared = ProxySharedState::new();
handle_tls_handshake_impl(
handshake,
reader,
writer,
peer,
config,
replay_checker,
rng,
tls_cache,
shared.as_ref(),
)
.await
}
pub async fn handle_tls_handshake_with_shared<R, W>(
handshake: &[u8],
reader: R,
writer: W,
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
rng: &SecureRandom,
tls_cache: Option<Arc<TlsFrontCache>>,
shared: &ProxySharedState,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
handle_tls_handshake_impl(
handshake,
reader,
writer,
peer,
config,
replay_checker,
rng,
tls_cache,
shared,
)
.await
}
async fn handle_tls_handshake_impl<R, W>(
handshake: &[u8],
reader: R,
mut writer: W,
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
rng: &SecureRandom,
tls_cache: Option<Arc<TlsFrontCache>>,
shared: &ProxySharedState,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
@@ -652,14 +766,14 @@ where
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
if auth_probe_should_apply_preauth_throttle_in(shared, peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
auth_probe_record_failure(peer.ip(), Instant::now());
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer };
@@ -695,11 +809,11 @@ where
};
if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() {
auth_probe_record_failure(peer.ip(), Instant::now());
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
let sni = client_sni.as_deref().unwrap_or_default();
let log_now = Instant::now();
if should_emit_unknown_sni_warn(log_now) {
if should_emit_unknown_sni_warn_in(shared, log_now) {
warn!(
peer = %peer,
sni = %sni,
@@ -722,7 +836,7 @@ where
};
}
let secrets = decode_user_secrets(config, preferred_user_hint);
let secrets = decode_user_secrets_in(shared, config, preferred_user_hint);
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
@@ -732,7 +846,7 @@ where
) {
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
@@ -746,7 +860,7 @@ where
// Reject known replay digests before expensive cache/domain/ALPN policy work.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
@@ -827,7 +941,7 @@ where
"TLS handshake successful"
);
auth_probe_record_success(peer.ip());
auth_probe_record_success_in(shared, peer.ip());
HandshakeResult::Success((
FakeTlsReader::new(reader),
@@ -837,6 +951,7 @@ where
}
/// Handle MTProto obfuscation handshake
#[cfg(test)]
pub async fn handle_mtproto_handshake<R, W>(
handshake: &[u8; HANDSHAKE_LEN],
reader: R,
@@ -847,6 +962,65 @@ pub async fn handle_mtproto_handshake<R, W>(
is_tls: bool,
preferred_user: Option<&str>,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
let shared = ProxySharedState::new();
handle_mtproto_handshake_impl(
handshake,
reader,
writer,
peer,
config,
replay_checker,
is_tls,
preferred_user,
shared.as_ref(),
)
.await
}
pub async fn handle_mtproto_handshake_with_shared<R, W>(
handshake: &[u8; HANDSHAKE_LEN],
reader: R,
writer: W,
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
is_tls: bool,
preferred_user: Option<&str>,
shared: &ProxySharedState,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
handle_mtproto_handshake_impl(
handshake,
reader,
writer,
peer,
config,
replay_checker,
is_tls,
preferred_user,
shared,
)
.await
}
async fn handle_mtproto_handshake_impl<R, W>(
handshake: &[u8; HANDSHAKE_LEN],
reader: R,
writer: W,
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
is_tls: bool,
preferred_user: Option<&str>,
shared: &ProxySharedState,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
@@ -862,7 +1036,7 @@ where
);
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
if auth_probe_should_apply_preauth_throttle_in(shared, peer.ip(), throttle_now) {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
@@ -872,7 +1046,7 @@ where
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let decoded_users = decode_user_secrets(config, preferred_user);
let decoded_users = decode_user_secrets_in(shared, config, preferred_user);
for (user, secret) in decoded_users {
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
@@ -932,7 +1106,7 @@ where
// entry from the cache. We accept the cost of performing the full
// authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now());
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
@@ -959,7 +1133,7 @@ where
"MTProto handshake successful"
);
auth_probe_record_success(peer.ip());
auth_probe_record_success_in(shared, peer.ip());
let max_pending = config.general.crypto_pending_buffer;
return HandshakeResult::Success((
@@ -969,7 +1143,7 @@ where
));
}
auth_probe_record_failure(peer.ip(), Instant::now());
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer }