security: harden handshake/masking flows and add adversarial regressions

- forward valid-TLS/invalid-MTProto clients to mask backend in both client paths\n- harden TLS validation against timing and clock edge cases\n- move replay tracking behind successful authentication to avoid cache pollution\n- tighten secret decoding and key-material handling paths\n- add dedicated security test modules for tls/client/handshake/masking\n- include production-path regression for ClientHandler fallback behavior
This commit is contained in:
David Osipov
2026-03-16 20:04:41 +04:00
parent dcab19a64f
commit 6ffbc51fb0
11 changed files with 2669 additions and 502 deletions

View File

@@ -13,6 +13,7 @@ use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
use num_bigint::BigUint;
use num_traits::One;
use subtle::ConstantTimeEq;
// ============= Public Constants =============
@@ -125,7 +126,7 @@ impl TlsExtensionBuilder {
// protocol name length (1 byte)
// protocol name bytes
let proto_len = proto.len() as u8;
let list_len: u16 = 1 + proto_len as u16;
let list_len: u16 = 1 + u16::from(proto_len);
let ext_len: u16 = 2 + list_len;
self.extensions.extend_from_slice(&ext_len.to_be_bytes());
@@ -273,13 +274,41 @@ impl ServerHelloBuilder {
// ============= Public Functions =============
/// Validate TLS ClientHello against user secrets
/// Validate TLS ClientHello against user secrets.
///
/// Returns validation result if a matching user is found.
/// The result **must** be used — ignoring it silently bypasses authentication.
#[must_use]
pub fn validate_tls_handshake(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
ignore_time_skew: bool,
) -> Option<TlsValidation> {
// Only pay the clock syscall when we will actually compare against it.
// If `ignore_time_skew` is set, a broken or unavailable system clock
// must not block legitimate clients — that would be a DoS via clock failure.
let now = if !ignore_time_skew {
system_time_to_unix_secs(SystemTime::now())?
} else {
0_i64
};
validate_tls_handshake_at_time(handshake, secrets, ignore_time_skew, now)
}
fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
// `try_from` rejects values that overflow i64 (> ~292 billion years CE),
// whereas `as i64` would silently wrap to a negative timestamp and corrupt
// every subsequent time-skew comparison.
let d = now.duration_since(UNIX_EPOCH).ok()?;
i64::try_from(d.as_secs()).ok()
}
fn validate_tls_handshake_at_time(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
ignore_time_skew: bool,
now: i64,
) -> Option<TlsValidation> {
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
return None;
@@ -305,50 +334,56 @@ pub fn validate_tls_handshake(
let mut msg = handshake.to_vec();
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
// Get current time
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let mut first_match: Option<TlsValidation> = None;
for (user, secret) in secrets {
let computed = sha256_hmac(secret, &msg);
// XOR digests
let xored: Vec<u8> = digest.iter()
.zip(computed.iter())
.map(|(a, b)| a ^ b)
.collect();
// Check that first 28 bytes are zeros (timestamp in last 4)
if !xored[..28].iter().all(|&b| b == 0) {
// Constant-time equality check on the 28-byte HMAC window.
// A variable-time short-circuit here lets an active censor measure how many
// bytes matched, enabling secret brute-force via timing side-channels.
// Direct comparison on the original arrays avoids a heap allocation and
// removes the `try_into().unwrap()` that the intermediate Vec would require.
if !bool::from(digest[..28].ct_eq(&computed[..28])) {
continue;
}
// Extract timestamp
let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap());
let time_diff = now - timestamp as i64;
// Check time skew
// The last 4 bytes encode the timestamp as XOR(digest[28..32], computed[28..32]).
// Inline array construction is infallible: both slices are [u8; 32] by construction.
let timestamp = u32::from_le_bytes([
digest[28] ^ computed[28],
digest[29] ^ computed[29],
digest[30] ^ computed[30],
digest[31] ^ computed[31],
]);
// time_diff is only meaningful (and `now` is only valid) when we are
// actually checking the window. Keep both inside the guard to make
// the dead-code path explicit and prevent accidental future use of
// a sentinel `now` value outside its intended scope.
if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time)
// This is a quirk in some clients that use uptime instead of real time
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds
if !is_boot_time && !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
continue;
if !is_boot_time {
let time_diff = now - i64::from(timestamp);
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
continue;
}
}
}
return Some(TlsValidation {
user: user.clone(),
session_id,
digest,
timestamp,
});
if first_match.is_none() {
first_match = Some(TlsValidation {
user: user.clone(),
session_id: session_id.clone(),
digest,
timestamp,
});
}
}
None
first_match
}
fn curve25519_prime() -> BigUint {
@@ -667,291 +702,29 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_tls_handshake() {
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00]));
assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data
assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version
assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short
}
#[test]
fn test_parse_tls_record_header() {
let header = [0x16, 0x03, 0x01, 0x02, 0x00];
let result = parse_tls_record_header(&header).unwrap();
assert_eq!(result.0, TLS_RECORD_HANDSHAKE);
assert_eq!(result.1, 512);
let header = [0x17, 0x03, 0x03, 0x40, 0x00];
let result = parse_tls_record_header(&header).unwrap();
assert_eq!(result.0, TLS_RECORD_APPLICATION);
assert_eq!(result.1, 16384);
}
#[test]
fn test_gen_fake_x25519_key() {
let rng = SecureRandom::new();
let key1 = gen_fake_x25519_key(&rng);
let key2 = gen_fake_x25519_key(&rng);
assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32);
assert_ne!(key1, key2); // Should be random
}
// ============= Compile-time Security Invariants =============
#[test]
fn test_fake_x25519_key_is_quadratic_residue() {
let rng = SecureRandom::new();
let key = gen_fake_x25519_key(&rng);
let p = curve25519_prime();
let k_num = BigUint::from_bytes_le(&key);
let exponent = (&p - BigUint::one()) >> 1;
let legendre = k_num.modpow(&exponent, &p);
assert_eq!(legendre, BigUint::one());
}
#[test]
fn test_tls_extension_builder() {
let key = [0x42u8; 32];
let mut builder = TlsExtensionBuilder::new();
builder.add_key_share(&key);
builder.add_supported_versions(0x0304);
let result = builder.build();
// Check length prefix
let len = u16::from_be_bytes([result[0], result[1]]) as usize;
assert_eq!(len, result.len() - 2);
// Check key_share extension is present
assert!(result.len() > 40); // At least key share
}
#[test]
fn test_server_hello_builder() {
let session_id = vec![0x01, 0x02, 0x03, 0x04];
let key = [0x55u8; 32];
let builder = ServerHelloBuilder::new(session_id.clone())
.with_x25519_key(&key)
.with_tls13_version();
let record = builder.build_record();
// Validate structure
validate_server_hello_structure(&record).expect("Invalid ServerHello structure");
// Check record type
assert_eq!(record[0], TLS_RECORD_HANDSHAKE);
// Check version
assert_eq!(&record[1..3], &TLS_VERSION);
// Check message type (ServerHello = 0x02)
assert_eq!(record[5], 0x02);
}
#[test]
fn test_build_server_hello_structure() {
let secret = b"test secret";
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let rng = SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0);
// Should have at least 3 records
assert!(response.len() > 100);
// First record should be ServerHello
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
// Validate ServerHello structure
validate_server_hello_structure(&response).expect("Invalid ServerHello");
// Find Change Cipher Spec
let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = server_hello_len;
assert!(response.len() > ccs_start + 6);
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
// Find Application Data
let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + ccs_len;
assert!(response.len() > app_start + 5);
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
}
#[test]
fn test_build_server_hello_digest() {
let secret = b"test secret key here";
let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32];
let rng = SecureRandom::new();
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0);
// Digest position should have non-zero data
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
assert!(!digest1.iter().all(|&b| b == 0));
// Different calls should have different digests (due to random cert)
let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
assert_ne!(digest1, digest2);
}
#[test]
fn test_server_hello_extensions_length() {
let session_id = vec![0x01; 32];
let key = [0x55u8; 32];
let builder = ServerHelloBuilder::new(session_id)
.with_x25519_key(&key)
.with_tls13_version();
let record = builder.build_record();
// Parse to find extensions
let msg_start = 5; // After record header
let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
// Skip to session ID
let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32)
let session_id_len = record[session_id_pos] as usize;
// Skip to extensions
let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1)
let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize;
// Verify extensions length matches actual data
let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len];
assert_eq!(ext_len, extensions_data.len(),
"Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len());
}
#[test]
fn test_validate_tls_handshake_format() {
// Build a minimal ClientHello-like structure
let mut handshake = vec![0u8; 100];
// Put a valid-looking digest at position 11
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&[0x42; 32]);
// Session ID length
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
// This won't validate (wrong HMAC) but shouldn't panic
let secrets = vec![("test".to_string(), b"secret".to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true);
// Should return None (no match) but not panic
assert!(result.is_none());
}
/// Compile-time checks that enforce invariants the rest of the code relies on.
/// Using `static_assertions` ensures these can never silently break across
/// refactors without a compile error.
mod compile_time_security_checks {
use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN};
use static_assertions::const_assert;
fn build_client_hello_with_exts(exts: Vec<(u16, Vec<u8>)>, host: &str) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION); // legacy version
body.extend_from_slice(&[0u8; 32]); // random
body.push(0); // session id len
body.extend_from_slice(&2u16.to_be_bytes()); // cipher suites len
body.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
body.push(1); // compression len
body.push(0); // null compression
// The digest must be exactly one SHA-256 output.
const_assert!(TLS_DIGEST_LEN == 32);
// Build SNI extension
let host_bytes = host.as_bytes();
let mut sni_ext = Vec::new();
sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes());
sni_ext.push(0);
sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_ext.extend_from_slice(host_bytes);
// Replay-dedup stores the first half; verify it is literally half.
const_assert!(TLS_DIGEST_HALF_LEN * 2 == TLS_DIGEST_LEN);
let mut ext_blob = Vec::new();
for (typ, data) in exts {
ext_blob.extend_from_slice(&typ.to_be_bytes());
ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&data);
}
// SNI last
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_ext);
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01); // ClientHello
let len_bytes = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&len_bytes[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record
}
#[test]
fn test_extract_sni_with_grease_extension() {
// GREASE type 0x0a0a with zero length before SNI
let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com");
let sni = extract_sni_from_client_hello(&ch);
assert_eq!(sni.as_deref(), Some("example.com"));
}
#[test]
fn test_extract_sni_tolerates_empty_unknown_extension() {
let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local");
let sni = extract_sni_from_client_hello(&ch);
assert_eq!(sni.as_deref(), Some("test.local"));
}
#[test]
fn test_extract_alpn_single() {
let mut alpn_data = Vec::new();
// list length = 3 (1 length byte + "h2")
alpn_data.extend_from_slice(&3u16.to_be_bytes());
alpn_data.push(2);
alpn_data.extend_from_slice(b"h2");
let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test");
let alpn = extract_alpn_from_client_hello(&ch);
let alpn_str: Vec<String> = alpn
.iter()
.map(|p| std::str::from_utf8(p).unwrap().to_string())
.collect();
assert_eq!(alpn_str, vec!["h2"]);
}
#[test]
fn test_extract_alpn_multiple() {
let mut alpn_data = Vec::new();
// list length = 11 (sum of per-proto lengths including length bytes)
alpn_data.extend_from_slice(&11u16.to_be_bytes());
alpn_data.push(2);
alpn_data.extend_from_slice(b"h2");
alpn_data.push(4);
alpn_data.extend_from_slice(b"spdy");
alpn_data.push(2);
alpn_data.extend_from_slice(b"h3");
let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test");
let alpn = extract_alpn_from_client_hello(&ch);
let alpn_str: Vec<String> = alpn
.iter()
.map(|p| std::str::from_utf8(p).unwrap().to_string())
.collect();
assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]);
}
// The HMAC check window (28 bytes) plus the embedded timestamp (4 bytes)
// must exactly fill the digest. If TLS_DIGEST_LEN ever changes, these
// assertions will catch the mismatch before any timing-oracle fix is broke.
const_assert!(28 + 4 == TLS_DIGEST_LEN);
}
// ============= Security-focused regression tests =============
#[cfg(test)]
#[path = "tls_security_tests.rs"]
mod security_tests;

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,7 @@ enum HandshakeOutcome {
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result};
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
use crate::ip_tracker::UserIpTracker;
use crate::protocol::constants::*;
use crate::protocol::tls;
@@ -63,10 +63,12 @@ fn record_handshake_failure_class(
peer_ip: IpAddr,
error: &ProxyError,
) {
let class = if error.to_string().contains("expected 64 bytes, got 0") {
"expected_64_got_0"
} else {
"other"
let class = match error {
ProxyError::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
"expected_64_got_0"
}
ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0",
_ => "other",
};
record_beobachten_class(beobachten, config, peer_ip, class);
}
@@ -204,9 +206,19 @@ where
&config, &replay_checker, true, Some(tls_user.as_str()),
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader: _, writer: _ } => {
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
handle_bad_client(
reader,
writer,
&mtproto_handshake,
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
@@ -590,12 +602,19 @@ impl RunningClientHandler {
.await
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient {
reader: _,
writer: _,
} => {
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
handle_bad_client(
reader,
writer,
&mtproto_handshake,
peer,
local_addr,
&config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
}
HandshakeResult::Error(e) => return Err(e),
@@ -806,8 +825,24 @@ impl RunningClientHandler {
});
}
let ip_reserved = match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => true,
if let Some(limit) = config.access.user_max_tcp_conns.get(user)
&& stats.get_user_curr_connects(user) >= *limit as u64
{
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
warn!(
user = %user,
@@ -819,33 +854,12 @@ impl RunningClientHandler {
user: user.to_string(),
});
}
};
// IP limit check
if let Some(limit) = config.access.user_max_tcp_conns.get(user)
&& stats.get_user_curr_connects(user) >= *limit as u64
{
if ip_reserved {
ip_tracker.remove_ip(user, peer_addr.ip()).await;
stats.increment_ip_reservation_rollback_tcp_limit_total();
}
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
{
if ip_reserved {
ip_tracker.remove_ip(user, peer_addr.ip()).await;
stats.increment_ip_reservation_rollback_quota_limit_total();
}
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
Ok(())
}
}
#[cfg(test)]
#[path = "client_security_tests.rs"]
mod security_tests;

View File

@@ -0,0 +1,631 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::tls;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
#[tokio::test]
async fn short_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10];
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; probe.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe).await.unwrap();
let mut observed = vec![0u8; backend_reply.len()];
client_side.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply);
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
accept_task.await.unwrap();
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec<u8> {
let tls_len: usize = 600;
let total_len = 5 + tls_len;
let mut handshake = vec![0x42u8; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17);
record.extend_from_slice(&[0x03, 0x03]);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
#[tokio::test]
async fn valid_tls_path_does_not_fall_back_to_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0x11u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 0);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "11111111111111111111111111111111".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(8192);
let peer: SocketAddr = "198.51.100.80:55002".parse().unwrap();
let stats_for_assert = stats.clone();
let bad_before = stats_for_assert.get_connects_bad();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut record_header = [0u8; 5];
client_side.read_exact(&mut record_header).await.unwrap();
assert_eq!(record_header[0], 0x16);
drop(client_side);
let handler_result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(handler_result.is_err());
let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await;
assert!(
no_mask_connect.is_err(),
"Mask backend must not be contacted on authenticated TLS path"
);
let bad_after = stats_for_assert.get_connects_bad();
assert_eq!(
bad_before,
bad_after,
"Authenticated TLS path must not increment connects_bad"
);
}
#[tokio::test]
async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0x33u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 0);
let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN];
let tls_app_record = wrap_tls_application_data(&invalid_mtproto);
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; invalid_mtproto.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, invalid_mtproto);
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "33333333333333333333333333333333".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(32768);
let peer: SocketAddr = "198.51.100.90:55111".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&tls_app_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = mask_listener.local_addr().unwrap();
let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let front_addr = front_listener.local_addr().unwrap();
let secret = [0x44u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 0);
let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN];
let tls_app_record = wrap_tls_application_data(&invalid_mtproto);
let mask_accept_task = tokio::spawn(async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = vec![0u8; invalid_mtproto.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, invalid_mtproto);
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "44444444444444444444444444444444".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let server_task = {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let route_runtime = route_runtime.clone();
let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone();
tokio::spawn(async move {
let (stream, peer) = front_listener.accept().await.unwrap();
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
ClientHandler::new(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
real_peer_report,
)
.run()
.await
})
};
let mut client = TcpStream::connect(front_addr).await.unwrap();
client.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5];
client.read_exact(&mut tls_response_head).await.unwrap();
assert_eq!(tls_response_head[0], 0x16);
client.write_all(&tls_app_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), mask_accept_task)
.await
.unwrap()
.unwrap();
drop(client);
let _ = tokio::time::timeout(Duration::from_secs(3), server_task)
.await
.unwrap()
.unwrap();
}
#[test]
fn unexpected_eof_is_classified_without_string_matching() {
let beobachten = BeobachtenStore::new();
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
let eof = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
let peer_ip: IpAddr = "198.51.100.200".parse().unwrap();
record_handshake_failure_class(&beobachten, &config, peer_ip, &eof);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(
snapshot.contains("[expected_64_got_0]"),
"UnexpectedEof must be classified as expected_64_got_0"
);
assert!(
snapshot.contains("198.51.100.200-1"),
"Classified record must include source IP"
);
}
#[test]
fn non_eof_error_is_classified_as_other() {
let beobachten = BeobachtenStore::new();
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
let non_eof = ProxyError::Io(std::io::Error::other("different error"));
let peer_ip: IpAddr = "203.0.113.201".parse().unwrap();
record_handshake_failure_class(&beobachten, &config, peer_ip, &non_eof);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(
snapshot.contains("[other]"),
"Non-EOF errors must map to other"
);
assert!(
snapshot.contains("203.0.113.201-1"),
"Classified record must include source IP"
);
assert!(
!snapshot.contains("[expected_64_got_0]"),
"Non-EOF errors must not be misclassified as expected_64_got_0"
);
}
#[tokio::test]
async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() {
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert("user".to_string(), 1);
let stats = Stats::new();
stats.increment_user_curr_connects("user");
let ip_tracker = UserIpTracker::new();
let peer_addr: SocketAddr = "198.51.100.210:50000".parse().unwrap();
let result = RunningClientHandler::check_user_limits_static(
"user",
&config,
&stats,
peer_addr,
&ip_tracker,
)
.await;
assert!(matches!(
result,
Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user"
));
assert_eq!(
ip_tracker.get_active_ip_count("user").await,
0,
"Rejected client must not reserve IP slot"
);
assert_eq!(
stats.get_ip_reservation_rollback_tcp_limit_total(),
0,
"No rollback should occur when reservation is not taken"
);
}
#[tokio::test]
async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
let mut config = ProxyConfig::default();
config.access.user_data_quota.insert("user".to_string(), 1024);
let stats = Stats::new();
stats.add_user_octets_from("user", 1024);
let ip_tracker = UserIpTracker::new();
let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap();
let result = RunningClientHandler::check_user_limits_static(
"user",
&config,
&stats,
peer_addr,
&ip_tracker,
)
.await;
assert!(matches!(
result,
Err(ProxyError::DataQuotaExceeded { user }) if user == "user"
));
assert_eq!(
ip_tracker.get_active_ip_count("user").await,
0,
"Quota-rejected client must not reserve IP slot"
);
assert_eq!(
stats.get_ip_reservation_rollback_quota_limit_total(),
0,
"No rollback should occur when reservation is not taken"
);
}
#[tokio::test]
async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() {
const PARALLEL_IPS: usize = 64;
const ATTEMPTS_PER_IP: usize = 8;
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert("user".to_string(), 1);
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
stats.increment_user_curr_connects("user");
let ip_tracker = Arc::new(UserIpTracker::new());
let mut tasks = tokio::task::JoinSet::new();
for i in 0..PARALLEL_IPS {
let config = config.clone();
let stats = stats.clone();
let ip_tracker = ip_tracker.clone();
tasks.spawn(async move {
let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 100, (i + 1) as u8));
for _ in 0..ATTEMPTS_PER_IP {
let peer_addr = SocketAddr::new(ip, 40000 + i as u16);
let result = RunningClientHandler::check_user_limits_static(
"user",
&config,
&stats,
peer_addr,
&ip_tracker,
)
.await;
assert!(matches!(
result,
Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user"
));
}
});
}
while let Some(joined) = tasks.join_next().await {
joined.unwrap();
}
assert_eq!(
ip_tracker.get_active_ip_count("user").await,
0,
"Concurrent rejected attempts must not leave active IP reservations"
);
let recent = ip_tracker
.get_recent_ips_for_users(&["user".to_string()])
.await;
assert!(
recent
.get("user")
.map(|ips| ips.is_empty())
.unwrap_or(true),
"Concurrent rejected attempts must not leave recent IP footprint"
);
assert_eq!(
stats.get_ip_reservation_rollback_tcp_limit_total(),
0,
"No rollback should occur under concurrent rejection storms"
);
}

