Implement shared MTProto framing and ME address role separation

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-06-15 08:50:08 +03:00
parent d81d7dba62
commit 37d0184a0b
13 changed files with 415 additions and 184 deletions

View File

@@ -429,7 +429,7 @@ pub struct GeneralConfig {
pub ad_tag: Option<String>,
/// Public IP override for middle-proxy NAT environments.
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
/// When set, this IP is used in ME key derivation and local address translation.
#[serde(default)]
pub middle_proxy_nat_ip: Option<IpAddr>,

View File

@@ -5,6 +5,9 @@
use std::net::{IpAddr, Ipv4Addr};
use crate::crypto::SecureRandom;
use crate::protocol::framing::{
secure_version_d_body_len_from_wire_len, secure_version_d_padding_len,
};
use std::sync::LazyLock;
// ============= Telegram Datacenters =============
@@ -239,10 +242,7 @@ pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
/// Secure mode cannot distinguish full-word padding from payload, so only the
/// non-aligned tail bytes are stripped.
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
}
Some(wire_len - (wire_len % 4))
secure_version_d_body_len_from_wire_len(wire_len)
}
/// Generate padding length for Secure Intermediate protocol.
@@ -252,7 +252,7 @@ pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
is_valid_secure_payload_len(data_len),
"Secure payload must be 4-byte aligned, got {data_len}"
);
rng.range(16)
secure_version_d_padding_len(rng)
}
// ============= Timeouts =============

92
src/protocol/framing.rs Normal file
View File

@@ -0,0 +1,92 @@
//! Shared MTProto transport framing helpers.
use crate::crypto::SecureRandom;
/// QuickACK marker bit used by Intermediate and Secure Intermediate headers.
pub(crate) const INTERMEDIATE_QUICKACK_FLAG: u32 = 0x8000_0000;
/// Payload length mask used by Intermediate and Secure Intermediate headers.
pub(crate) const INTERMEDIATE_WIRE_LEN_MASK: u32 = 0x7fff_ffff;
/// Maximum random tail length used by Telegram Desktop VersionD packets.
pub(crate) const SECURE_VERSION_D_PADDING_MAX: usize = 15;
/// Parsed Intermediate/Secure Intermediate length header.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct IntermediateHeader {
/// Payload length on the wire, excluding the four-byte header.
pub(crate) wire_len: usize,
/// Whether the QuickACK marker bit was set in the length header.
pub(crate) quickack: bool,
}
/// Parse an Intermediate/Secure Intermediate length header.
pub(crate) fn parse_intermediate_header(header: [u8; 4]) -> IntermediateHeader {
let raw = u32::from_le_bytes(header);
IntermediateHeader {
wire_len: (raw & INTERMEDIATE_WIRE_LEN_MASK) as usize,
quickack: (raw & INTERMEDIATE_QUICKACK_FLAG) != 0,
}
}
/// Encode an Intermediate/Secure Intermediate length header.
pub(crate) fn encode_intermediate_header(wire_len: usize, quickack: bool) -> Option<u32> {
if wire_len > INTERMEDIATE_WIRE_LEN_MASK as usize {
return None;
}
let mut raw = u32::try_from(wire_len).ok()?;
if quickack {
raw |= INTERMEDIATE_QUICKACK_FLAG;
}
Some(raw)
}
/// Recover the VersionD body length visible to MTProto from the encrypted wire length.
pub(crate) fn secure_version_d_body_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
}
Some(wire_len - (wire_len % 4))
}
/// Generate Telegram Desktop-compatible VersionD random tail length.
pub(crate) fn secure_version_d_padding_len(rng: &SecureRandom) -> usize {
rng.range(SECURE_VERSION_D_PADDING_MAX + 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn intermediate_header_roundtrip_preserves_quickack_zero_length() {
let encoded = encode_intermediate_header(0, true).unwrap();
assert_eq!(encoded, INTERMEDIATE_QUICKACK_FLAG);
let parsed = parse_intermediate_header(encoded.to_le_bytes());
assert_eq!(parsed.wire_len, 0);
assert!(parsed.quickack);
}
#[test]
fn intermediate_header_rejects_lengths_above_31_bits() {
assert_eq!(
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize, false),
Some(INTERMEDIATE_WIRE_LEN_MASK)
);
assert_eq!(
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize + 1, false),
None
);
}
#[test]
fn secure_version_d_body_len_strips_only_non_word_tail() {
assert_eq!(secure_version_d_body_len_from_wire_len(3), None);
assert_eq!(secure_version_d_body_len_from_wire_len(8), Some(8));
assert_eq!(secure_version_d_body_len_from_wire_len(11), Some(8));
assert_eq!(secure_version_d_body_len_from_wire_len(12), Some(12));
}
}

