This commit is contained in:
Alexey
2026-03-21 15:45:29 +03:00
parent 7a8f946029
commit d7bbb376c9
154 changed files with 6194 additions and 3775 deletions

View File

@@ -6,10 +6,10 @@
#![allow(dead_code)]
use crate::crypto::{sha256_hmac, SecureRandom};
use super::constants::*;
use crate::crypto::{SecureRandom, sha256_hmac};
#[cfg(test)]
use crate::error::ProxyError;
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH};
use subtle::ConstantTimeEq;
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
@@ -31,7 +31,7 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
/// Operators with known clock-drifted clients should tune deployment config
/// (for example replay-window policy) to match their environment.
pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance
@@ -69,7 +69,6 @@ pub struct TlsValidation {
/// Client digest for response generation
pub digest: [u8; TLS_DIGEST_LEN],
/// Timestamp extracted from digest
pub timestamp: u32,
}
@@ -87,60 +86,63 @@ impl TlsExtensionBuilder {
extensions: Vec::with_capacity(128),
}
}
/// Add Key Share extension with X25519 key
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
// Extension type: key_share (0x0033)
self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
self.extensions
.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
// Extension data length
let entry_len: u16 = 2 + 2 + 32; // curve + length + key
self.extensions.extend_from_slice(&entry_len.to_be_bytes());
// Named curve: x25519
self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes());
self.extensions
.extend_from_slice(&named_curve::X25519.to_be_bytes());
// Key length
self.extensions.extend_from_slice(&(32u16).to_be_bytes());
// Key data
self.extensions.extend_from_slice(public_key);
self
}
/// Add Supported Versions extension
fn add_supported_versions(&mut self, version: u16) -> &mut Self {
// Extension type: supported_versions (0x002b)
self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
self.extensions
.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
// Extension data: length (2) + version (2)
self.extensions.extend_from_slice(&(2u16).to_be_bytes());
// Selected version
self.extensions.extend_from_slice(&version.to_be_bytes());
self
}
/// Build final extensions with length prefix
fn build(self) -> Vec<u8> {
let mut result = Vec::with_capacity(2 + self.extensions.len());
// Extensions length (2 bytes)
let len = self.extensions.len() as u16;
result.extend_from_slice(&len.to_be_bytes());
// Extensions data
result.extend_from_slice(&self.extensions);
result
}
/// Get current extensions without length prefix (for calculation)
fn as_bytes(&self) -> &[u8] {
&self.extensions
}
@@ -172,12 +174,12 @@ impl ServerHelloBuilder {
extensions: TlsExtensionBuilder::new(),
}
}
fn with_x25519_key(mut self, key: &[u8; 32]) -> Self {
self.extensions.add_key_share(key);
self
}
fn with_tls13_version(mut self) -> Self {
// TLS 1.3 = 0x0304
self.extensions.add_supported_versions(0x0304);
@@ -188,7 +190,7 @@ impl ServerHelloBuilder {
fn build_message(&self) -> Vec<u8> {
let extensions = self.extensions.extensions.clone();
let extensions_len = extensions.len() as u16;
// Calculate total length
let body_len = 2 + // version
32 + // random
@@ -196,55 +198,55 @@ impl ServerHelloBuilder {
2 + // cipher suite
1 + // compression
2 + extensions.len(); // extensions length + data
let mut message = Vec::with_capacity(4 + body_len);
// Handshake header
message.push(0x02); // ServerHello message type
// 3-byte length
let len_bytes = (body_len as u32).to_be_bytes();
message.extend_from_slice(&len_bytes[1..4]);
// Server version (TLS 1.2 in header, actual version in extension)
message.extend_from_slice(&TLS_VERSION);
// Random (32 bytes) - placeholder, will be replaced with digest
message.extend_from_slice(&self.random);
// Session ID
message.push(self.session_id.len() as u8);
message.extend_from_slice(&self.session_id);
// Cipher suite
message.extend_from_slice(&self.cipher_suite);
// Compression method
message.push(self.compression);
// Extensions length
message.extend_from_slice(&extensions_len.to_be_bytes());
// Extensions data
message.extend_from_slice(&extensions);
message
}
/// Build complete ServerHello TLS record
fn build_record(&self) -> Vec<u8> {
let message = self.build_message();
let mut record = Vec::with_capacity(5 + message.len());
// TLS record header
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(message.len() as u16).to_be_bytes());
// Message
record.extend_from_slice(&message);
record
}
}
@@ -320,7 +322,6 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
i64::try_from(d.as_secs()).ok()
}
fn validate_tls_handshake_at_time(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
@@ -346,12 +347,12 @@ fn validate_tls_handshake_at_time_with_boot_cap(
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
return None;
}
// Extract digest
let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.try_into()
.ok()?;
// Extract session ID
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
@@ -359,17 +360,17 @@ fn validate_tls_handshake_at_time_with_boot_cap(
return None;
}
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
return None;
}
let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec();
// Build message for HMAC (with zeroed digest)
let mut msg = handshake.to_vec();
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let mut first_match: Option<(&String, u32)> = None;
for (user, secret) in secrets {
@@ -408,7 +409,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
}
}
}
if first_match.is_none() {
first_match = Some((user, timestamp));
}
@@ -453,25 +454,30 @@ pub fn build_server_hello(
const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE;
let fake_cert_len = fake_cert_len.clamp(MIN_APP_DATA, MAX_APP_DATA);
let x25519_key = gen_fake_x25519_key(rng);
// Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec())
.with_x25519_key(&x25519_key)
.with_tls13_version()
.build_record();
// Build Change Cipher Spec record
let change_cipher_spec = [
TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0], TLS_VERSION[1],
0x00, 0x01, // length = 1
0x01, // CCS byte
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x01, // length = 1
0x01, // CCS byte
];
// Build first encrypted flight mimic as opaque ApplicationData bytes.
// Embed a compact EncryptedExtensions-like ALPN block when selected.
let mut fake_cert = Vec::with_capacity(fake_cert_len);
if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) {
if let Some(proto) = alpn
.as_ref()
.filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize)
{
let proto_list_len = 1usize + proto.len();
let ext_data_len = 2usize + proto_list_len;
let marker_len = 4usize + ext_data_len;
@@ -496,7 +502,7 @@ pub fn build_server_hello(
// Fill ApplicationData with fully random bytes of desired length to avoid
// deterministic DPI fingerprints (fixed inner content type markers).
app_data_record.extend_from_slice(&fake_cert);
// Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted;
// here we mimic with opaque ApplicationData records of plausible size).
let mut tickets = Vec::new();
@@ -515,7 +521,10 @@ pub fn build_server_hello(
// Combine all records
let mut response = Vec::with_capacity(
server_hello.len() + change_cipher_spec.len() + app_data_record.len() + tickets.iter().map(|r| r.len()).sum::<usize>()
server_hello.len()
+ change_cipher_spec.len()
+ app_data_record.len()
+ tickets.iter().map(|r| r.len()).sum::<usize>(),
);
response.extend_from_slice(&server_hello);
response.extend_from_slice(&change_cipher_spec);
@@ -523,18 +532,17 @@ pub fn build_server_hello(
for t in &tickets {
response.extend_from_slice(t);
}
// Compute HMAC for the response
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
hmac_input.extend_from_slice(client_digest);
hmac_input.extend_from_slice(&response);
let response_digest = sha256_hmac(secret, &hmac_input);
// Insert computed digest into ServerHello
// Position: record header (5) + message type (1) + length (3) + version (2) = 11
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&response_digest);
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&response_digest);
response
}
@@ -611,12 +619,14 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
let sn_end = std::cmp::min(sn_pos + list_len, pos + elen);
while sn_pos + 3 <= sn_end {
let name_type = handshake[sn_pos];
let name_len = u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize;
let name_len =
u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize;
sn_pos += 3;
if sn_pos + name_len > sn_end {
break;
}
if name_type == 0 && name_len > 0
if name_type == 0
&& name_len > 0
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
{
if is_valid_sni_hostname(host) {
@@ -679,35 +689,49 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
}
pos += 4; // type + len
pos += 2 + 32; // version + random
if pos >= handshake.len() { return Vec::new(); }
if pos >= handshake.len() {
return Vec::new();
}
let session_id_len = *handshake.get(pos).unwrap_or(&0) as usize;
pos += 1 + session_id_len;
if pos + 2 > handshake.len() { return Vec::new(); }
let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize;
if pos + 2 > handshake.len() {
return Vec::new();
}
let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2 + cipher_len;
if pos >= handshake.len() { return Vec::new(); }
if pos >= handshake.len() {
return Vec::new();
}
let comp_len = *handshake.get(pos).unwrap_or(&0) as usize;
pos += 1 + comp_len;
if pos + 2 > handshake.len() { return Vec::new(); }
let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize;
if pos + 2 > handshake.len() {
return Vec::new();
}
let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len;
if ext_end > handshake.len() { return Vec::new(); }
if ext_end > handshake.len() {
return Vec::new();
}
let mut out = Vec::new();
while pos + 4 <= ext_end {
let etype = u16::from_be_bytes([handshake[pos], handshake[pos+1]]);
let elen = u16::from_be_bytes([handshake[pos+2], handshake[pos+3]]) as usize;
let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize;
pos += 4;
if pos + elen > ext_end { break; }
if pos + elen > ext_end {
break;
}
if etype == extension_type::ALPN && elen >= 3 {
let list_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize;
let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
let mut lp = pos + 2;
let list_end = (pos + 2).saturating_add(list_len).min(pos + elen);
while lp < list_end {
let plen = handshake[lp] as usize;
lp += 1;
if lp + plen > list_end { break; }
out.push(handshake[lp..lp+plen].to_vec());
if lp + plen > list_end {
break;
}
out.push(handshake[lp..lp + plen].to_vec());
lp += plen;
}
break;
@@ -717,16 +741,15 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
out
}
/// Check if bytes look like a TLS ClientHello
pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
if first_bytes.len() < 3 {
return false;
}
// TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303.
first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03
first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03
&& (first_bytes[2] == 0x01 || first_bytes[2] == 0x03)
}
@@ -735,12 +758,12 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
let record_type = header[0];
let version = [header[1], header[2]];
// We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3
if version != [0x03, 0x01] && version != TLS_VERSION {
return None;
}
let length = u16::from_be_bytes([header[3], header[4]]);
Some((record_type, length))
}
@@ -756,7 +779,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
version: [0, 0],
});
}
// Check record header
if data[0] != TLS_RECORD_HANDSHAKE {
return Err(ProxyError::InvalidTlsRecord {
@@ -764,7 +787,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
version: [data[1], data[2]],
});
}
// Check version
if data[1..3] != TLS_VERSION {
return Err(ProxyError::InvalidTlsRecord {
@@ -772,31 +795,34 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
version: [data[1], data[2]],
});
}
// Check record length
let record_len = u16::from_be_bytes([data[3], data[4]]) as usize;
if data.len() < 5 + record_len {
return Err(ProxyError::InvalidHandshake(
format!("ServerHello record truncated: expected {}, got {}",
5 + record_len, data.len())
));
return Err(ProxyError::InvalidHandshake(format!(
"ServerHello record truncated: expected {}, got {}",
5 + record_len,
data.len()
)));
}
// Check message type
if data[5] != 0x02 {
return Err(ProxyError::InvalidHandshake(
format!("Expected ServerHello (0x02), got 0x{:02x}", data[5])
));
return Err(ProxyError::InvalidHandshake(format!(
"Expected ServerHello (0x02), got 0x{:02x}",
data[5]
)));
}
// Parse message length
let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize;
if msg_len + 4 != record_len {
return Err(ProxyError::InvalidHandshake(
format!("Message length mismatch: {} + 4 != {}", msg_len, record_len)
));
return Err(ProxyError::InvalidHandshake(format!(
"Message length mismatch: {} + 4 != {}",
msg_len, record_len
)));
}
Ok(())
}
@@ -806,7 +832,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
/// Using `static_assertions` ensures these can never silently break across
/// refactors without a compile error.
mod compile_time_security_checks {
use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN};
use super::{TLS_DIGEST_HALF_LEN, TLS_DIGEST_LEN};
use static_assertions::const_assert;
// The digest must be exactly one SHA-256 output.