View File

@@ -19,6 +19,8 @@ use crate::stats::ReplayChecker;
use crate::config::ProxyConfig;
use crate::tls_front::{TlsFrontCache, emulator};
const ACCESS_SECRET_BYTES: usize = 16;
fn decode_user_secrets(
config: &ProxyConfig,
preferred_user: Option<&str>,
@@ -28,6 +30,7 @@ fn decode_user_secrets(
if let Some(preferred) = preferred_user
&& let Some(secret_hex) = config.access.users.get(preferred)
&& let Ok(bytes) = hex::decode(secret_hex)
&& bytes.len() == ACCESS_SECRET_BYTES
{
secrets.push((preferred.to_string(), bytes));
}
@@ -36,7 +39,9 @@ fn decode_user_secrets(
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
continue;
}
if let Ok(bytes) = hex::decode(secret_hex) {
if let Ok(bytes) = hex::decode(secret_hex)
&& bytes.len() == ACCESS_SECRET_BYTES
{
secrets.push((name.clone(), bytes));
}
}
@@ -48,7 +53,7 @@ fn decode_user_secrets(
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
/// zeroized on drop.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct HandshakeSuccess {
/// Authenticated user name
pub user: String,
@@ -99,14 +104,6 @@ where
return HandshakeResult::BadClient { reader, writer };
}
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secrets = decode_user_secrets(config, None);
let validation = match tls::validate_tls_handshake(
@@ -125,6 +122,14 @@ where
}
};
// Replay tracking is applied only after successful authentication to avoid
// letting unauthenticated probes evict valid entries from the replay cache.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer },
@@ -254,11 +259,6 @@ where
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
}
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let decoded_users = decode_user_secrets(config, preferred_user);
@@ -273,14 +273,19 @@ where
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake);
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into()
.unwrap();
let tag_bytes: [u8; 4] = [
decrypted[PROTO_TAG_POS],
decrypted[PROTO_TAG_POS + 1],
decrypted[PROTO_TAG_POS + 2],
decrypted[PROTO_TAG_POS + 3],
];
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag,
@@ -303,9 +308,7 @@ where
continue;
}
let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]);
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
@@ -315,10 +318,19 @@ where
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(enc_iv_bytes);
let enc_iv = u128::from_be_bytes(enc_iv_arr);
let encryptor = AesCtr::new(&enc_key, enc_iv);
// Apply replay tracking only after successful authentication to prevent
// unauthenticated probes from evicting legitimate replay-cache entries.
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer };
}
let success = HandshakeSuccess {
user: user.clone(),
dc_idx,
@@ -365,14 +377,16 @@ pub fn generate_tg_nonce(
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop {
let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
let Ok(mut nonce): Result<[u8; HANDSHAKE_LEN], _> = bytes.try_into() else {
continue;
};
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]];
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
@@ -390,11 +404,17 @@ pub fn generate_tg_nonce(
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut tg_enc_key = [0u8; 32];
tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
let mut tg_enc_iv_arr = [0u8; IV_LEN];
tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
let tg_enc_iv = u128::from_be_bytes(tg_enc_iv_arr);
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
let mut tg_dec_key = [0u8; 32];
tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]);
let mut tg_dec_iv_arr = [0u8; IV_LEN];
tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]);
let tg_dec_iv = u128::from_be_bytes(tg_dec_iv_arr);
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
}
@@ -405,11 +425,17 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut enc_key = [0u8; 32];
enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
let enc_iv = u128::from_be_bytes(enc_iv_arr);
let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
let mut dec_key = [0u8; 32];
dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
@@ -429,80 +455,15 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
}
#[cfg(test)]
mod tests {
use super::*;
#[path = "handshake_security_tests.rs"]
mod security_tests;
#[test]
fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.
mod compile_time_security_checks {
use super::HandshakeSuccess;
use static_assertions::assert_not_impl_all;
let rng = SecureRandom::new();
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
assert_eq!(nonce.len(), HANDSHAKE_LEN);
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
}
#[test]
fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) =
generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
}
#[test]
fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess {
user: "test".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Secure,
dec_key: [0xAA; 32],
dec_iv: 0xBBBBBBBB,
enc_key: [0xCC; 32],
enc_iv: 0xDDDDDDDD,
peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true,
};
assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]);
drop(success);
// Drop impl zeroizes key material without panic
}
assert_not_impl_all!(HandshakeSuccess: Copy, Clone);
}

