mirror of
https://github.com/telemt/telemt.git
synced 2026-06-15 15:31:43 +03:00
Implement shared MTProto framing and ME address role separation
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
@@ -429,7 +429,7 @@ pub struct GeneralConfig {
|
||||
pub ad_tag: Option<String>,
|
||||
|
||||
/// Public IP override for middle-proxy NAT environments.
|
||||
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
|
||||
/// When set, this IP is used in ME key derivation and local address translation.
|
||||
#[serde(default)]
|
||||
pub middle_proxy_nat_ip: Option<IpAddr>,
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::protocol::framing::{
|
||||
secure_version_d_body_len_from_wire_len, secure_version_d_padding_len,
|
||||
};
|
||||
use std::sync::LazyLock;
|
||||
|
||||
// ============= Telegram Datacenters =============
|
||||
@@ -239,10 +242,7 @@ pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
|
||||
/// Secure mode cannot distinguish full-word padding from payload, so only the
|
||||
/// non-aligned tail bytes are stripped.
|
||||
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
|
||||
if wire_len < 4 {
|
||||
return None;
|
||||
}
|
||||
Some(wire_len - (wire_len % 4))
|
||||
secure_version_d_body_len_from_wire_len(wire_len)
|
||||
}
|
||||
|
||||
/// Generate padding length for Secure Intermediate protocol.
|
||||
@@ -252,7 +252,7 @@ pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
|
||||
is_valid_secure_payload_len(data_len),
|
||||
"Secure payload must be 4-byte aligned, got {data_len}"
|
||||
);
|
||||
rng.range(16)
|
||||
secure_version_d_padding_len(rng)
|
||||
}
|
||||
|
||||
// ============= Timeouts =============
|
||||
|
||||
92
src/protocol/framing.rs
Normal file
92
src/protocol/framing.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Shared MTProto transport framing helpers.
|
||||
|
||||
use crate::crypto::SecureRandom;
|
||||
|
||||
/// QuickACK marker bit used by Intermediate and Secure Intermediate headers.
|
||||
pub(crate) const INTERMEDIATE_QUICKACK_FLAG: u32 = 0x8000_0000;
|
||||
|
||||
/// Payload length mask used by Intermediate and Secure Intermediate headers.
|
||||
pub(crate) const INTERMEDIATE_WIRE_LEN_MASK: u32 = 0x7fff_ffff;
|
||||
|
||||
/// Maximum random tail length used by Telegram Desktop VersionD packets.
|
||||
pub(crate) const SECURE_VERSION_D_PADDING_MAX: usize = 15;
|
||||
|
||||
/// Parsed Intermediate/Secure Intermediate length header.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub(crate) struct IntermediateHeader {
|
||||
/// Payload length on the wire, excluding the four-byte header.
|
||||
pub(crate) wire_len: usize,
|
||||
/// Whether the QuickACK marker bit was set in the length header.
|
||||
pub(crate) quickack: bool,
|
||||
}
|
||||
|
||||
/// Parse an Intermediate/Secure Intermediate length header.
|
||||
pub(crate) fn parse_intermediate_header(header: [u8; 4]) -> IntermediateHeader {
|
||||
let raw = u32::from_le_bytes(header);
|
||||
IntermediateHeader {
|
||||
wire_len: (raw & INTERMEDIATE_WIRE_LEN_MASK) as usize,
|
||||
quickack: (raw & INTERMEDIATE_QUICKACK_FLAG) != 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode an Intermediate/Secure Intermediate length header.
|
||||
pub(crate) fn encode_intermediate_header(wire_len: usize, quickack: bool) -> Option<u32> {
|
||||
if wire_len > INTERMEDIATE_WIRE_LEN_MASK as usize {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut raw = u32::try_from(wire_len).ok()?;
|
||||
if quickack {
|
||||
raw |= INTERMEDIATE_QUICKACK_FLAG;
|
||||
}
|
||||
Some(raw)
|
||||
}
|
||||
|
||||
/// Recover the VersionD body length visible to MTProto from the encrypted wire length.
|
||||
pub(crate) fn secure_version_d_body_len_from_wire_len(wire_len: usize) -> Option<usize> {
|
||||
if wire_len < 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(wire_len - (wire_len % 4))
|
||||
}
|
||||
|
||||
/// Generate Telegram Desktop-compatible VersionD random tail length.
|
||||
pub(crate) fn secure_version_d_padding_len(rng: &SecureRandom) -> usize {
|
||||
rng.range(SECURE_VERSION_D_PADDING_MAX + 1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn intermediate_header_roundtrip_preserves_quickack_zero_length() {
|
||||
let encoded = encode_intermediate_header(0, true).unwrap();
|
||||
assert_eq!(encoded, INTERMEDIATE_QUICKACK_FLAG);
|
||||
|
||||
let parsed = parse_intermediate_header(encoded.to_le_bytes());
|
||||
assert_eq!(parsed.wire_len, 0);
|
||||
assert!(parsed.quickack);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn intermediate_header_rejects_lengths_above_31_bits() {
|
||||
assert_eq!(
|
||||
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize, false),
|
||||
Some(INTERMEDIATE_WIRE_LEN_MASK)
|
||||
);
|
||||
assert_eq!(
|
||||
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize + 1, false),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_version_d_body_len_strips_only_non_word_tail() {
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(3), None);
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(8), Some(8));
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(11), Some(8));
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(12), Some(12));
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
pub mod constants;
|
||||
pub mod frame;
|
||||
pub(crate) mod framing;
|
||||
pub mod obfuscation;
|
||||
pub mod tls;
|
||||
pub mod tls_fingerprint;
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
use dashmap::DashMap;
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use hmac::{Hmac, Mac};
|
||||
#[cfg(test)]
|
||||
use std::collections::HashSet;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
@@ -33,8 +32,10 @@ use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
|
||||
use crate::tls_front::{TlsFrontCache, emulator};
|
||||
#[cfg(test)]
|
||||
use rand::RngExt;
|
||||
use sha2::Sha256;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
mod tls_auth;
|
||||
|
||||
use self::tls_auth::{parse_tls_auth_material, validate_tls_secret_candidate};
|
||||
|
||||
const ACCESS_SECRET_BYTES: usize = 16;
|
||||
const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5;
|
||||
@@ -58,8 +59,6 @@ const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8;
|
||||
const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64;
|
||||
const RECENT_USER_RING_SCAN_LIMIT: usize = 32;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
#[cfg(test)]
|
||||
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
|
||||
#[cfg(not(test))]
|
||||
@@ -104,23 +103,6 @@ fn should_emit_unknown_sni_warn_in(shared: &ProxySharedState, now: Instant) -> b
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct ParsedTlsAuthMaterial {
|
||||
digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
session_id: [u8; 32],
|
||||
session_id_len: usize,
|
||||
now: i64,
|
||||
ignore_time_skew: bool,
|
||||
boot_time_cap_secs: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct TlsCandidateValidation {
|
||||
digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
session_id: [u8; 32],
|
||||
session_id_len: usize,
|
||||
}
|
||||
|
||||
struct MtprotoCandidateValidation {
|
||||
proto_tag: ProtoTag,
|
||||
dc_idx: i16,
|
||||
@@ -251,104 +233,6 @@ fn budget_for_validation(total_users: usize, overload: bool, has_hint: bool) ->
|
||||
total_users.min(cap.max(1))
|
||||
}
|
||||
|
||||
fn parse_tls_auth_material(
|
||||
handshake: &[u8],
|
||||
ignore_time_skew: bool,
|
||||
replay_window_secs: u64,
|
||||
) -> Option<ParsedTlsAuthMaterial> {
|
||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
|
||||
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.try_into()
|
||||
.ok()?;
|
||||
|
||||
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
|
||||
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
|
||||
if session_id_len > 32 {
|
||||
return None;
|
||||
}
|
||||
let session_id_start = session_id_len_pos + 1;
|
||||
if handshake.len() < session_id_start + session_id_len {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut session_id = [0u8; 32];
|
||||
session_id[..session_id_len]
|
||||
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
|
||||
|
||||
let now = if !ignore_time_skew {
|
||||
let d = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.ok()?;
|
||||
i64::try_from(d.as_secs()).ok()?
|
||||
} else {
|
||||
0_i64
|
||||
};
|
||||
|
||||
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||
let boot_time_cap_secs = if ignore_time_skew {
|
||||
0
|
||||
} else {
|
||||
tls::BOOT_TIME_MAX_SECS
|
||||
.min(replay_window_u32)
|
||||
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
|
||||
};
|
||||
|
||||
Some(ParsedTlsAuthMaterial {
|
||||
digest,
|
||||
session_id,
|
||||
session_id_len,
|
||||
now,
|
||||
ignore_time_skew,
|
||||
boot_time_cap_secs,
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> [u8; 32] {
|
||||
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
|
||||
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
|
||||
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
|
||||
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
|
||||
mac.finalize().into_bytes().into()
|
||||
}
|
||||
|
||||
fn validate_tls_secret_candidate(
|
||||
parsed: &ParsedTlsAuthMaterial,
|
||||
handshake: &[u8],
|
||||
secret: &[u8],
|
||||
) -> Option<TlsCandidateValidation> {
|
||||
let computed = compute_tls_hmac_zeroed_digest(secret, handshake);
|
||||
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let timestamp = u32::from_le_bytes([
|
||||
parsed.digest[28] ^ computed[28],
|
||||
parsed.digest[29] ^ computed[29],
|
||||
parsed.digest[30] ^ computed[30],
|
||||
parsed.digest[31] ^ computed[31],
|
||||
]);
|
||||
|
||||
if !parsed.ignore_time_skew {
|
||||
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
|
||||
if !is_boot_time {
|
||||
let time_diff = parsed.now - i64::from(timestamp);
|
||||
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(TlsCandidateValidation {
|
||||
digest: parsed.digest,
|
||||
session_id: parsed.session_id,
|
||||
session_id_len: parsed.session_id_len,
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_mtproto_secret_candidate(
|
||||
handshake: &[u8; HANDSHAKE_LEN],
|
||||
dec_prekey: &[u8; PREKEY_LEN],
|
||||
@@ -1857,7 +1741,16 @@ where
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
let validation = matched_validation.expect("validation must exist when matched");
|
||||
let Some(validation) = matched_validation else {
|
||||
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
warn!(
|
||||
peer = %peer,
|
||||
user = %matched_user,
|
||||
"MTProto handshake matched user without validation material"
|
||||
);
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
};
|
||||
|
||||
if config
|
||||
.access
|
||||
|
||||
126
src/proxy/handshake/tls_auth.rs
Normal file
126
src/proxy/handshake/tls_auth.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use crate::protocol::tls;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Parsed TLS authentication material extracted from a ClientHello candidate.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(super) struct ParsedTlsAuthMaterial {
|
||||
digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
session_id: [u8; 32],
|
||||
session_id_len: usize,
|
||||
now: i64,
|
||||
ignore_time_skew: bool,
|
||||
boot_time_cap_secs: u32,
|
||||
}
|
||||
|
||||
/// Successful TLS secret validation output used by the handshake state machine.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(super) struct TlsCandidateValidation {
|
||||
pub(super) digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
pub(super) session_id: [u8; 32],
|
||||
pub(super) session_id_len: usize,
|
||||
}
|
||||
|
||||
/// Parse TLS auth digest and session-id material from a candidate handshake.
|
||||
pub(super) fn parse_tls_auth_material(
|
||||
handshake: &[u8],
|
||||
ignore_time_skew: bool,
|
||||
replay_window_secs: u64,
|
||||
) -> Option<ParsedTlsAuthMaterial> {
|
||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
|
||||
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.try_into()
|
||||
.ok()?;
|
||||
|
||||
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
|
||||
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
|
||||
if session_id_len > 32 {
|
||||
return None;
|
||||
}
|
||||
let session_id_start = session_id_len_pos + 1;
|
||||
if handshake.len() < session_id_start + session_id_len {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut session_id = [0u8; 32];
|
||||
session_id[..session_id_len]
|
||||
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
|
||||
|
||||
let now = if !ignore_time_skew {
|
||||
let d = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.ok()?;
|
||||
i64::try_from(d.as_secs()).ok()?
|
||||
} else {
|
||||
0_i64
|
||||
};
|
||||
|
||||
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||
let boot_time_cap_secs = if ignore_time_skew {
|
||||
0
|
||||
} else {
|
||||
tls::BOOT_TIME_MAX_SECS
|
||||
.min(replay_window_u32)
|
||||
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
|
||||
};
|
||||
|
||||
Some(ParsedTlsAuthMaterial {
|
||||
digest,
|
||||
session_id,
|
||||
session_id_len,
|
||||
now,
|
||||
ignore_time_skew,
|
||||
boot_time_cap_secs,
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> Option<[u8; 32]> {
|
||||
let mut mac = HmacSha256::new_from_slice(secret).ok()?;
|
||||
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
|
||||
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
|
||||
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
|
||||
Some(mac.finalize().into_bytes().into())
|
||||
}
|
||||
|
||||
/// Validate a candidate secret against parsed TLS authentication material.
|
||||
pub(super) fn validate_tls_secret_candidate(
|
||||
parsed: &ParsedTlsAuthMaterial,
|
||||
handshake: &[u8],
|
||||
secret: &[u8],
|
||||
) -> Option<TlsCandidateValidation> {
|
||||
let computed = compute_tls_hmac_zeroed_digest(secret, handshake)?;
|
||||
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let timestamp = u32::from_le_bytes([
|
||||
parsed.digest[28] ^ computed[28],
|
||||
parsed.digest[29] ^ computed[29],
|
||||
parsed.digest[30] ^ computed[30],
|
||||
parsed.digest[31] ^ computed[31],
|
||||
]);
|
||||
|
||||
if !parsed.ignore_time_skew {
|
||||
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
|
||||
if !is_boot_time {
|
||||
let time_diff = parsed.now - i64::from(timestamp);
|
||||
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(TlsCandidateValidation {
|
||||
digest: parsed.digest,
|
||||
session_id: parsed.session_id,
|
||||
session_id_len: parsed.session_id_len,
|
||||
})
|
||||
}
|
||||
@@ -276,20 +276,17 @@ pub(in crate::proxy::middle_relay) fn compute_intermediate_secure_wire_len(
|
||||
let wire_len = data_len
|
||||
.checked_add(padding_len)
|
||||
.ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?;
|
||||
if wire_len > 0x7fff_ffffusize {
|
||||
return Err(ProxyError::Proxy(format!(
|
||||
"Intermediate/Secure frame too large: {wire_len}"
|
||||
)));
|
||||
}
|
||||
|
||||
let len_val =
|
||||
crate::protocol::framing::encode_intermediate_header(wire_len, quickack).ok_or_else(
|
||||
|| {
|
||||
ProxyError::Proxy(format!(
|
||||
"Intermediate/Secure frame too large: {wire_len}"
|
||||
))
|
||||
},
|
||||
)?;
|
||||
let total = 4usize
|
||||
.checked_add(wire_len)
|
||||
.ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?;
|
||||
let mut len_val = u32::try_from(wire_len)
|
||||
.map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?;
|
||||
if quickack {
|
||||
len_val |= 0x8000_0000;
|
||||
}
|
||||
Ok((len_val, total))
|
||||
}
|
||||
|
||||
|
||||
@@ -236,10 +236,10 @@ where
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
let quickack = (len_buf[3] & 0x80) != 0;
|
||||
let header = crate::protocol::framing::parse_intermediate_header(len_buf);
|
||||
(
|
||||
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize,
|
||||
quickack,
|
||||
header.wire_len,
|
||||
header.quickack,
|
||||
Some(len_buf),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ use crate::crypto::SecureRandom;
|
||||
use crate::protocol::constants::{
|
||||
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
|
||||
};
|
||||
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
|
||||
|
||||
// ============= Unified Codec =============
|
||||
|
||||
@@ -197,13 +198,9 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option
|
||||
}
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
@@ -239,10 +236,12 @@ fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
|
||||
dst.reserve(4 + data.len());
|
||||
|
||||
let mut len = data.len() as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
let len = encode_intermediate_header(data.len(), frame.meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("frame too large: {} bytes", data.len()),
|
||||
)
|
||||
})?;
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
@@ -258,13 +257,9 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
|
||||
}
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
@@ -323,10 +318,12 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
|
||||
let total_len = data.len() + padding_len;
|
||||
dst.reserve(4 + total_len);
|
||||
|
||||
let mut len = total_len as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
let len = encode_intermediate_header(total_len, frame.meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("frame too large: {} bytes", total_len),
|
||||
)
|
||||
})?;
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
use super::traits::{FrameMeta, LayeredStream};
|
||||
use crate::crypto::{SecureRandom, crc32};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
|
||||
use bytes::Bytes;
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use std::sync::Arc;
|
||||
@@ -105,10 +106,17 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
|
||||
|
||||
if len_div_4 < 0x7f {
|
||||
// Short length (1 byte)
|
||||
self.upstream.write_all(&[len_div_4 as u8]).await?;
|
||||
let mut first = len_div_4 as u8;
|
||||
if meta.quickack {
|
||||
first |= 0x80;
|
||||
}
|
||||
self.upstream.write_all(&[first]).await?;
|
||||
} else if len_div_4 < (1 << 24) {
|
||||
// Long length (4 bytes: 0x7f + 3 bytes)
|
||||
let mut header = [0x7f, 0, 0, 0];
|
||||
if meta.quickack {
|
||||
header[0] |= 0x80;
|
||||
}
|
||||
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
|
||||
self.upstream.write_all(&header).await?;
|
||||
} else {
|
||||
@@ -160,13 +168,9 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; len];
|
||||
@@ -204,7 +208,13 @@ impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
|
||||
if meta.simple_ack {
|
||||
self.upstream.write_all(data).await?;
|
||||
} else {
|
||||
let len_bytes = (data.len() as u32).to_le_bytes();
|
||||
let len = encode_intermediate_header(data.len(), meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Frame too large: {} bytes", data.len()),
|
||||
)
|
||||
})?;
|
||||
let len_bytes = len.to_le_bytes();
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
}
|
||||
@@ -249,13 +259,9 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Read data (including padding)
|
||||
let mut data = vec![0u8; len];
|
||||
@@ -316,7 +322,13 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
let padding = self.rng.bytes(padding_len);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let len_bytes = (total_len as u32).to_le_bytes();
|
||||
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Frame too large: {total_len} bytes"),
|
||||
)
|
||||
})?;
|
||||
let len_bytes = len.to_le_bytes();
|
||||
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
@@ -623,6 +635,43 @@ mod tests {
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_intermediate_quickack_zero_length_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = IntermediateFrameWriter::new(client);
|
||||
let mut reader = IntermediateFrameReader::new(server);
|
||||
|
||||
writer
|
||||
.write_frame(&[], &FrameMeta::new().with_quickack())
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, meta) = reader.read_frame().await.unwrap();
|
||||
assert!(received.is_empty());
|
||||
assert!(meta.quickack);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_quickack_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4];
|
||||
writer
|
||||
.write_frame(&data, &FrameMeta::new().with_quickack())
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
assert!(meta.quickack);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
@@ -293,6 +293,8 @@ impl MePool {
|
||||
};
|
||||
record_bnd_status(bnd_addr_status, bnd_port_status, raw_socks_bound_addr);
|
||||
let socks_bound_kdf_addr = socks_bound_addr.filter(|bound| bound.port() != 0);
|
||||
// SOCKS BND is the only reflected source that can supply both KDF IP and
|
||||
// port. Direct STUN reflection is IP-only and keeps the TCP local port.
|
||||
let reflected = if let Some(bound) = socks_bound_kdf_addr {
|
||||
Some(bound)
|
||||
} else if is_socks_route {
|
||||
@@ -417,6 +419,7 @@ impl MePool {
|
||||
key_selector = format_args!("0x{ks:08x}"),
|
||||
crypto_schema = format_args!("0x{schema:08x}"),
|
||||
skew_secs = skew,
|
||||
socks_kdf_policy = ?self.socks_kdf_policy(),
|
||||
"ME key derivation parameters"
|
||||
);
|
||||
|
||||
|
||||
@@ -464,8 +464,7 @@ impl MePool {
|
||||
if !self.writer_accepts_new_binding(w) {
|
||||
continue;
|
||||
}
|
||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||
let (payload, meta) = build_routed_payload(our_addr);
|
||||
match w.tx.clone().try_reserve_owned() {
|
||||
Ok(permit) => {
|
||||
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
||||
@@ -520,8 +519,7 @@ impl MePool {
|
||||
}
|
||||
self.stats
|
||||
.increment_me_writer_pick_blocking_fallback_total();
|
||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||
let (payload, meta) = build_routed_payload(our_addr);
|
||||
let reserve_result =
|
||||
if let Some(timeout) = self.route_runtime.me_route_blocking_send_timeout {
|
||||
match tokio::time::timeout(timeout, w.tx.clone().reserve_owned()).await {
|
||||
|
||||
@@ -177,6 +177,37 @@ async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duratio
|
||||
data_count
|
||||
}
|
||||
|
||||
async fn recv_first_data_payload(
|
||||
rx: &mut mpsc::Receiver<WriterCommand>,
|
||||
budget: Duration,
|
||||
) -> Option<Vec<u8>> {
|
||||
let start = Instant::now();
|
||||
while Instant::now().duration_since(start) < budget {
|
||||
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
|
||||
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
|
||||
Ok(Some(WriterCommand::Data(payload))) => return Some(payload.to_vec()),
|
||||
Ok(Some(WriterCommand::DataAndFlush(payload))) => return Some(payload.to_vec()),
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => break,
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn proxy_req_our_addr_from_payload(payload: &[u8]) -> SocketAddr {
|
||||
const CLIENT_ADDR_WIRE_LEN: usize = 20;
|
||||
const OUR_ADDR_OFFSET: usize = 4 + 4 + 8 + CLIENT_ADDR_WIRE_LEN;
|
||||
|
||||
let our_addr = &payload[OUR_ADDR_OFFSET..OUR_ADDR_OFFSET + CLIENT_ADDR_WIRE_LEN];
|
||||
let ip = Ipv4Addr::new(our_addr[12], our_addr[13], our_addr[14], our_addr[15]);
|
||||
let port = u32::from_le_bytes([our_addr[16], our_addr[17], our_addr[18], our_addr[19]]);
|
||||
SocketAddr::new(
|
||||
IpAddr::V4(ip),
|
||||
u16::try_from(port).expect("test port must fit u16"),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
|
||||
let (pool, _rng) = make_pool().await;
|
||||
@@ -290,3 +321,47 @@ async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay
|
||||
drop(writers);
|
||||
assert_eq!(writer_ids, vec![23]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_proxy_req_preserves_client_facing_our_addr_when_writer_source_ip_differs() {
|
||||
let (pool, _rng) = make_pool().await;
|
||||
pool.rr.store(0, Ordering::Relaxed);
|
||||
|
||||
let (conn_id, _rx) = pool.registry.register().await;
|
||||
let mut live_rx = insert_writer(
|
||||
&pool,
|
||||
31,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 2, 31)), 443),
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
{
|
||||
let mut writers = pool.writers.write().await;
|
||||
let writer = writers
|
||||
.iter_mut()
|
||||
.find(|writer| writer.id == 31)
|
||||
.expect("test writer must exist");
|
||||
writer.source_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 31));
|
||||
}
|
||||
|
||||
let our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)), 8443);
|
||||
let result = pool
|
||||
.send_proxy_req(
|
||||
conn_id,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 7)), 30002),
|
||||
our_addr,
|
||||
b"route",
|
||||
0,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let payload = recv_first_data_payload(&mut live_rx, Duration::from_millis(50))
|
||||
.await
|
||||
.expect("writer must receive routed payload");
|
||||
assert_eq!(proxy_req_our_addr_from_payload(&payload), our_addr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user