View File

@@ -2,6 +2,7 @@
pub mod constants;
pub mod frame;
pub(crate) mod framing;
pub mod obfuscation;
pub mod tls;
pub mod tls_fingerprint;

View File

@@ -4,7 +4,6 @@
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use hmac::{Hmac, Mac};
#[cfg(test)]
use std::collections::HashSet;
use std::collections::hash_map::DefaultHasher;
@@ -33,8 +32,10 @@ use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::tls_front::{TlsFrontCache, emulator};
#[cfg(test)]
use rand::RngExt;
use sha2::Sha256;
use subtle::ConstantTimeEq;
mod tls_auth;
use self::tls_auth::{parse_tls_auth_material, validate_tls_secret_candidate};
const ACCESS_SECRET_BYTES: usize = 16;
const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5;
@@ -58,8 +59,6 @@ const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8;
const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64;
const RECENT_USER_RING_SCAN_LIMIT: usize = 32;
type HmacSha256 = Hmac<Sha256>;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
#[cfg(not(test))]
@@ -104,23 +103,6 @@ fn should_emit_unknown_sni_warn_in(shared: &ProxySharedState, now: Instant) -> b
true
}
#[derive(Clone, Copy)]
struct ParsedTlsAuthMaterial {
digest: [u8; tls::TLS_DIGEST_LEN],
session_id: [u8; 32],
session_id_len: usize,
now: i64,
ignore_time_skew: bool,
boot_time_cap_secs: u32,
}
#[derive(Clone, Copy)]
struct TlsCandidateValidation {
digest: [u8; tls::TLS_DIGEST_LEN],
session_id: [u8; 32],
session_id_len: usize,
}
struct MtprotoCandidateValidation {
proto_tag: ProtoTag,
dc_idx: i16,
@@ -251,104 +233,6 @@ fn budget_for_validation(total_users: usize, overload: bool, has_hint: bool) ->
total_users.min(cap.max(1))
}
fn parse_tls_auth_material(
handshake: &[u8],
ignore_time_skew: bool,
replay_window_secs: u64,
) -> Option<ParsedTlsAuthMaterial> {
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
return None;
}
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.try_into()
.ok()?;
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
if session_id_len > 32 {
return None;
}
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
return None;
}
let mut session_id = [0u8; 32];
session_id[..session_id_len]
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
let now = if !ignore_time_skew {
let d = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?;
i64::try_from(d.as_secs()).ok()?
} else {
0_i64
};
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
let boot_time_cap_secs = if ignore_time_skew {
0
} else {
tls::BOOT_TIME_MAX_SECS
.min(replay_window_u32)
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
};
Some(ParsedTlsAuthMaterial {
digest,
session_id,
session_id_len,
now,
ignore_time_skew,
boot_time_cap_secs,
})
}
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
mac.finalize().into_bytes().into()
}
fn validate_tls_secret_candidate(
parsed: &ParsedTlsAuthMaterial,
handshake: &[u8],
secret: &[u8],
) -> Option<TlsCandidateValidation> {
let computed = compute_tls_hmac_zeroed_digest(secret, handshake);
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
return None;
}
let timestamp = u32::from_le_bytes([
parsed.digest[28] ^ computed[28],
parsed.digest[29] ^ computed[29],
parsed.digest[30] ^ computed[30],
parsed.digest[31] ^ computed[31],
]);
if !parsed.ignore_time_skew {
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
if !is_boot_time {
let time_diff = parsed.now - i64::from(timestamp);
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
return None;
}
}
}
Some(TlsCandidateValidation {
digest: parsed.digest,
session_id: parsed.session_id,
session_id_len: parsed.session_id_len,
})
}
fn validate_mtproto_secret_candidate(
handshake: &[u8; HANDSHAKE_LEN],
dec_prekey: &[u8; PREKEY_LEN],
@@ -1857,7 +1741,16 @@ where
return HandshakeResult::BadClient { reader, writer };
}
let validation = matched_validation.expect("validation must exist when matched");
let Some(validation) = matched_validation else {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(
peer = %peer,
user = %matched_user,
"MTProto handshake matched user without validation material"
);
return HandshakeResult::BadClient { reader, writer };
};
if config
.access