View File

@@ -0,0 +1,276 @@
use super::*;
use crate::crypto::sha256_hmac;
use std::time::Duration;
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg
}
#[test]
fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new();
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
assert_eq!(nonce.len(), HANDSHAKE_LEN);
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
}
#[test]
fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure,
2,
&client_dec_key,
client_dec_iv,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
}
#[test]
fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess {
user: "test".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Secure,
dec_key: [0xAA; 32],
dec_iv: 0xBBBBBBBB,
enc_key: [0xCC; 32],
enc_iv: 0xDDDDDDDD,
peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true,
};
assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]);
drop(success);
}
#[tokio::test]
async fn tls_replay_second_identical_handshake_is_rejected() {
let secret = [0x11u8; 16];
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44321".parse().unwrap();
let handshake = make_valid_tls_handshake(&secret, 0);
let first = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(first, HandshakeResult::Success(_)));
let second = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(second, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn invalid_tls_probe_does_not_pollute_replay_cache() {
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44322".parse().unwrap();
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
let before = replay_checker.stats();
let result = handle_tls_handshake(
&invalid,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let after = replay_checker.stats();
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(before.total_additions, after.total_additions);
assert_eq!(before.total_hits, after.total_hits);
}
#[tokio::test]
async fn empty_decoded_secret_is_rejected() {
let config = test_config_with_secret_hex("");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44323".parse().unwrap();
let handshake = make_valid_tls_handshake(&[], 0);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn wrong_length_decoded_secret_is_rejected() {
let config = test_config_with_secret_hex("aa");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44324".parse().unwrap();
let handshake = make_valid_tls_handshake(&[0xaau8], 0);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn invalid_mtproto_probe_does_not_pollute_replay_cache() {
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "127.0.0.1:44325".parse().unwrap();
let handshake = [0u8; HANDSHAKE_LEN];
let before = replay_checker.stats();
let result = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
let after = replay_checker.stats();
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(before.total_additions, after.total_additions);
assert_eq!(before.total_hits, after.total_hits);
}
#[tokio::test]
async fn mixed_secret_lengths_keep_valid_user_authenticating() {
let good_secret = [0x22u8; 16];
let mut config = ProxyConfig::default();
config.access.users.clear();
config
.access
.users
.insert("broken_user".to_string(), "aa".to_string());
config
.access
.users
.insert("valid_user".to_string(), "22222222222222222222222222222222".to_string());
config.access.ignore_time_skew = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44326".parse().unwrap();
let handshake = make_valid_tls_handshake(&good_secret, 0);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::Success(_)));
}

View File

@@ -194,55 +194,48 @@ async fn relay_to_mask<R, W, MR, MW>(
initial_data: &[u8],
)
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
MR: AsyncRead + Unpin + Send + 'static,
MW: AsyncWrite + Unpin + Send + 'static,
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
MR: AsyncRead + Unpin + Send,
MW: AsyncWrite + Unpin + Send,
{
// Send initial data to mask host
if mask_write.write_all(initial_data).await.is_err() {
return;
}
// Relay traffic
let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
match reader.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await;
break;
}
Ok(n) => {
if mask_write.write_all(&buf[..n]).await.is_err() {
let mut client_buf = vec![0u8; MASK_BUFFER_SIZE];
let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
tokio::select! {
client_read = reader.read(&mut client_buf) => {
match client_read {
Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await;
break;
}
Ok(n) => {
if mask_write.write_all(&client_buf[..n]).await.is_err() {
break;
}
}
}
}
mask_read_res = mask_read.read(&mut mask_buf) => {
match mask_read_res {
Ok(0) | Err(_) => {
let _ = writer.shutdown().await;
break;
}
Ok(n) => {
if writer.write_all(&mask_buf[..n]).await.is_err() {
break;
}
}
}
}
}
});
let m2c = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = writer.shutdown().await;
break;
}
Ok(n) => {
if writer.write_all(&buf[..n]).await.is_err() {
break;
}
}
}
}
});
// Wait for either to complete
tokio::select! {
_ = c2m => {}
_ = m2c => {}
}
}
@@ -255,3 +248,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
}
}
}
#[cfg(test)]
#[path = "masking_security_tests.rs"]
mod security_tests;

