mirror of
https://github.com/telemt/telemt.git
synced 2026-06-18 17:08:29 +03:00
Harden masking fallback and frame readers after flow sync
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
@@ -1940,6 +1940,13 @@ impl ProxyConfig {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.server.listen_backlog == 0 || config.server.listen_backlog > i32::MAX as u32 {
|
||||||
|
return Err(ProxyError::Config(format!(
|
||||||
|
"server.listen_backlog must be within [1, {}]",
|
||||||
|
i32::MAX
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
config
|
config
|
||||||
.server
|
.server
|
||||||
.client_mss_value()
|
.client_mss_value()
|
||||||
|
|||||||
@@ -95,6 +95,44 @@ max_client_frame = 16777217
|
|||||||
remove_temp_config(&path);
|
remove_temp_config(&path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_rejects_listen_backlog_above_i32_upper_bound() {
|
||||||
|
let path = write_temp_config(
|
||||||
|
r#"
|
||||||
|
[server]
|
||||||
|
listen_backlog = 2147483648
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let err = ProxyConfig::load(&path).expect_err("listen_backlog above socket cap must fail");
|
||||||
|
let msg = err.to_string();
|
||||||
|
assert!(
|
||||||
|
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
|
||||||
|
"error must explain listen_backlog hard cap, got: {msg}"
|
||||||
|
);
|
||||||
|
|
||||||
|
remove_temp_config(&path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_rejects_zero_listen_backlog() {
|
||||||
|
let path = write_temp_config(
|
||||||
|
r#"
|
||||||
|
[server]
|
||||||
|
listen_backlog = 0
|
||||||
|
"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let err = ProxyConfig::load(&path).expect_err("zero listen_backlog must fail");
|
||||||
|
let msg = err.to_string();
|
||||||
|
assert!(
|
||||||
|
msg.contains("server.listen_backlog must be within [1, 2147483647]"),
|
||||||
|
"error must explain listen_backlog lower bound, got: {msg}"
|
||||||
|
);
|
||||||
|
|
||||||
|
remove_temp_config(&path);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn load_accepts_memory_limits_at_hard_upper_bounds() {
|
fn load_accepts_memory_limits_at_hard_upper_bounds() {
|
||||||
let path = write_temp_config(
|
let path = write_temp_config(
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ use crate::proxy::handshake::{
|
|||||||
};
|
};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake};
|
use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake};
|
||||||
use crate::proxy::masking::handle_bad_client;
|
use crate::proxy::masking::handle_bad_client_with_shared;
|
||||||
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
||||||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||||
use crate::proxy::shared_state::ProxySharedState;
|
use crate::proxy::shared_state::ProxySharedState;
|
||||||
@@ -310,6 +310,7 @@ fn masking_outcome<R, W>(
|
|||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
beobachten: Arc<BeobachtenStore>,
|
beobachten: Arc<BeobachtenStore>,
|
||||||
|
shared: Arc<ProxySharedState>,
|
||||||
) -> HandshakeOutcome
|
) -> HandshakeOutcome
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
@@ -325,7 +326,7 @@ where
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
handle_bad_client(
|
handle_bad_client_with_shared(
|
||||||
reader,
|
reader,
|
||||||
writer,
|
writer,
|
||||||
&initial_data,
|
&initial_data,
|
||||||
@@ -333,6 +334,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
&config,
|
&config,
|
||||||
&beobachten,
|
&beobachten,
|
||||||
|
shared.as_ref(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -718,6 +720,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
beobachten.clone(),
|
beobachten.clone(),
|
||||||
|
shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -739,6 +742,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
beobachten.clone(),
|
beobachten.clone(),
|
||||||
|
shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -757,6 +761,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
beobachten.clone(),
|
beobachten.clone(),
|
||||||
|
shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -787,6 +792,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
beobachten.clone(),
|
beobachten.clone(),
|
||||||
|
shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => {
|
HandshakeResult::Error(e) => {
|
||||||
@@ -844,6 +850,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
beobachten.clone(),
|
beobachten.clone(),
|
||||||
|
shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
@@ -873,6 +880,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
beobachten.clone(),
|
beobachten.clone(),
|
||||||
|
shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -898,6 +906,7 @@ where
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
beobachten.clone(),
|
beobachten.clone(),
|
||||||
|
shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
@@ -1329,6 +1338,7 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
self.config.clone(),
|
self.config.clone(),
|
||||||
self.beobachten.clone(),
|
self.beobachten.clone(),
|
||||||
|
self.shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1350,6 +1360,7 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
self.config.clone(),
|
self.config.clone(),
|
||||||
self.beobachten.clone(),
|
self.beobachten.clone(),
|
||||||
|
self.shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -1369,6 +1380,7 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
self.config.clone(),
|
self.config.clone(),
|
||||||
self.beobachten.clone(),
|
self.beobachten.clone(),
|
||||||
|
self.shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1416,6 +1428,7 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
self.beobachten.clone(),
|
self.beobachten.clone(),
|
||||||
|
self.shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => {
|
HandshakeResult::Error(e) => {
|
||||||
@@ -1483,6 +1496,7 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
self.beobachten.clone(),
|
self.beobachten.clone(),
|
||||||
|
self.shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
@@ -1530,6 +1544,7 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
self.config.clone(),
|
self.config.clone(),
|
||||||
self.beobachten.clone(),
|
self.beobachten.clone(),
|
||||||
|
self.shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1568,6 +1583,7 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
config.clone(),
|
config.clone(),
|
||||||
self.beobachten.clone(),
|
self.beobachten.clone(),
|
||||||
|
self.shared.clone(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::network::dns_overrides::resolve_socket_addr;
|
use crate::network::dns_overrides::resolve_socket_addr;
|
||||||
use crate::protocol::tls;
|
use crate::protocol::tls;
|
||||||
|
use crate::proxy::shared_state::ProxySharedState;
|
||||||
use crate::stats::beobachten::BeobachtenStore;
|
use crate::stats::beobachten::BeobachtenStore;
|
||||||
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
||||||
use crate::transport::socket::configure_tcp_socket;
|
use crate::transport::socket::configure_tcp_socket;
|
||||||
@@ -10,6 +11,7 @@ use crate::transport::socket::configure_tcp_socket;
|
|||||||
use nix::ifaddrs::getifaddrs;
|
use nix::ifaddrs::getifaddrs;
|
||||||
use rand::rngs::StdRng;
|
use rand::rngs::StdRng;
|
||||||
use rand::{Rng, RngExt, SeedableRng};
|
use rand::{Rng, RngExt, SeedableRng};
|
||||||
|
use std::io::{Error as IoError, ErrorKind};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::str;
|
use std::str;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -18,7 +20,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
|
|||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
use std::time::{Duration, Instant as StdInstant};
|
use std::time::{Duration, Instant as StdInstant};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::{TcpStream, lookup_host};
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use tokio::net::UnixStream;
|
use tokio::net::UnixStream;
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
@@ -271,6 +273,32 @@ async fn consume_client_data_with_timeout_and_cap<R>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mask_failure_drain_cap(config: &ProxyConfig) -> usize {
|
||||||
|
let configured_cap = config.censorship.mask_relay_max_bytes;
|
||||||
|
if configured_cap == 0 {
|
||||||
|
return MASK_BUFFER_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
configured_cap.min(MASK_BUFFER_SIZE)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn consume_mask_failure_path<R>(
|
||||||
|
reader: R,
|
||||||
|
config: &ProxyConfig,
|
||||||
|
relay_timeout: Duration,
|
||||||
|
idle_timeout: Duration,
|
||||||
|
) where
|
||||||
|
R: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
consume_client_data_with_timeout_and_cap(
|
||||||
|
reader,
|
||||||
|
mask_failure_drain_cap(config),
|
||||||
|
relay_timeout,
|
||||||
|
idle_timeout,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
async fn wait_mask_connect_budget(started: Instant) {
|
async fn wait_mask_connect_budget(started: Instant) {
|
||||||
let elapsed = started.elapsed();
|
let elapsed = started.elapsed();
|
||||||
if elapsed < MASK_TIMEOUT {
|
if elapsed < MASK_TIMEOUT {
|
||||||
@@ -501,6 +529,32 @@ fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
|
|||||||
host.parse::<IpAddr>().ok()
|
host.parse::<IpAddr>().ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn resolve_mask_target_addrs(
|
||||||
|
mask_host: &str,
|
||||||
|
mask_port: u16,
|
||||||
|
) -> std::io::Result<Vec<SocketAddr>> {
|
||||||
|
if let Some(addr) = resolve_socket_addr(mask_host, mask_port) {
|
||||||
|
return Ok(vec![addr]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ip) = parse_mask_host_ip_literal(mask_host) {
|
||||||
|
return Ok(vec![SocketAddr::new(ip, mask_port)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let addrs = timeout(MASK_TIMEOUT, lookup_host((mask_host, mask_port)))
|
||||||
|
.await
|
||||||
|
.map_err(|_| IoError::new(ErrorKind::TimedOut, "mask target DNS lookup timed out"))??;
|
||||||
|
let addrs = addrs.collect::<Vec<_>>();
|
||||||
|
if addrs.is_empty() {
|
||||||
|
return Err(IoError::new(
|
||||||
|
ErrorKind::NotFound,
|
||||||
|
"mask target DNS lookup returned no addresses",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(addrs)
|
||||||
|
}
|
||||||
|
|
||||||
fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
|
fn matching_tls_domain_for_sni<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
|
||||||
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
|
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
|
||||||
return Some(config.censorship.tls_domain.as_str());
|
return Some(config.censorship.tls_domain.as_str());
|
||||||
@@ -782,7 +836,7 @@ fn is_mask_target_local_listener_with_interfaces(
|
|||||||
mask_host: &str,
|
mask_host: &str,
|
||||||
mask_port: u16,
|
mask_port: u16,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
resolved_override: Option<SocketAddr>,
|
resolved_addrs: &[SocketAddr],
|
||||||
interface_ips: &[IpAddr],
|
interface_ips: &[IpAddr],
|
||||||
) -> bool {
|
) -> bool {
|
||||||
if mask_port != local_addr.port() {
|
if mask_port != local_addr.port() {
|
||||||
@@ -792,7 +846,7 @@ fn is_mask_target_local_listener_with_interfaces(
|
|||||||
let local_ip = canonical_ip(local_addr.ip());
|
let local_ip = canonical_ip(local_addr.ip());
|
||||||
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
|
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
|
||||||
|
|
||||||
if let Some(addr) = resolved_override {
|
for addr in resolved_addrs {
|
||||||
let resolved_ip = canonical_ip(addr.ip());
|
let resolved_ip = canonical_ip(addr.ip());
|
||||||
if resolved_ip == local_ip {
|
if resolved_ip == local_ip {
|
||||||
return true;
|
return true;
|
||||||
@@ -829,7 +883,7 @@ fn is_mask_target_local_listener(
|
|||||||
mask_host: &str,
|
mask_host: &str,
|
||||||
mask_port: u16,
|
mask_port: u16,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
resolved_override: Option<SocketAddr>,
|
resolved_addrs: &[SocketAddr],
|
||||||
) -> bool {
|
) -> bool {
|
||||||
if mask_port != local_addr.port() {
|
if mask_port != local_addr.port() {
|
||||||
return false;
|
return false;
|
||||||
@@ -840,7 +894,7 @@ fn is_mask_target_local_listener(
|
|||||||
mask_host,
|
mask_host,
|
||||||
mask_port,
|
mask_port,
|
||||||
local_addr,
|
local_addr,
|
||||||
resolved_override,
|
resolved_addrs,
|
||||||
&interfaces,
|
&interfaces,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -849,7 +903,7 @@ async fn is_mask_target_local_listener_async(
|
|||||||
mask_host: &str,
|
mask_host: &str,
|
||||||
mask_port: u16,
|
mask_port: u16,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
resolved_override: Option<SocketAddr>,
|
resolved_addrs: &[SocketAddr],
|
||||||
) -> bool {
|
) -> bool {
|
||||||
if mask_port != local_addr.port() {
|
if mask_port != local_addr.port() {
|
||||||
return false;
|
return false;
|
||||||
@@ -860,7 +914,7 @@ async fn is_mask_target_local_listener_async(
|
|||||||
mask_host,
|
mask_host,
|
||||||
mask_port,
|
mask_port,
|
||||||
local_addr,
|
local_addr,
|
||||||
resolved_override,
|
resolved_addrs,
|
||||||
&interfaces,
|
&interfaces,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -904,7 +958,7 @@ fn configure_mask_backend_socket(stream: &TcpStream) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle a bad client by forwarding to mask host
|
/// Handles a bad client by forwarding it to the configured mask target.
|
||||||
pub async fn handle_bad_client<R, W>(
|
pub async fn handle_bad_client<R, W>(
|
||||||
reader: R,
|
reader: R,
|
||||||
writer: W,
|
writer: W,
|
||||||
@@ -916,6 +970,34 @@ pub async fn handle_bad_client<R, W>(
|
|||||||
) where
|
) where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
W: AsyncWrite + Unpin + Send + 'static,
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
let shared = ProxySharedState::new();
|
||||||
|
handle_bad_client_with_shared(
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
initial_data,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
config,
|
||||||
|
beobachten,
|
||||||
|
shared.as_ref(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handles a bad client with shared pre-auth fallback admission state.
|
||||||
|
pub(crate) async fn handle_bad_client_with_shared<R, W>(
|
||||||
|
reader: R,
|
||||||
|
writer: W,
|
||||||
|
initial_data: &[u8],
|
||||||
|
peer: SocketAddr,
|
||||||
|
local_addr: SocketAddr,
|
||||||
|
config: &ProxyConfig,
|
||||||
|
beobachten: &BeobachtenStore,
|
||||||
|
shared: &ProxySharedState,
|
||||||
|
) where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
let client_type = detect_client_type(initial_data);
|
let client_type = detect_client_type(initial_data);
|
||||||
if config.general.beobachten {
|
if config.general.beobachten {
|
||||||
@@ -938,6 +1020,17 @@ pub async fn handle_bad_client<R, W>(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let Some(_masking_permit) = shared.try_acquire_masking_fallback_permit() else {
|
||||||
|
let outcome_started = Instant::now();
|
||||||
|
debug!(
|
||||||
|
client_type = client_type,
|
||||||
|
"Masking fallback concurrency limit reached"
|
||||||
|
);
|
||||||
|
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||||
|
wait_mask_outcome_budget(outcome_started, config).await;
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
let client_sni = tls::extract_sni_from_client_hello(initial_data);
|
let client_sni = tls::extract_sni_from_client_hello(initial_data);
|
||||||
let exclusive_tcp_target = client_sni
|
let exclusive_tcp_target = client_sni
|
||||||
.as_deref()
|
.as_deref()
|
||||||
@@ -1000,24 +1093,12 @@ pub async fn handle_bad_client<R, W>(
|
|||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||||
consume_client_data_with_timeout_and_cap(
|
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||||
reader,
|
|
||||||
config.censorship.mask_relay_max_bytes,
|
|
||||||
relay_timeout,
|
|
||||||
idle_timeout,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
wait_mask_outcome_budget(outcome_started, config).await;
|
wait_mask_outcome_budget(outcome_started, config).await;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask unix socket");
|
debug!("Timeout connecting to mask unix socket");
|
||||||
consume_client_data_with_timeout_and_cap(
|
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||||
reader,
|
|
||||||
config.censorship.mask_relay_max_bytes,
|
|
||||||
relay_timeout,
|
|
||||||
idle_timeout,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
wait_mask_outcome_budget(outcome_started, config).await;
|
wait_mask_outcome_budget(outcome_started, config).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1030,11 +1111,27 @@ pub async fn handle_bad_client<R, W>(
|
|||||||
let mask_host = mask_target.host;
|
let mask_host = mask_target.host;
|
||||||
let mask_port = mask_target.port;
|
let mask_port = mask_target.port;
|
||||||
|
|
||||||
|
let resolved_mask_addrs = match resolve_mask_target_addrs(mask_host, mask_port).await {
|
||||||
|
Ok(addrs) => addrs,
|
||||||
|
Err(e) => {
|
||||||
|
let outcome_started = Instant::now();
|
||||||
|
debug!(
|
||||||
|
client_type = client_type,
|
||||||
|
host = %mask_host,
|
||||||
|
port = mask_port,
|
||||||
|
error = %e,
|
||||||
|
"Failed to resolve mask target"
|
||||||
|
);
|
||||||
|
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||||
|
wait_mask_outcome_budget(outcome_started, config).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Fail closed when fallback points at our own listener endpoint.
|
// Fail closed when fallback points at our own listener endpoint.
|
||||||
// Self-referential masking can create recursive proxy loops under
|
// Self-referential masking can create recursive proxy loops under
|
||||||
// misconfiguration and leak distinguishable load spikes to adversaries.
|
// misconfiguration and leak distinguishable load spikes to adversaries.
|
||||||
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
|
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, &resolved_mask_addrs)
|
||||||
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
|
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let outcome_started = Instant::now();
|
let outcome_started = Instant::now();
|
||||||
@@ -1045,13 +1142,7 @@ pub async fn handle_bad_client<R, W>(
|
|||||||
local = %local_addr,
|
local = %local_addr,
|
||||||
"Mask target resolves to local listener; refusing self-referential masking fallback"
|
"Mask target resolves to local listener; refusing self-referential masking fallback"
|
||||||
);
|
);
|
||||||
consume_client_data_with_timeout_and_cap(
|
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||||
reader,
|
|
||||||
config.censorship.mask_relay_max_bytes,
|
|
||||||
relay_timeout,
|
|
||||||
idle_timeout,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
wait_mask_outcome_budget(outcome_started, config).await;
|
wait_mask_outcome_budget(outcome_started, config).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -1066,12 +1157,12 @@ pub async fn handle_bad_client<R, W>(
|
|||||||
"Forwarding bad client to mask host"
|
"Forwarding bad client to mask host"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Apply runtime DNS override for mask target when configured.
|
|
||||||
let mask_addr = resolved_mask_addr
|
|
||||||
.map(|addr| addr.to_string())
|
|
||||||
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
|
|
||||||
let connect_started = Instant::now();
|
let connect_started = Instant::now();
|
||||||
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
let connect_result = timeout(
|
||||||
|
MASK_TIMEOUT,
|
||||||
|
TcpStream::connect(resolved_mask_addrs.as_slice()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
match connect_result {
|
match connect_result {
|
||||||
Ok(Ok(stream)) => {
|
Ok(Ok(stream)) => {
|
||||||
configure_mask_backend_socket(&stream);
|
configure_mask_backend_socket(&stream);
|
||||||
@@ -1113,24 +1204,12 @@ pub async fn handle_bad_client<R, W>(
|
|||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||||
debug!(error = %e, "Failed to connect to mask host");
|
debug!(error = %e, "Failed to connect to mask host");
|
||||||
consume_client_data_with_timeout_and_cap(
|
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||||
reader,
|
|
||||||
config.censorship.mask_relay_max_bytes,
|
|
||||||
relay_timeout,
|
|
||||||
idle_timeout,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
wait_mask_outcome_budget(outcome_started, config).await;
|
wait_mask_outcome_budget(outcome_started, config).await;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask host");
|
debug!("Timeout connecting to mask host");
|
||||||
consume_client_data_with_timeout_and_cap(
|
consume_mask_failure_path(reader, config, relay_timeout, idle_timeout).await;
|
||||||
reader,
|
|
||||||
config.censorship.mask_relay_max_bytes,
|
|
||||||
relay_timeout,
|
|
||||||
idle_timeout,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
wait_mask_outcome_budget(outcome_started, config).await;
|
wait_mask_outcome_budget(outcome_started, config).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,8 @@ use self::c2me::{
|
|||||||
};
|
};
|
||||||
use self::d2c::{
|
use self::d2c::{
|
||||||
MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason,
|
MeD2cFlushPolicy, MeWriterResponseOutcome, classify_me_d2c_flush_reason,
|
||||||
flush_client_or_cancel, observe_me_d2c_flush_event,
|
flush_client_or_cancel, me_d2c_flush_reason_requires_client_flush,
|
||||||
|
observe_me_d2c_flush_event,
|
||||||
process_me_writer_response_with_traffic_lease,
|
process_me_writer_response_with_traffic_lease,
|
||||||
};
|
};
|
||||||
use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in};
|
use self::desync::{RelayForensicsState, hash_ip_in, report_desync_frame_too_large_in};
|
||||||
|
|||||||
@@ -55,6 +55,37 @@ pub(super) fn classify_me_d2c_flush_reason(
|
|||||||
MeD2cFlushReason::QueueDrain
|
MeD2cFlushReason::QueueDrain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn me_d2c_flush_reason_requires_client_flush(reason: MeD2cFlushReason) -> bool {
|
||||||
|
!matches!(reason, MeD2cFlushReason::QueueDrain)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn queue_drain_is_not_a_physical_flush_trigger() {
|
||||||
|
assert!(!me_d2c_flush_reason_requires_client_flush(
|
||||||
|
MeD2cFlushReason::QueueDrain
|
||||||
|
));
|
||||||
|
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||||
|
MeD2cFlushReason::AckImmediate
|
||||||
|
));
|
||||||
|
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||||
|
MeD2cFlushReason::BatchFrames
|
||||||
|
));
|
||||||
|
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||||
|
MeD2cFlushReason::BatchBytes
|
||||||
|
));
|
||||||
|
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||||
|
MeD2cFlushReason::MaxDelay
|
||||||
|
));
|
||||||
|
assert!(me_d2c_flush_reason_requires_client_flush(
|
||||||
|
MeD2cFlushReason::Close
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) fn observe_me_d2c_flush_event(
|
pub(super) fn observe_me_d2c_flush_event(
|
||||||
stats: &Stats,
|
stats: &Stats,
|
||||||
reason: MeD2cFlushReason,
|
reason: MeD2cFlushReason,
|
||||||
|
|||||||
@@ -491,12 +491,18 @@ where
|
|||||||
d2c_flush_policy.max_bytes,
|
d2c_flush_policy.max_bytes,
|
||||||
max_delay_fired,
|
max_delay_fired,
|
||||||
);
|
);
|
||||||
let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() {
|
let physical_flush =
|
||||||
|
me_d2c_flush_reason_requires_client_flush(flush_reason);
|
||||||
|
let flush_started_at = if physical_flush
|
||||||
|
&& stats_clone.telemetry_policy().me_level.allows_debug()
|
||||||
|
{
|
||||||
Some(Instant::now())
|
Some(Instant::now())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
|
if physical_flush {
|
||||||
|
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
|
||||||
|
}
|
||||||
let flush_duration_us = flush_started_at.map(|started| {
|
let flush_duration_us = flush_started_at.map(|started| {
|
||||||
started
|
started
|
||||||
.elapsed()
|
.elapsed()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
|
|||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
|
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
|
||||||
@@ -14,6 +14,7 @@ use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateReg
|
|||||||
use crate::proxy::traffic_limiter::TrafficLimiter;
|
use crate::proxy::traffic_limiter::TrafficLimiter;
|
||||||
|
|
||||||
const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64;
|
const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64;
|
||||||
|
const MASKING_FALLBACK_MAX_CONCURRENT: usize = 512;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub(crate) enum ConntrackCloseReason {
|
pub(crate) enum ConntrackCloseReason {
|
||||||
@@ -72,6 +73,7 @@ pub(crate) struct ProxySharedState {
|
|||||||
active_user_sessions: DashMap<(String, u64), CancellationToken>,
|
active_user_sessions: DashMap<(String, u64), CancellationToken>,
|
||||||
pub(crate) conntrack_pressure_active: AtomicBool,
|
pub(crate) conntrack_pressure_active: AtomicBool,
|
||||||
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
|
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
|
||||||
|
masking_fallback_permits: Arc<Semaphore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use = "registered user sessions must be kept alive until relay completion"]
|
#[must_use = "registered user sessions must be kept alive until relay completion"]
|
||||||
@@ -131,9 +133,18 @@ impl ProxySharedState {
|
|||||||
active_user_sessions: DashMap::new(),
|
active_user_sessions: DashMap::new(),
|
||||||
conntrack_pressure_active: AtomicBool::new(false),
|
conntrack_pressure_active: AtomicBool::new(false),
|
||||||
conntrack_close_tx: Mutex::new(None),
|
conntrack_close_tx: Mutex::new(None),
|
||||||
|
masking_fallback_permits: Arc::new(Semaphore::new(MASKING_FALLBACK_MAX_CONCURRENT)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Attempts to reserve one masking fallback slot for a pre-auth connection.
|
||||||
|
pub(crate) fn try_acquire_masking_fallback_permit(&self) -> Option<OwnedSemaphorePermit> {
|
||||||
|
self.masking_fallback_permits
|
||||||
|
.clone()
|
||||||
|
.try_acquire_owned()
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn is_user_enabled(&self, user: &str) -> bool {
|
pub(crate) fn is_user_enabled(&self, user: &str) -> bool {
|
||||||
!self.disabled_users.contains_key(user)
|
!self.disabled_users.contains_key(user)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ fn loop_guard_unspecified_bind_uses_interface_inventory() {
|
|||||||
"mask.example",
|
"mask.example",
|
||||||
443,
|
443,
|
||||||
local,
|
local,
|
||||||
Some(resolved),
|
&[resolved],
|
||||||
&interfaces,
|
&interfaces,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
|
|||||||
let barrier = std::sync::Arc::clone(&barrier);
|
let barrier = std::sync::Arc::clone(&barrier);
|
||||||
tasks.push(tokio::spawn(async move {
|
tasks.push(tokio::spawn(async move {
|
||||||
barrier.wait().await;
|
barrier.wait().await;
|
||||||
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
|
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_
|
|||||||
|
|
||||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||||
|
|
||||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
|
||||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, &[]).await;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
local_interface_enumerations_for_tests(),
|
local_interface_enumerations_for_tests(),
|
||||||
@@ -35,7 +35,7 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
|
|||||||
reset_local_interface_enumerations_for_tests();
|
reset_local_interface_enumerations_for_tests();
|
||||||
|
|
||||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||||
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
|
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, &[]).await;
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!is_local,
|
!is_local,
|
||||||
|
|||||||
@@ -15,38 +15,49 @@ fn closed_local_port() -> u16 {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn self_target_detection_matches_literal_ipv4_listener() {
|
async fn self_target_detection_matches_literal_ipv4_listener() {
|
||||||
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
|
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
|
||||||
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await);
|
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, &[],).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn self_target_detection_matches_bracketed_ipv6_listener() {
|
async fn self_target_detection_matches_bracketed_ipv6_listener() {
|
||||||
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
|
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
|
||||||
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await);
|
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, &[],).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
|
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
|
||||||
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
|
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
|
||||||
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await);
|
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, &[],).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
|
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
|
||||||
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await);
|
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, &[],).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
|
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
|
||||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||||
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await);
|
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, &[],).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
|
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
|
||||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||||
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
||||||
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await);
|
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, &[remote],).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn self_target_detection_checks_all_resolved_addresses() {
|
||||||
|
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
||||||
|
let loopback: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
is_mask_target_local_listener_async("mask.example", 443, local, &[remote, loopback],).await
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -11,16 +11,41 @@ use std::io::{Error, ErrorKind, Result};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
|
const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
|
||||||
|
|
||||||
|
fn reject_oversize_frame(len: usize, max_frame_size: usize, protocol: &str) -> Result<()> {
|
||||||
|
if len > max_frame_size {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidData,
|
||||||
|
format!("{protocol} frame too large: {len} bytes (max {max_frame_size})"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// ============= Abridged (Compact) Frame =============
|
// ============= Abridged (Compact) Frame =============
|
||||||
|
|
||||||
/// Reader for abridged MTProto framing
|
/// Reader for abridged MTProto framing
|
||||||
pub struct AbridgedFrameReader<R> {
|
pub struct AbridgedFrameReader<R> {
|
||||||
upstream: R,
|
upstream: R,
|
||||||
|
max_frame_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> AbridgedFrameReader<R> {
|
impl<R> AbridgedFrameReader<R> {
|
||||||
|
/// Creates a reader with the default maximum frame size.
|
||||||
pub fn new(upstream: R) -> Self {
|
pub fn new(upstream: R) -> Self {
|
||||||
Self { upstream }
|
Self {
|
||||||
|
upstream,
|
||||||
|
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
upstream,
|
||||||
|
max_frame_size,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,10 +73,12 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
|
|||||||
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
|
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length is in 4-byte words
|
// Length is in 4-byte words.
|
||||||
let byte_len = len * 4;
|
let byte_len = len
|
||||||
|
.checked_mul(4)
|
||||||
|
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "abridged frame length overflow"))?;
|
||||||
|
reject_oversize_frame(byte_len, self.max_frame_size, "abridged")?;
|
||||||
|
|
||||||
// Read data
|
|
||||||
let mut data = vec![0u8; byte_len];
|
let mut data = vec![0u8; byte_len];
|
||||||
self.upstream.read_exact(&mut data).await?;
|
self.upstream.read_exact(&mut data).await?;
|
||||||
|
|
||||||
@@ -152,11 +179,23 @@ impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
|
|||||||
/// Reader for intermediate MTProto framing
|
/// Reader for intermediate MTProto framing
|
||||||
pub struct IntermediateFrameReader<R> {
|
pub struct IntermediateFrameReader<R> {
|
||||||
upstream: R,
|
upstream: R,
|
||||||
|
max_frame_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> IntermediateFrameReader<R> {
|
impl<R> IntermediateFrameReader<R> {
|
||||||
|
/// Creates a reader with the default maximum frame size.
|
||||||
pub fn new(upstream: R) -> Self {
|
pub fn new(upstream: R) -> Self {
|
||||||
Self { upstream }
|
Self {
|
||||||
|
upstream,
|
||||||
|
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
upstream,
|
||||||
|
max_frame_size,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,8 +210,8 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
|||||||
let header = parse_intermediate_header(len_bytes);
|
let header = parse_intermediate_header(len_bytes);
|
||||||
let len = header.wire_len;
|
let len = header.wire_len;
|
||||||
meta.quickack = header.quickack;
|
meta.quickack = header.quickack;
|
||||||
|
reject_oversize_frame(len, self.max_frame_size, "intermediate")?;
|
||||||
|
|
||||||
// Read data
|
|
||||||
let mut data = vec![0u8; len];
|
let mut data = vec![0u8; len];
|
||||||
self.upstream.read_exact(&mut data).await?;
|
self.upstream.read_exact(&mut data).await?;
|
||||||
|
|
||||||
@@ -243,11 +282,23 @@ impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
|
|||||||
/// Reader for secure intermediate MTProto framing (with padding)
|
/// Reader for secure intermediate MTProto framing (with padding)
|
||||||
pub struct SecureIntermediateFrameReader<R> {
|
pub struct SecureIntermediateFrameReader<R> {
|
||||||
upstream: R,
|
upstream: R,
|
||||||
|
max_frame_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> SecureIntermediateFrameReader<R> {
|
impl<R> SecureIntermediateFrameReader<R> {
|
||||||
|
/// Creates a reader with the default maximum frame size.
|
||||||
pub fn new(upstream: R) -> Self {
|
pub fn new(upstream: R) -> Self {
|
||||||
Self { upstream }
|
Self {
|
||||||
|
upstream,
|
||||||
|
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
upstream,
|
||||||
|
max_frame_size,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,17 +313,16 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
|||||||
let header = parse_intermediate_header(len_bytes);
|
let header = parse_intermediate_header(len_bytes);
|
||||||
let len = header.wire_len;
|
let len = header.wire_len;
|
||||||
meta.quickack = header.quickack;
|
meta.quickack = header.quickack;
|
||||||
|
reject_oversize_frame(len, self.max_frame_size, "secure intermediate")?;
|
||||||
// Read data (including padding)
|
|
||||||
let mut data = vec![0u8; len];
|
|
||||||
self.upstream.read_exact(&mut data).await?;
|
|
||||||
|
|
||||||
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
|
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
|
||||||
Error::new(
|
Error::new(
|
||||||
ErrorKind::InvalidData,
|
ErrorKind::InvalidData,
|
||||||
format!("Invalid secure frame length: {len}"),
|
format!("Invalid secure frame length: {len}"),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let mut data = vec![0u8; len];
|
||||||
|
self.upstream.read_exact(&mut data).await?;
|
||||||
data.truncate(payload_len);
|
data.truncate(payload_len);
|
||||||
|
|
||||||
Ok((Bytes::from(data), meta))
|
Ok((Bytes::from(data), meta))
|
||||||
@@ -321,7 +371,9 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
|||||||
let padding_len = secure_padding_len(data.len(), &self.rng);
|
let padding_len = secure_padding_len(data.len(), &self.rng);
|
||||||
let padding = self.rng.bytes(padding_len);
|
let padding = self.rng.bytes(padding_len);
|
||||||
|
|
||||||
let total_len = data.len() + padding_len;
|
let total_len = data.len().checked_add(padding_len).ok_or_else(|| {
|
||||||
|
Error::new(ErrorKind::InvalidInput, "secure frame length overflow")
|
||||||
|
})?;
|
||||||
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
|
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
|
||||||
Error::new(
|
Error::new(
|
||||||
ErrorKind::InvalidInput,
|
ErrorKind::InvalidInput,
|
||||||
@@ -507,15 +559,26 @@ pub enum FrameReaderKind<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
|
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
|
||||||
|
/// Creates a frame reader with the default maximum frame size.
|
||||||
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
|
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
|
||||||
|
Self::with_max_frame_size(upstream, proto_tag, DEFAULT_MAX_FRAME_SIZE)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_max_frame_size(
|
||||||
|
upstream: R,
|
||||||
|
proto_tag: ProtoTag,
|
||||||
|
max_frame_size: usize,
|
||||||
|
) -> Self {
|
||||||
match proto_tag {
|
match proto_tag {
|
||||||
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
|
ProtoTag::Abridged => FrameReaderKind::Abridged(
|
||||||
ProtoTag::Intermediate => {
|
AbridgedFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||||
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream))
|
),
|
||||||
}
|
ProtoTag::Intermediate => FrameReaderKind::Intermediate(
|
||||||
ProtoTag::Secure => {
|
IntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||||
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream))
|
),
|
||||||
}
|
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(
|
||||||
|
SecureIntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -569,7 +632,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::crypto::SecureRandom;
|
use crate::crypto::SecureRandom;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::duplex;
|
use tokio::io::{AsyncWriteExt, duplex};
|
||||||
|
use tokio::time::{Duration, timeout};
|
||||||
|
|
||||||
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
|
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
|
||||||
assert!(decoded.starts_with(original));
|
assert!(decoded.starts_with(original));
|
||||||
@@ -672,6 +736,55 @@ mod tests {
|
|||||||
assert!(meta.quickack);
|
assert!(meta.quickack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn abridged_reader_rejects_oversize_frame_before_body_read() {
|
||||||
|
let (mut client, server) = duplex(1024);
|
||||||
|
let mut reader = AbridgedFrameReader::new(server);
|
||||||
|
let len_words = (DEFAULT_MAX_FRAME_SIZE / 4) + 1;
|
||||||
|
let encoded = (len_words as u32).to_le_bytes();
|
||||||
|
|
||||||
|
client
|
||||||
|
.write_all(&[0x7f, encoded[0], encoded[1], encoded[2]])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap_err();
|
||||||
|
|
||||||
|
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn intermediate_reader_rejects_oversize_frame_before_body_read() {
|
||||||
|
let (mut client, server) = duplex(1024);
|
||||||
|
let mut reader = IntermediateFrameReader::new(server);
|
||||||
|
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 1, false).unwrap();
|
||||||
|
|
||||||
|
client.write_all(&len.to_le_bytes()).await.unwrap();
|
||||||
|
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap_err();
|
||||||
|
|
||||||
|
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn secure_reader_rejects_oversize_frame_before_body_read() {
|
||||||
|
let (mut client, server) = duplex(1024);
|
||||||
|
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||||
|
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 4, false).unwrap();
|
||||||
|
|
||||||
|
client.write_all(&len.to_le_bytes()).await.unwrap();
|
||||||
|
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap_err();
|
||||||
|
|
||||||
|
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_secure_intermediate_padding() {
|
async fn test_secure_intermediate_padding() {
|
||||||
let (client, server) = duplex(1024);
|
let (client, server) = duplex(1024);
|
||||||
|
|||||||
Reference in New Issue
Block a user