View File

@@ -0,0 +1,126 @@
use hmac::{Hmac, Mac};
use sha2::Sha256;
use subtle::ConstantTimeEq;
use crate::protocol::tls;
type HmacSha256 = Hmac<Sha256>;
/// Parsed TLS authentication material extracted from a ClientHello candidate.
#[derive(Clone, Copy)]
pub(super) struct ParsedTlsAuthMaterial {
digest: [u8; tls::TLS_DIGEST_LEN],
session_id: [u8; 32],
session_id_len: usize,
now: i64,
ignore_time_skew: bool,
boot_time_cap_secs: u32,
}
/// Successful TLS secret validation output used by the handshake state machine.
#[derive(Clone, Copy)]
pub(super) struct TlsCandidateValidation {
pub(super) digest: [u8; tls::TLS_DIGEST_LEN],
pub(super) session_id: [u8; 32],
pub(super) session_id_len: usize,
}
/// Parse TLS auth digest and session-id material from a candidate handshake.
pub(super) fn parse_tls_auth_material(
handshake: &[u8],
ignore_time_skew: bool,
replay_window_secs: u64,
) -> Option<ParsedTlsAuthMaterial> {
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
return None;
}
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.try_into()
.ok()?;
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
if session_id_len > 32 {
return None;
}
let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len {
return None;
}
let mut session_id = [0u8; 32];
session_id[..session_id_len]
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
let now = if !ignore_time_skew {
let d = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?;
i64::try_from(d.as_secs()).ok()?
} else {
0_i64
};
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
let boot_time_cap_secs = if ignore_time_skew {
0
} else {
tls::BOOT_TIME_MAX_SECS
.min(replay_window_u32)
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
};
Some(ParsedTlsAuthMaterial {
digest,
session_id,
session_id_len,
now,
ignore_time_skew,
boot_time_cap_secs,
})
}
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> Option<[u8; 32]> {
let mut mac = HmacSha256::new_from_slice(secret).ok()?;
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
Some(mac.finalize().into_bytes().into())
}
/// Validate a candidate secret against parsed TLS authentication material.
pub(super) fn validate_tls_secret_candidate(
parsed: &ParsedTlsAuthMaterial,
handshake: &[u8],
secret: &[u8],
) -> Option<TlsCandidateValidation> {
let computed = compute_tls_hmac_zeroed_digest(secret, handshake)?;
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
return None;
}
let timestamp = u32::from_le_bytes([
parsed.digest[28] ^ computed[28],
parsed.digest[29] ^ computed[29],
parsed.digest[30] ^ computed[30],
parsed.digest[31] ^ computed[31],
]);
if !parsed.ignore_time_skew {
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
if !is_boot_time {
let time_diff = parsed.now - i64::from(timestamp);
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
return None;
}
}
}
Some(TlsCandidateValidation {
digest: parsed.digest,
session_id: parsed.session_id,
session_id_len: parsed.session_id_len,
})
}

View File

@@ -276,20 +276,17 @@ pub(in crate::proxy::middle_relay) fn compute_intermediate_secure_wire_len(
let wire_len = data_len
.checked_add(padding_len)
.ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?;
if wire_len > 0x7fff_ffffusize {
return Err(ProxyError::Proxy(format!(
"Intermediate/Secure frame too large: {wire_len}"
)));
}
let len_val =
crate::protocol::framing::encode_intermediate_header(wire_len, quickack).ok_or_else(
|| {
ProxyError::Proxy(format!(
"Intermediate/Secure frame too large: {wire_len}"
))
},
)?;
let total = 4usize
.checked_add(wire_len)
.ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?;
let mut len_val = u32::try_from(wire_len)
.map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?;
if quickack {
len_val |= 0x8000_0000;
}
Ok((len_val, total))
}

