From 1357f3cc4c18f1b59e7bfd3cdabddd2d4e39d4aa Mon Sep 17 00:00:00 2001 From: David Osipov Date: Tue, 17 Mar 2026 18:16:17 +0400 Subject: [PATCH] bump version to 3.3.20 and implement connection lease management for direct and middle relays --- Cargo.lock | 2 +- src/proxy/direct_relay.rs | 4 +- src/proxy/direct_relay_security_tests.rs | 129 ++++++++++++++ src/proxy/middle_relay.rs | 4 +- src/proxy/middle_relay_security_tests.rs | 167 ++++++++++++++++++- src/stats/connection_lease_security_tests.rs | 114 +++++++++++++ src/stats/mod.rs | 54 ++++++ 7 files changed, 465 insertions(+), 9 deletions(-) create mode 100644 src/stats/connection_lease_security_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 89eefd6..677ab84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2093,7 +2093,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.3.19" +version = "3.3.20" dependencies = [ "aes", "anyhow", diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 9c6116c..d7d5f64 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -105,7 +105,7 @@ where debug!(peer = %success.peer, "TG handshake complete, starting relay"); stats.increment_user_connects(user); - stats.increment_current_connections_direct(); + let _direct_connection_lease = stats.acquire_direct_connection_lease(); let relay_result = relay_bidirectional( client_reader, @@ -148,8 +148,6 @@ where } }; - stats.decrement_current_connections_direct(); - match &relay_result { Ok(()) => debug!(user = %user, "Direct relay completed"), Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), diff --git a/src/proxy/direct_relay_security_tests.rs b/src/proxy/direct_relay_security_tests.rs index 3b3185a..1e2d673 100644 --- a/src/proxy/direct_relay_security_tests.rs +++ b/src/proxy/direct_relay_security_tests.rs @@ -1,4 +1,33 @@ use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::protocol::constants::ProtoTag; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::duplex; +use tokio::net::TcpListener; + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} #[test] fn unknown_dc_log_is_deduplicated_per_dc_idx() { @@ -49,3 +78,103 @@ fn fallback_dc_never_panics_with_single_dc_list() { let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); assert_eq!(addr, expected); } + +#[tokio::test] +async fn direct_relay_abort_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50000".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xabad1dea, + )); + + let started = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "direct relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted direct relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay task is aborted mid-flight" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 1acbdc1..affa4cd 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -306,7 +306,7 @@ where }; stats.increment_user_connects(&user); - stats.increment_current_connections_me(); + let _me_connection_lease = stats.acquire_me_connection_lease(); if let Some(cutover) = affected_cutover_state( &route_rx, @@ -324,7 +324,6 @@ where tokio::time::sleep(delay).await; let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); } @@ -672,7 +671,6 @@ where "ME relay cleanup" ); me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); result } diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index a2a6c3e..509ba95 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -2,8 +2,13 @@ use super::*; use bytes::Bytes; use crate::crypto::AesCtr; use crate::crypto::SecureRandom; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::network::probe::NetworkDecision; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; +use crate::transport::middle_proxy::MePool; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::AtomicU64; @@ -229,18 +234,108 @@ fn make_forensics_state() -> RelayForensicsState { } } -fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader { +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ let key = [0u8; 32]; let iv = 0u128; CryptoReader::new(reader, AesCtr::new(&key, iv)) } -fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter { +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ let key = [0u8; 32]; let iv = 0u128; CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + stats, + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_pool_drain_threshold, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + fn encrypt_for_reader(plaintext: &[u8]) -> Vec { let key = [0u8; 32]; let iv = 0u128; @@ -779,3 +874,71 @@ async fn process_me_writer_response_data_updates_byte_accounting() { "ME->C byte accounting must increase by emitted payload size" ); } + +#[tokio::test] +async fn middle_relay_abort_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50001".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xdecafbad, + )); + + let started = tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "middle relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted middle relay task must return join error"); + + tokio::time::sleep(TokioDuration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay task is aborted mid-flight" + ); + + drop(client_side); +} diff --git a/src/stats/connection_lease_security_tests.rs b/src/stats/connection_lease_security_tests.rs new file mode 100644 index 0000000..2d942c2 --- /dev/null +++ b/src/stats/connection_lease_security_tests.rs @@ -0,0 +1,114 @@ +use super::*; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::Arc; +use std::time::Duration; + +#[test] +fn direct_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_direct(), 0); + + { + let _lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + } + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn middle_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_me(), 0); + + { + let _lease = stats.acquire_me_connection_lease(); + assert_eq!(stats.get_current_connections_me(), 1); + } + + assert_eq!(stats.get_current_connections_me(), 0); +} + +#[test] +fn connection_lease_disarm_prevents_double_release() { + let stats = Arc::new(Stats::new()); + + let mut lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + + stats.decrement_current_connections_direct(); + assert_eq!(stats.get_current_connections_direct(), 0); + + lease.disarm(); + drop(lease); + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn direct_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_direct_connection_lease(); + panic!("intentional panic to verify lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_direct(), + 0, + "panic unwind must release direct route gauge" + ); +} + +#[tokio::test] +async fn direct_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_direct(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "aborted task must release direct route gauge" + ); +} + +#[tokio::test] +async fn middle_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_me(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "aborted task must release middle route gauge" + ); +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 603552d..36241af 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -6,6 +6,7 @@ pub mod beobachten; pub mod telemetry; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; use parking_lot::Mutex; @@ -19,6 +20,45 @@ use tracing::debug; use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use self::telemetry::TelemetryPolicy; +#[derive(Clone, Copy)] +enum RouteConnectionGauge { + Direct, + Middle, +} + +pub struct RouteConnectionLease { + stats: Arc, + gauge: RouteConnectionGauge, + active: bool, +} + +impl RouteConnectionLease { + fn new(stats: Arc, gauge: RouteConnectionGauge) -> Self { + Self { + stats, + gauge, + active: true, + } + } + + #[cfg(test)] + fn disarm(&mut self) { + self.active = false; + } +} + +impl Drop for RouteConnectionLease { + fn drop(&mut self) { + if !self.active { + return; + } + match self.gauge { + RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(), + RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(), + } + } +} + // ============= Stats ============= #[derive(Default)] @@ -285,6 +325,16 @@ impl Stats { pub fn decrement_current_connections_me(&self) { Self::decrement_atomic_saturating(&self.current_connections_me); } + + pub fn acquire_direct_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_direct(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct) + } + + pub fn acquire_me_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_me(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle) + } pub fn increment_handshake_timeouts(&self) { if self.telemetry_core_enabled() { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); @@ -1772,3 +1822,7 @@ mod tests { assert_eq!(checker.stats().total_entries, 500); } } + +#[cfg(test)] +#[path = "connection_lease_security_tests.rs"] +mod connection_lease_security_tests;