View File

@@ -0,0 +1,257 @@
use super::*;
use crate::config::ProxyConfig;
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
use tokio::net::TcpListener;
use tokio::time::{timeout, Duration};
#[tokio::test]
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply);
accept_task.await.unwrap();
}
#[tokio::test]
async fn tls_scanner_probe_keeps_http_like_fallback_surface() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04];
let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; probe.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "198.51.100.44:55221".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply);
assert!(observed.starts_with(b"HTTP/"));
accept_task.await.unwrap();
}
#[tokio::test]
async fn backend_unavailable_falls_back_to_silent_consume() {
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = unused_port;
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.11:42425".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n";
let (mut client_reader_side, client_reader) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client_reader_side.write_all(b"noise").await.unwrap();
drop(client_reader_side);
timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
.await
.unwrap()
.unwrap();
assert_eq!(n, 0);
}
#[tokio::test]
async fn mask_disabled_consumes_client_data_without_response() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let peer: SocketAddr = "198.51.100.12:45454".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let initial = b"scanner";
let (mut client_reader_side, client_reader) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
initial,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client_reader_side.write_all(b"untrusted payload").await.unwrap();
drop(client_reader_side);
timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
.await
.unwrap()
.unwrap();
assert_eq!(n, 0);
}
#[tokio::test]
async fn proxy_protocol_v1_header_is_sent_before_probe() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (stream, _) = listener.accept().await.unwrap();
let mut reader = BufReader::new(stream);
let mut header_line = Vec::new();
reader.read_until(b'\n', &mut header_line).await.unwrap();
let header_text = String::from_utf8(header_line.clone()).unwrap();
assert!(header_text.starts_with("PROXY TCP4 "));
assert!(header_text.ends_with("\r\n"));
let mut received_probe = vec![0u8; probe.len()];
reader.read_exact(&mut received_probe).await.unwrap();
assert_eq!(received_probe, probe);
let mut stream = reader.into_inner();
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 1;
let peer: SocketAddr = "203.0.113.15:50001".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = duplex(256);
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply);
accept_task.await.unwrap();
}