View File

@@ -236,10 +236,10 @@ where
}
Err(e) => return Err(e),
}
let quickack = (len_buf[3] & 0x80) != 0;
let header = crate::protocol::framing::parse_intermediate_header(len_buf);
(
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize,
quickack,
header.wire_len,
header.quickack,
Some(len_buf),
)
}

View File

@@ -15,6 +15,7 @@ use crate::crypto::SecureRandom;
use crate::protocol::constants::{
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
};
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
// ============= Unified Codec =============
@@ -197,13 +198,9 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
let len = header.wire_len;
meta.quickack = header.quickack;
// Validate size
if len > max_size {
@@ -239,10 +236,12 @@ fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
dst.reserve(4 + data.len());
let mut len = data.len() as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
let len = encode_intermediate_header(data.len(), frame.meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("frame too large: {} bytes", data.len()),
)
})?;
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
@@ -258,13 +257,9 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
let len = header.wire_len;
meta.quickack = header.quickack;
// Validate size
if len > max_size {
@@ -323,10 +318,12 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
let total_len = data.len() + padding_len;
dst.reserve(4 + total_len);
let mut len = total_len as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
let len = encode_intermediate_header(total_len, frame.meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("frame too large: {} bytes", total_len),
)
})?;
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);

View File

@@ -5,6 +5,7 @@
use super::traits::{FrameMeta, LayeredStream};
use crate::crypto::{SecureRandom, crc32};
use crate::protocol::constants::*;
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
use bytes::Bytes;
use std::io::{Error, ErrorKind, Result};
use std::sync::Arc;
@@ -105,10 +106,17 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
if len_div_4 < 0x7f {
// Short length (1 byte)
self.upstream.write_all(&[len_div_4 as u8]).await?;
let mut first = len_div_4 as u8;
if meta.quickack {
first |= 0x80;
}
self.upstream.write_all(&[first]).await?;
} else if len_div_4 < (1 << 24) {
// Long length (4 bytes: 0x7f + 3 bytes)
let mut header = [0x7f, 0, 0, 0];
if meta.quickack {
header[0] |= 0x80;
}
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
self.upstream.write_all(&header).await?;
} else {
@@ -160,13 +168,9 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag (high bit)
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
let header = parse_intermediate_header(len_bytes);
let len = header.wire_len;
meta.quickack = header.quickack;
// Read data
let mut data = vec![0u8; len];
@@ -204,7 +208,13 @@ impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
if meta.simple_ack {
self.upstream.write_all(data).await?;
} else {
let len_bytes = (data.len() as u32).to_le_bytes();
let len = encode_intermediate_header(data.len(), meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("Frame too large: {} bytes", data.len()),
)
})?;
let len_bytes = len.to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
}
@@ -249,13 +259,9 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
let header = parse_intermediate_header(len_bytes);
let len = header.wire_len;
meta.quickack = header.quickack;
// Read data (including padding)
let mut data = vec![0u8; len];
@@ -316,7 +322,13 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes();
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
format!("Frame too large: {total_len} bytes"),
)
})?;
let len_bytes = len.to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
@@ -623,6 +635,43 @@ mod tests {
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_intermediate_quickack_zero_length_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = IntermediateFrameWriter::new(client);
let mut reader = IntermediateFrameReader::new(server);
writer
.write_frame(&[], &FrameMeta::new().with_quickack())
.await
.unwrap();
writer.flush().await.unwrap();
let (received, meta) = reader.read_frame().await.unwrap();
assert!(received.is_empty());
assert!(meta.quickack);
}
#[tokio::test]
async fn test_abridged_quickack_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
let data = vec![1u8, 2, 3, 4];
writer
.write_frame(&data, &FrameMeta::new().with_quickack())
.await
.unwrap();
writer.flush().await.unwrap();
let (received, meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
assert!(meta.quickack);
}
#[tokio::test]
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);

View File

