mirror of https://github.com/telemt/telemt.git
bump version to 3.3.20 and implement connection lease management for direct and middle relays
This commit is contained in:
parent
d9aa6f4956
commit
1357f3cc4c
|
|
@ -2093,7 +2093,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "telemt"
|
name = "telemt"
|
||||||
version = "3.3.19"
|
version = "3.3.20"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aes",
|
"aes",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ where
|
||||||
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||||
|
|
||||||
stats.increment_user_connects(user);
|
stats.increment_user_connects(user);
|
||||||
stats.increment_current_connections_direct();
|
let _direct_connection_lease = stats.acquire_direct_connection_lease();
|
||||||
|
|
||||||
let relay_result = relay_bidirectional(
|
let relay_result = relay_bidirectional(
|
||||||
client_reader,
|
client_reader,
|
||||||
|
|
@ -148,8 +148,6 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
stats.decrement_current_connections_direct();
|
|
||||||
|
|
||||||
match &relay_result {
|
match &relay_result {
|
||||||
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
||||||
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
|
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,33 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::config::{UpstreamConfig, UpstreamType};
|
||||||
|
use crate::crypto::{AesCtr, SecureRandom};
|
||||||
|
use crate::protocol::constants::ProtoTag;
|
||||||
|
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||||
|
use crate::stats::Stats;
|
||||||
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||||
|
use crate::transport::UpstreamManager;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::io::duplex;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
|
||||||
|
where
|
||||||
|
R: tokio::io::AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||||
|
where
|
||||||
|
W: tokio::io::AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn unknown_dc_log_is_deduplicated_per_dc_idx() {
|
fn unknown_dc_log_is_deduplicated_per_dc_idx() {
|
||||||
|
|
@ -49,3 +78,103 @@ fn fallback_dc_never_panics_with_single_dc_list() {
|
||||||
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
|
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
|
||||||
assert_eq!(addr, expected);
|
assert_eq!(addr, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_relay_abort_midflight_releases_route_gauge() {
|
||||||
|
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let tg_addr = tg_listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let tg_accept_task = tokio::spawn(async move {
|
||||||
|
let (stream, _) = tg_listener.accept().await.unwrap();
|
||||||
|
let _hold_stream = stream;
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config
|
||||||
|
.dc_overrides
|
||||||
|
.insert("2".to_string(), vec![tg_addr.to_string()]);
|
||||||
|
let config = Arc::new(config);
|
||||||
|
|
||||||
|
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||||
|
vec![UpstreamConfig {
|
||||||
|
upstream_type: UpstreamType::Direct {
|
||||||
|
interface: None,
|
||||||
|
bind_addresses: None,
|
||||||
|
},
|
||||||
|
weight: 1,
|
||||||
|
enabled: true,
|
||||||
|
scopes: String::new(),
|
||||||
|
selected_scope: String::new(),
|
||||||
|
}],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
false,
|
||||||
|
stats.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||||
|
let route_snapshot = route_runtime.snapshot();
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let client_reader = make_crypto_reader(server_reader);
|
||||||
|
let client_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "abort-direct-user".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: "127.0.0.1:50000".parse().unwrap(),
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(handle_via_direct(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
route_runtime.subscribe(),
|
||||||
|
route_snapshot,
|
||||||
|
0xabad1dea,
|
||||||
|
));
|
||||||
|
|
||||||
|
let started = tokio::time::timeout(Duration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_direct() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
assert!(started.is_ok(), "direct relay must increment route gauge before abort");
|
||||||
|
|
||||||
|
relay_task.abort();
|
||||||
|
let joined = relay_task.await;
|
||||||
|
assert!(joined.is_err(), "aborted direct relay task must return join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"route gauge must be released when direct relay task is aborted mid-flight"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
tg_accept_task.abort();
|
||||||
|
let _ = tg_accept_task.await;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -306,7 +306,7 @@ where
|
||||||
};
|
};
|
||||||
|
|
||||||
stats.increment_user_connects(&user);
|
stats.increment_user_connects(&user);
|
||||||
stats.increment_current_connections_me();
|
let _me_connection_lease = stats.acquire_me_connection_lease();
|
||||||
|
|
||||||
if let Some(cutover) = affected_cutover_state(
|
if let Some(cutover) = affected_cutover_state(
|
||||||
&route_rx,
|
&route_rx,
|
||||||
|
|
@ -324,7 +324,6 @@ where
|
||||||
tokio::time::sleep(delay).await;
|
tokio::time::sleep(delay).await;
|
||||||
let _ = me_pool.send_close(conn_id).await;
|
let _ = me_pool.send_close(conn_id).await;
|
||||||
me_pool.registry().unregister(conn_id).await;
|
me_pool.registry().unregister(conn_id).await;
|
||||||
stats.decrement_current_connections_me();
|
|
||||||
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -672,7 +671,6 @@ where
|
||||||
"ME relay cleanup"
|
"ME relay cleanup"
|
||||||
);
|
);
|
||||||
me_pool.registry().unregister(conn_id).await;
|
me_pool.registry().unregister(conn_id).await;
|
||||||
stats.decrement_current_connections_me();
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,13 @@ use super::*;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use crate::crypto::AesCtr;
|
use crate::crypto::AesCtr;
|
||||||
use crate::crypto::SecureRandom;
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
|
||||||
|
use crate::network::probe::NetworkDecision;
|
||||||
|
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
|
use crate::transport::middle_proxy::MePool;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::AtomicU64;
|
use std::sync::atomic::AtomicU64;
|
||||||
|
|
@ -229,18 +234,108 @@ fn make_forensics_state() -> RelayForensicsState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io::DuplexStream> {
|
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
|
||||||
|
where
|
||||||
|
R: tokio::io::AsyncRead + Unpin,
|
||||||
|
{
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter<tokio::io::DuplexStream> {
|
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||||
|
where
|
||||||
|
W: tokio::io::AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn make_me_pool_for_abort_test(stats: Arc<Stats>) -> Arc<MePool> {
|
||||||
|
let general = GeneralConfig::default();
|
||||||
|
|
||||||
|
MePool::new(
|
||||||
|
None,
|
||||||
|
vec![1u8; 32],
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
None,
|
||||||
|
Vec::new(),
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
12,
|
||||||
|
1200,
|
||||||
|
HashMap::new(),
|
||||||
|
HashMap::new(),
|
||||||
|
None,
|
||||||
|
NetworkDecision::default(),
|
||||||
|
None,
|
||||||
|
Arc::new(SecureRandom::new()),
|
||||||
|
stats,
|
||||||
|
general.me_keepalive_enabled,
|
||||||
|
general.me_keepalive_interval_secs,
|
||||||
|
general.me_keepalive_jitter_secs,
|
||||||
|
general.me_keepalive_payload_random,
|
||||||
|
general.rpc_proxy_req_every,
|
||||||
|
general.me_warmup_stagger_enabled,
|
||||||
|
general.me_warmup_step_delay_ms,
|
||||||
|
general.me_warmup_step_jitter_ms,
|
||||||
|
general.me_reconnect_max_concurrent_per_dc,
|
||||||
|
general.me_reconnect_backoff_base_ms,
|
||||||
|
general.me_reconnect_backoff_cap_ms,
|
||||||
|
general.me_reconnect_fast_retry_count,
|
||||||
|
general.me_single_endpoint_shadow_writers,
|
||||||
|
general.me_single_endpoint_outage_mode_enabled,
|
||||||
|
general.me_single_endpoint_outage_disable_quarantine,
|
||||||
|
general.me_single_endpoint_outage_backoff_min_ms,
|
||||||
|
general.me_single_endpoint_outage_backoff_max_ms,
|
||||||
|
general.me_single_endpoint_shadow_rotate_every_secs,
|
||||||
|
general.me_floor_mode,
|
||||||
|
general.me_adaptive_floor_idle_secs,
|
||||||
|
general.me_adaptive_floor_min_writers_single_endpoint,
|
||||||
|
general.me_adaptive_floor_min_writers_multi_endpoint,
|
||||||
|
general.me_adaptive_floor_recover_grace_secs,
|
||||||
|
general.me_adaptive_floor_writers_per_core_total,
|
||||||
|
general.me_adaptive_floor_cpu_cores_override,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_single_per_core,
|
||||||
|
general.me_adaptive_floor_max_extra_writers_multi_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_per_core,
|
||||||
|
general.me_adaptive_floor_max_active_writers_global,
|
||||||
|
general.me_adaptive_floor_max_warm_writers_global,
|
||||||
|
general.hardswap,
|
||||||
|
general.me_pool_drain_ttl_secs,
|
||||||
|
general.me_pool_drain_threshold,
|
||||||
|
general.effective_me_pool_force_close_secs(),
|
||||||
|
general.me_pool_min_fresh_ratio,
|
||||||
|
general.me_hardswap_warmup_delay_min_ms,
|
||||||
|
general.me_hardswap_warmup_delay_max_ms,
|
||||||
|
general.me_hardswap_warmup_extra_passes,
|
||||||
|
general.me_hardswap_warmup_pass_backoff_base_ms,
|
||||||
|
general.me_bind_stale_mode,
|
||||||
|
general.me_bind_stale_ttl_secs,
|
||||||
|
general.me_secret_atomic_snapshot,
|
||||||
|
general.me_deterministic_writer_sort,
|
||||||
|
MeWriterPickMode::default(),
|
||||||
|
general.me_writer_pick_sample_size,
|
||||||
|
MeSocksKdfPolicy::default(),
|
||||||
|
general.me_writer_cmd_channel_capacity,
|
||||||
|
general.me_route_channel_capacity,
|
||||||
|
general.me_route_backpressure_base_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_timeout_ms,
|
||||||
|
general.me_route_backpressure_high_watermark_pct,
|
||||||
|
general.me_reader_route_data_wait_ms,
|
||||||
|
general.me_health_interval_ms_unhealthy,
|
||||||
|
general.me_health_interval_ms_healthy,
|
||||||
|
general.me_warn_rate_limit_ms,
|
||||||
|
MeRouteNoWriterMode::default(),
|
||||||
|
general.me_route_no_writer_wait_ms,
|
||||||
|
general.me_route_inline_recovery_attempts,
|
||||||
|
general.me_route_inline_recovery_wait_ms,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
|
|
@ -779,3 +874,71 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
|
||||||
"ME->C byte accounting must increase by emitted payload size"
|
"ME->C byte accounting must increase by emitted payload size"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middle_relay_abort_midflight_releases_route_gauge() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let me_pool = make_me_pool_for_abort_test(stats.clone()).await;
|
||||||
|
let config = Arc::new(ProxyConfig::default());
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
|
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle));
|
||||||
|
let route_snapshot = route_runtime.snapshot();
|
||||||
|
|
||||||
|
let (server_side, client_side) = duplex(64 * 1024);
|
||||||
|
let (server_reader, server_writer) = tokio::io::split(server_side);
|
||||||
|
let crypto_reader = make_crypto_reader(server_reader);
|
||||||
|
let crypto_writer = make_crypto_writer(server_writer);
|
||||||
|
|
||||||
|
let success = HandshakeSuccess {
|
||||||
|
user: "abort-middle-user".to_string(),
|
||||||
|
dc_idx: 2,
|
||||||
|
proto_tag: ProtoTag::Intermediate,
|
||||||
|
dec_key: [0u8; 32],
|
||||||
|
dec_iv: 0,
|
||||||
|
enc_key: [0u8; 32],
|
||||||
|
enc_iv: 0,
|
||||||
|
peer: "127.0.0.1:50001".parse().unwrap(),
|
||||||
|
is_tls: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(handle_via_middle_proxy(
|
||||||
|
crypto_reader,
|
||||||
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
me_pool,
|
||||||
|
stats.clone(),
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
"127.0.0.1:443".parse().unwrap(),
|
||||||
|
rng,
|
||||||
|
route_runtime.subscribe(),
|
||||||
|
route_snapshot,
|
||||||
|
0xdecafbad,
|
||||||
|
));
|
||||||
|
|
||||||
|
let started = tokio::time::timeout(TokioDuration::from_secs(2), async {
|
||||||
|
loop {
|
||||||
|
if stats.get_current_connections_me() == 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
assert!(started.is_ok(), "middle relay must increment route gauge before abort");
|
||||||
|
|
||||||
|
relay_task.abort();
|
||||||
|
let joined = relay_task.await;
|
||||||
|
assert!(joined.is_err(), "aborted middle relay task must return join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(20)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"route gauge must be released when middle relay task is aborted mid-flight"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(client_side);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
use super::*;
|
||||||
|
use std::panic::{self, AssertUnwindSafe};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_connection_lease_balances_on_drop() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
|
||||||
|
{
|
||||||
|
let _lease = stats.acquire_direct_connection_lease();
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn middle_connection_lease_balances_on_drop() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 0);
|
||||||
|
|
||||||
|
{
|
||||||
|
let _lease = stats.acquire_me_connection_lease();
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn connection_lease_disarm_prevents_double_release() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
|
||||||
|
let mut lease = stats.acquire_direct_connection_lease();
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||||
|
|
||||||
|
stats.decrement_current_connections_direct();
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
|
||||||
|
lease.disarm();
|
||||||
|
drop(lease);
|
||||||
|
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_connection_lease_balances_on_panic_unwind() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_panic = stats.clone();
|
||||||
|
|
||||||
|
let panic_result = panic::catch_unwind(AssertUnwindSafe(move || {
|
||||||
|
let _lease = stats_for_panic.acquire_direct_connection_lease();
|
||||||
|
panic!("intentional panic to verify lease drop path");
|
||||||
|
}));
|
||||||
|
|
||||||
|
assert!(panic_result.is_err(), "panic must propagate from test closure");
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"panic unwind must release direct route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn direct_connection_lease_balances_on_task_abort() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _lease = stats_for_task.acquire_direct_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(stats.get_current_connections_direct(), 1);
|
||||||
|
|
||||||
|
task.abort();
|
||||||
|
let joined = task.await;
|
||||||
|
assert!(joined.is_err(), "aborted task must return a join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_direct(),
|
||||||
|
0,
|
||||||
|
"aborted task must release direct route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middle_connection_lease_balances_on_task_abort() {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
let stats_for_task = stats.clone();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _lease = stats_for_task.acquire_me_connection_lease();
|
||||||
|
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(stats.get_current_connections_me(), 1);
|
||||||
|
|
||||||
|
task.abort();
|
||||||
|
let joined = task.await;
|
||||||
|
assert!(joined.is_err(), "aborted task must return a join error");
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_current_connections_me(),
|
||||||
|
0,
|
||||||
|
"aborted task must release middle route gauge"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -6,6 +6,7 @@ pub mod beobachten;
|
||||||
pub mod telemetry;
|
pub mod telemetry;
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
@ -19,6 +20,45 @@ use tracing::debug;
|
||||||
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
|
use crate::config::{MeTelemetryLevel, MeWriterPickMode};
|
||||||
use self::telemetry::TelemetryPolicy;
|
use self::telemetry::TelemetryPolicy;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
enum RouteConnectionGauge {
|
||||||
|
Direct,
|
||||||
|
Middle,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RouteConnectionLease {
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
gauge: RouteConnectionGauge,
|
||||||
|
active: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouteConnectionLease {
|
||||||
|
fn new(stats: Arc<Stats>, gauge: RouteConnectionGauge) -> Self {
|
||||||
|
Self {
|
||||||
|
stats,
|
||||||
|
gauge,
|
||||||
|
active: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn disarm(&mut self) {
|
||||||
|
self.active = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for RouteConnectionLease {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if !self.active {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
match self.gauge {
|
||||||
|
RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(),
|
||||||
|
RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ============= Stats =============
|
// ============= Stats =============
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
|
@ -285,6 +325,16 @@ impl Stats {
|
||||||
pub fn decrement_current_connections_me(&self) {
|
pub fn decrement_current_connections_me(&self) {
|
||||||
Self::decrement_atomic_saturating(&self.current_connections_me);
|
Self::decrement_atomic_saturating(&self.current_connections_me);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn acquire_direct_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
|
||||||
|
self.increment_current_connections_direct();
|
||||||
|
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn acquire_me_connection_lease(self: &Arc<Self>) -> RouteConnectionLease {
|
||||||
|
self.increment_current_connections_me();
|
||||||
|
RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle)
|
||||||
|
}
|
||||||
pub fn increment_handshake_timeouts(&self) {
|
pub fn increment_handshake_timeouts(&self) {
|
||||||
if self.telemetry_core_enabled() {
|
if self.telemetry_core_enabled() {
|
||||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
@ -1772,3 +1822,7 @@ mod tests {
|
||||||
assert_eq!(checker.stats().total_entries, 500);
|
assert_eq!(checker.stats().total_entries, 500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "connection_lease_security_tests.rs"]
|
||||||
|
mod connection_lease_security_tests;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue