mirror of
https://github.com/telemt/telemt.git
synced 2026-06-24 11:51:10 +03:00
@@ -1558,6 +1558,11 @@ impl RunningClientHandler {
|
||||
{
|
||||
let user = success.user.clone();
|
||||
|
||||
if !shared.is_user_enabled(&user) {
|
||||
warn!(user = %user, "Disabled user rejected");
|
||||
return Err(ProxyError::UserDisabled { user });
|
||||
}
|
||||
|
||||
let user_limit_reservation = match Self::acquire_user_connection_reservation_static(
|
||||
&user,
|
||||
&config,
|
||||
@@ -1576,6 +1581,8 @@ impl RunningClientHandler {
|
||||
|
||||
let route_snapshot = route_runtime.snapshot();
|
||||
let session_id = rng.u64();
|
||||
let _user_session = shared.register_user_session(&user, session_id);
|
||||
let session_cancel = _user_session.token();
|
||||
let selected_me_pool = if config.general.use_middle_proxy
|
||||
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
|
||||
{
|
||||
@@ -1607,6 +1614,7 @@ impl RunningClientHandler {
|
||||
route_runtime.subscribe(),
|
||||
route_snapshot,
|
||||
session_id,
|
||||
session_cancel.clone(),
|
||||
shared.clone(),
|
||||
)
|
||||
.await
|
||||
@@ -1625,6 +1633,7 @@ impl RunningClientHandler {
|
||||
route_snapshot,
|
||||
session_id,
|
||||
local_addr,
|
||||
session_cancel.clone(),
|
||||
shared.clone(),
|
||||
)
|
||||
.await
|
||||
@@ -1644,6 +1653,7 @@ impl RunningClientHandler {
|
||||
route_snapshot,
|
||||
session_id,
|
||||
local_addr,
|
||||
session_cancel,
|
||||
shared.clone(),
|
||||
)
|
||||
.await
|
||||
|
||||
+40
-19
@@ -10,6 +10,7 @@ use std::time::Duration;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split};
|
||||
use tokio::sync::watch;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
@@ -258,6 +259,7 @@ where
|
||||
route_snapshot,
|
||||
session_id,
|
||||
SocketAddr::from(([0, 0, 0, 0], config.server.port)),
|
||||
CancellationToken::new(),
|
||||
ProxySharedState::new(),
|
||||
)
|
||||
.await
|
||||
@@ -276,6 +278,7 @@ pub(crate) async fn handle_via_direct_with_shared<R, W>(
|
||||
route_snapshot: RouteCutoverState,
|
||||
session_id: u64,
|
||||
local_addr: SocketAddr,
|
||||
session_cancel: CancellationToken,
|
||||
shared: Arc<ProxySharedState>,
|
||||
) -> Result<()>
|
||||
where
|
||||
@@ -302,14 +305,25 @@ where
|
||||
"Ignoring invalid scope hint and falling back to default upstream selection"
|
||||
);
|
||||
}
|
||||
let tg_stream = upstream_manager
|
||||
.connect(dc_addr, Some(success.dc_idx), scope_hint)
|
||||
.await?;
|
||||
let tg_stream = tokio::select! {
|
||||
result = upstream_manager.connect(dc_addr, Some(success.dc_idx), scope_hint) => result?,
|
||||
_ = session_cancel.cancelled() => {
|
||||
return Err(ProxyError::UserDisabled {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
|
||||
|
||||
let (tg_reader, tg_writer) =
|
||||
do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?;
|
||||
let (tg_reader, tg_writer) = tokio::select! {
|
||||
result = do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()) => result?,
|
||||
_ = session_cancel.cancelled() => {
|
||||
return Err(ProxyError::UserDisabled {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||
|
||||
@@ -331,20 +345,22 @@ where
|
||||
} else {
|
||||
Duration::from_secs(1800)
|
||||
};
|
||||
let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout_and_lease(
|
||||
client_reader,
|
||||
client_writer,
|
||||
tg_reader,
|
||||
tg_writer,
|
||||
config.general.direct_relay_copy_buf_c2s_bytes,
|
||||
config.general.direct_relay_copy_buf_s2c_bytes,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
config.access.user_data_quota.get(user).copied(),
|
||||
buffer_pool,
|
||||
traffic_lease,
|
||||
relay_activity_timeout,
|
||||
);
|
||||
let relay_result =
|
||||
crate::proxy::relay::relay_bidirectional_with_activity_timeout_lease_and_cancel(
|
||||
client_reader,
|
||||
client_writer,
|
||||
tg_reader,
|
||||
tg_writer,
|
||||
config.general.direct_relay_copy_buf_c2s_bytes,
|
||||
config.general.direct_relay_copy_buf_s2c_bytes,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
config.access.user_data_quota.get(user).copied(),
|
||||
buffer_pool,
|
||||
traffic_lease,
|
||||
relay_activity_timeout,
|
||||
session_cancel.clone(),
|
||||
);
|
||||
tokio::pin!(relay_result);
|
||||
let relay_result = loop {
|
||||
if let Some(cutover) =
|
||||
@@ -371,6 +387,11 @@ where
|
||||
break relay_result.await;
|
||||
}
|
||||
}
|
||||
_ = session_cancel.cancelled() => {
|
||||
break Err(ProxyError::UserDisabled {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
|
||||
mut route_rx: watch::Receiver<RouteCutoverState>,
|
||||
route_snapshot: RouteCutoverState,
|
||||
session_id: u64,
|
||||
session_cancel: CancellationToken,
|
||||
shared: Arc<ProxySharedState>,
|
||||
) -> Result<()>
|
||||
where
|
||||
@@ -20,6 +21,10 @@ where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let user = success.user.clone();
|
||||
if session_cancel.is_cancelled() {
|
||||
return Err(ProxyError::UserDisabled { user });
|
||||
}
|
||||
|
||||
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
||||
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
|
||||
let peer = success.peer;
|
||||
@@ -590,6 +595,25 @@ where
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = session_cancel.cancelled() => {
|
||||
warn!(
|
||||
user = %user,
|
||||
conn_id,
|
||||
"Disabled user middle session cancelled"
|
||||
);
|
||||
let _ = enqueue_c2me_command_in(
|
||||
shared.as_ref(),
|
||||
&c2me_tx,
|
||||
C2MeCommand::Close,
|
||||
c2me_send_timeout,
|
||||
stats.as_ref(),
|
||||
)
|
||||
.await;
|
||||
main_result = Err(ProxyError::UserDisabled {
|
||||
user: user.clone(),
|
||||
});
|
||||
break;
|
||||
}
|
||||
changed = route_rx.changed(), if route_watch_open => {
|
||||
if changed.is_err() {
|
||||
route_watch_open = false;
|
||||
|
||||
+119
-8
@@ -55,11 +55,13 @@ use crate::error::{ProxyError, Result};
|
||||
use crate::proxy::traffic_limiter::TrafficLease;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::future::pending;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, copy_bidirectional_with_sizes};
|
||||
use tokio::time::Instant;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
// ============= Constants =============
|
||||
@@ -191,6 +193,84 @@ pub async fn relay_bidirectional_with_activity_timeout_and_lease<CR, CW, SR, SW>
|
||||
traffic_lease: Option<Arc<TrafficLease>>,
|
||||
activity_timeout: Duration,
|
||||
) -> Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
relay_bidirectional_with_activity_timeout_lease_cancel_inner(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
c2s_buf_size,
|
||||
s2c_buf_size,
|
||||
user,
|
||||
stats,
|
||||
quota_limit,
|
||||
_buffer_pool,
|
||||
traffic_lease,
|
||||
activity_timeout,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn relay_bidirectional_with_activity_timeout_lease_and_cancel<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
server_reader: SR,
|
||||
server_writer: SW,
|
||||
c2s_buf_size: usize,
|
||||
s2c_buf_size: usize,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
quota_limit: Option<u64>,
|
||||
_buffer_pool: Arc<BufferPool>,
|
||||
traffic_lease: Option<Arc<TrafficLease>>,
|
||||
activity_timeout: Duration,
|
||||
session_cancel: CancellationToken,
|
||||
) -> Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
relay_bidirectional_with_activity_timeout_lease_cancel_inner(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
c2s_buf_size,
|
||||
s2c_buf_size,
|
||||
user,
|
||||
stats,
|
||||
quota_limit,
|
||||
_buffer_pool,
|
||||
traffic_lease,
|
||||
activity_timeout,
|
||||
Some(session_cancel),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn relay_bidirectional_with_activity_timeout_lease_cancel_inner<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
server_reader: SR,
|
||||
server_writer: SW,
|
||||
c2s_buf_size: usize,
|
||||
s2c_buf_size: usize,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
quota_limit: Option<u64>,
|
||||
_buffer_pool: Arc<BufferPool>,
|
||||
traffic_lease: Option<Arc<TrafficLease>>,
|
||||
activity_timeout: Duration,
|
||||
session_cancel: Option<CancellationToken>,
|
||||
) -> Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
@@ -287,14 +367,29 @@ where
|
||||
//
|
||||
// When the watchdog fires, select! drops the copy future,
|
||||
// releasing the &mut borrows on client and server.
|
||||
let copy_result = tokio::select! {
|
||||
enum RelayOutcome {
|
||||
Copy(std::io::Result<(u64, u64)>),
|
||||
ActivityTimeout,
|
||||
UserDisabled,
|
||||
}
|
||||
|
||||
let cancel_wait = async move {
|
||||
match session_cancel {
|
||||
Some(token) => token.cancelled().await,
|
||||
None => pending::<()>().await,
|
||||
}
|
||||
};
|
||||
tokio::pin!(cancel_wait);
|
||||
|
||||
let relay_outcome = tokio::select! {
|
||||
result = copy_bidirectional_with_sizes(
|
||||
&mut client,
|
||||
&mut server,
|
||||
c2s_buf_size.max(1),
|
||||
s2c_buf_size.max(1),
|
||||
) => Some(result),
|
||||
_ = watchdog => None, // Activity timeout — cancel relay
|
||||
) => RelayOutcome::Copy(result),
|
||||
_ = watchdog => RelayOutcome::ActivityTimeout,
|
||||
_ = &mut cancel_wait => RelayOutcome::UserDisabled,
|
||||
};
|
||||
|
||||
// ── Clean shutdown ──────────────────────────────────────────────
|
||||
@@ -308,8 +403,8 @@ where
|
||||
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
|
||||
let duration = epoch.elapsed();
|
||||
|
||||
match copy_result {
|
||||
Some(Ok((c2s, s2c))) => {
|
||||
match relay_outcome {
|
||||
RelayOutcome::Copy(Ok((c2s, s2c))) => {
|
||||
// Normal completion — one side closed the connection
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
@@ -322,7 +417,7 @@ where
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Some(Err(e)) if is_quota_io_error(&e) => {
|
||||
RelayOutcome::Copy(Err(e)) if is_quota_io_error(&e) => {
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
warn!(
|
||||
@@ -338,7 +433,7 @@ where
|
||||
user: user_owned.clone(),
|
||||
})
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
RelayOutcome::Copy(Err(e)) => {
|
||||
// I/O error in one of the directions
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
@@ -354,7 +449,7 @@ where
|
||||
);
|
||||
Err(e.into())
|
||||
}
|
||||
None => {
|
||||
RelayOutcome::ActivityTimeout => {
|
||||
// Activity timeout (watchdog fired)
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
@@ -369,6 +464,22 @@ where
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
RelayOutcome::UserDisabled => {
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s,
|
||||
s2c_bytes = s2c,
|
||||
c2s_msgs = c2s_ops,
|
||||
s2c_msgs = s2c_ops,
|
||||
duration_secs = duration.as_secs(),
|
||||
"Relay finished (user disabled)"
|
||||
);
|
||||
Err(ProxyError::UserDisabled {
|
||||
user: user_owned.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+143
-1
@@ -1,5 +1,5 @@
|
||||
use std::collections::HashSet;
|
||||
use std::collections::hash_map::RandomState;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
@@ -7,6 +7,7 @@ use std::time::Instant;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
|
||||
use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry};
|
||||
@@ -67,10 +68,35 @@ pub(crate) struct ProxySharedState {
|
||||
pub(crate) handshake: HandshakeSharedState,
|
||||
pub(crate) middle_relay: MiddleRelaySharedState,
|
||||
pub(crate) traffic_limiter: Arc<TrafficLimiter>,
|
||||
disabled_users: DashMap<String, ()>,
|
||||
active_user_sessions: DashMap<(String, u64), CancellationToken>,
|
||||
pub(crate) conntrack_pressure_active: AtomicBool,
|
||||
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
|
||||
}
|
||||
|
||||
#[must_use = "registered user sessions must be kept alive until relay completion"]
|
||||
pub(crate) struct UserSessionRegistration {
|
||||
token: CancellationToken,
|
||||
_guard: UserSessionGuard,
|
||||
}
|
||||
|
||||
impl UserSessionRegistration {
|
||||
pub(crate) fn token(&self) -> CancellationToken {
|
||||
self.token.clone()
|
||||
}
|
||||
}
|
||||
|
||||
struct UserSessionGuard {
|
||||
shared: Arc<ProxySharedState>,
|
||||
key: (String, u64),
|
||||
}
|
||||
|
||||
impl Drop for UserSessionGuard {
|
||||
fn drop(&mut self) {
|
||||
self.shared.active_user_sessions.remove(&self.key);
|
||||
}
|
||||
}
|
||||
|
||||
impl ProxySharedState {
|
||||
pub(crate) fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
@@ -101,11 +127,82 @@ impl ProxySharedState {
|
||||
relay_idle_mark_seq: AtomicU64::new(0),
|
||||
},
|
||||
traffic_limiter: TrafficLimiter::new(),
|
||||
disabled_users: DashMap::new(),
|
||||
active_user_sessions: DashMap::new(),
|
||||
conntrack_pressure_active: AtomicBool::new(false),
|
||||
conntrack_close_tx: Mutex::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn is_user_enabled(&self, user: &str) -> bool {
|
||||
!self.disabled_users.contains_key(user)
|
||||
}
|
||||
|
||||
pub(crate) fn set_user_enabled(&self, user: &str, enabled: bool) -> bool {
|
||||
if enabled {
|
||||
self.disabled_users.remove(user);
|
||||
false
|
||||
} else {
|
||||
self.disabled_users.insert(user.to_string(), ()).is_none()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_user_enabled_config(
|
||||
&self,
|
||||
user_enabled: &HashMap<String, bool>,
|
||||
) -> Vec<String> {
|
||||
let desired_disabled = user_enabled
|
||||
.iter()
|
||||
.filter_map(|(user, enabled)| (!*enabled).then_some(user.clone()))
|
||||
.collect::<HashSet<_>>();
|
||||
let current_disabled = self
|
||||
.disabled_users
|
||||
.iter()
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
for user in current_disabled.difference(&desired_disabled) {
|
||||
self.disabled_users.remove(user);
|
||||
}
|
||||
let newly_disabled = desired_disabled
|
||||
.difference(¤t_disabled)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
for user in desired_disabled {
|
||||
self.disabled_users.insert(user, ());
|
||||
}
|
||||
newly_disabled
|
||||
}
|
||||
|
||||
pub(crate) fn register_user_session(
|
||||
self: &Arc<Self>,
|
||||
user: &str,
|
||||
session_id: u64,
|
||||
) -> UserSessionRegistration {
|
||||
let token = CancellationToken::new();
|
||||
let key = (user.to_string(), session_id);
|
||||
self.active_user_sessions.insert(key.clone(), token.clone());
|
||||
UserSessionRegistration {
|
||||
token,
|
||||
_guard: UserSessionGuard {
|
||||
shared: Arc::clone(self),
|
||||
key,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cancel_user_sessions(&self, user: &str) -> usize {
|
||||
let tokens = self
|
||||
.active_user_sessions
|
||||
.iter()
|
||||
.filter_map(|entry| (entry.key().0 == user).then(|| entry.value().clone()))
|
||||
.collect::<Vec<_>>();
|
||||
for token in &tokens {
|
||||
token.cancel();
|
||||
}
|
||||
tokens.len()
|
||||
}
|
||||
|
||||
pub(crate) fn set_conntrack_close_sender(&self, tx: mpsc::Sender<ConntrackCloseEvent>) {
|
||||
match self.conntrack_close_tx.lock() {
|
||||
Ok(mut guard) => {
|
||||
@@ -166,3 +263,48 @@ impl ProxySharedState {
|
||||
self.conntrack_pressure_active.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn user_enabled_config_sync_tracks_disabled_overrides() {
|
||||
let shared = ProxySharedState::new();
|
||||
assert!(shared.is_user_enabled("alice"));
|
||||
|
||||
let mut user_enabled = HashMap::new();
|
||||
user_enabled.insert("alice".to_string(), false);
|
||||
user_enabled.insert("bob".to_string(), true);
|
||||
|
||||
let mut newly_disabled = shared.apply_user_enabled_config(&user_enabled);
|
||||
newly_disabled.sort();
|
||||
assert_eq!(newly_disabled, vec!["alice".to_string()]);
|
||||
assert!(!shared.is_user_enabled("alice"));
|
||||
assert!(shared.is_user_enabled("bob"));
|
||||
|
||||
assert!(shared.apply_user_enabled_config(&user_enabled).is_empty());
|
||||
|
||||
user_enabled.clear();
|
||||
assert!(shared.apply_user_enabled_config(&user_enabled).is_empty());
|
||||
assert!(shared.is_user_enabled("alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cancel_user_sessions_cancels_only_registered_matching_user() {
|
||||
let shared = ProxySharedState::new();
|
||||
let alice_1 = shared.register_user_session("alice", 1);
|
||||
let alice_2 = shared.register_user_session("alice", 2);
|
||||
let bob = shared.register_user_session("bob", 1);
|
||||
let alice_1_token = alice_1.token();
|
||||
let alice_2_token = alice_2.token();
|
||||
let bob_token = bob.token();
|
||||
|
||||
drop(alice_1);
|
||||
|
||||
assert_eq!(shared.cancel_user_sessions("alice"), 1);
|
||||
assert!(!alice_1_token.is_cancelled());
|
||||
assert!(alice_2_token.is_cancelled());
|
||||
assert!(!bob_token.is_cancelled());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user