@@ -293,6 +293,8 @@ impl MePool {
};
record_bnd_status(bnd_addr_status, bnd_port_status, raw_socks_bound_addr);
let socks_bound_kdf_addr = socks_bound_addr.filter(|bound| bound.port() != 0);
// SOCKS BND is the only reflected source that can supply both KDF IP and
// port. Direct STUN reflection is IP-only and keeps the TCP local port.
let reflected = if let Some(bound) = socks_bound_kdf_addr {
Some(bound)
} else if is_socks_route {
@@ -417,6 +419,7 @@ impl MePool {
key_selector = format_args!("0x{ks:08x}"),
crypto_schema = format_args!("0x{schema:08x}"),
skew_secs = skew,
socks_kdf_policy = ?self.socks_kdf_policy(),
"ME key derivation parameters"
);

View File

@@ -464,8 +464,7 @@ impl MePool {
if !self.writer_accepts_new_binding(w) {
continue;
}
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
let (payload, meta) = build_routed_payload(our_addr);
match w.tx.clone().try_reserve_owned() {
Ok(permit) => {
if !self.registry.bind_writer(conn_id, w.id, meta).await {
@@ -520,8 +519,7 @@ impl MePool {
}
self.stats
.increment_me_writer_pick_blocking_fallback_total();
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
let (payload, meta) = build_routed_payload(our_addr);
let reserve_result =
if let Some(timeout) = self.route_runtime.me_route_blocking_send_timeout {
match tokio::time::timeout(timeout, w.tx.clone().reserve_owned()).await {

View File

@@ -177,6 +177,37 @@ async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duratio
data_count
}
async fn recv_first_data_payload(
rx: &mut mpsc::Receiver<WriterCommand>,
budget: Duration,
) -> Option<Vec<u8>> {
let start = Instant::now();
while Instant::now().duration_since(start) < budget {
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
Ok(Some(WriterCommand::Data(payload))) => return Some(payload.to_vec()),
Ok(Some(WriterCommand::DataAndFlush(payload))) => return Some(payload.to_vec()),
Ok(Some(_)) => {}
Ok(None) => break,
Err(_) => break,
}
}
None
}
fn proxy_req_our_addr_from_payload(payload: &[u8]) -> SocketAddr {
const CLIENT_ADDR_WIRE_LEN: usize = 20;
const OUR_ADDR_OFFSET: usize = 4 + 4 + 8 + CLIENT_ADDR_WIRE_LEN;
let our_addr = &payload[OUR_ADDR_OFFSET..OUR_ADDR_OFFSET + CLIENT_ADDR_WIRE_LEN];
let ip = Ipv4Addr::new(our_addr[12], our_addr[13], our_addr[14], our_addr[15]);
let port = u32::from_le_bytes([our_addr[16], our_addr[17], our_addr[18], our_addr[19]]);
SocketAddr::new(
IpAddr::V4(ip),
u16::try_from(port).expect("test port must fit u16"),
)
}
#[tokio::test]
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
let (pool, _rng) = make_pool().await;
@@ -290,3 +321,47 @@ async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay
drop(writers);
assert_eq!(writer_ids, vec![23]);
}
#[tokio::test]
async fn send_proxy_req_preserves_client_facing_our_addr_when_writer_source_ip_differs() {
let (pool, _rng) = make_pool().await;
pool.rr.store(0, Ordering::Relaxed);
let (conn_id, _rx) = pool.registry.register().await;
let mut live_rx = insert_writer(
&pool,
31,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 2, 31)), 443),
true,
)
.await;
{
let mut writers = pool.writers.write().await;
let writer = writers
.iter_mut()
.find(|writer| writer.id == 31)
.expect("test writer must exist");
writer.source_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 31));
}
let our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)), 8443);
let result = pool
.send_proxy_req(
conn_id,
2,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 7)), 30002),
our_addr,
b"route",
0,
None,
)
.await;
assert!(result.is_ok());
let payload = recv_first_data_payload(&mut live_rx, Duration::from_millis(50))
.await
.expect("writer must receive routed payload");
assert_eq!(proxy_req_our_addr_from_payload(&payload), our_addr);
}