From 22649809267ef643dfad0f02f0903cec227cf267 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 31 May 2026 11:17:18 +0300 Subject: [PATCH] User Disabler in API by #814 + Consistent Listeners in API by #800 --- src/api/config_store.rs | 12 +++ src/api/mod.rs | 153 +++++++++++++++++++++++++++++- src/api/model.rs | 4 + src/api/users.rs | 111 ++++++++++++++++++++++ src/config/hot_reload.rs | 13 +++ src/config/load.rs | 1 + src/config/types.rs | 8 ++ src/error.rs | 3 + src/maestro/mod.rs | 13 ++- src/maestro/runtime_tasks.rs | 23 ++++- src/metrics.rs | 6 +- src/proxy/client.rs | 10 ++ src/proxy/direct_relay.rs | 59 ++++++++---- src/proxy/middle_relay/session.rs | 24 +++++ src/proxy/relay.rs | 127 +++++++++++++++++++++++-- src/proxy/shared_state.rs | 144 +++++++++++++++++++++++++++- 16 files changed, 671 insertions(+), 40 deletions(-) diff --git a/src/api/config_store.rs b/src/api/config_store.rs index 1416667..6be4040 100644 --- a/src/api/config_store.rs +++ b/src/api/config_store.rs @@ -14,6 +14,7 @@ use super::model::ApiFailure; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(super) enum AccessSection { Users, + UserEnabled, UserAdTags, UserMaxTcpConns, UserExpirations, @@ -26,6 +27,7 @@ impl AccessSection { fn table_name(self) -> &'static str { match self { Self::Users => "access.users", + Self::UserEnabled => "access.user_enabled", Self::UserAdTags => "access.user_ad_tags", Self::UserMaxTcpConns => "access.user_max_tcp_conns", Self::UserExpirations => "access.user_expirations", @@ -135,6 +137,15 @@ fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result { + let rows: BTreeMap = cfg + .access + .user_enabled + .iter() + .map(|(key, value)| (key.clone(), *value)) + .collect(); + serialize_table_body(&rows)? + } AccessSection::UserAdTags => { let rows: BTreeMap = cfg .access @@ -204,6 +215,7 @@ fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result bool { match section { AccessSection::Users => cfg.access.users.is_empty(), + AccessSection::UserEnabled => cfg.access.user_enabled.is_empty(), AccessSection::UserAdTags => cfg.access.user_ad_tags.is_empty(), AccessSection::UserMaxTcpConns => cfg.access.user_max_tcp_conns.is_empty(), AccessSection::UserExpirations => cfg.access.user_expirations.is_empty(), diff --git a/src/api/mod.rs b/src/api/mod.rs index 2e2ef6f..7a03346 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -22,6 +22,7 @@ use tracing::{debug, info, warn}; use crate::config::{ApiGrayAction, ProxyConfig}; use crate::ip_tracker::UserIpTracker; use crate::proxy::route_mode::RouteRuntimeController; +use crate::proxy::shared_state::ProxySharedState; use crate::startup::StartupTracker; use crate::stats::Stats; use crate::transport::UpstreamManager; @@ -51,6 +52,7 @@ use model::{ PatchUserRequest, ResetUserQuotaResponse, RotateSecretRequest, SummaryData, UserActiveIps, is_valid_username, }; +use patch::Patch; use runtime_edge::{ EdgeConnectionsCacheEntry, build_runtime_connections_summary_data, build_runtime_events_recent_data, @@ -71,7 +73,8 @@ use runtime_zero::{ build_system_info_data, }; use users::{ - build_user_quota_list, create_user, delete_user, patch_user, rotate_secret, users_from_config, + build_user_quota_list, create_user, delete_user, patch_user, rotate_secret, set_user_enabled, + users_from_config, }; const API_MAX_CONTROL_CONNECTIONS: usize = 1024; @@ -107,6 +110,7 @@ pub(super) struct ApiShared { pub(super) runtime_state: Arc, pub(super) startup_tracker: Arc, pub(super) route_runtime: Arc, + pub(super) proxy_shared: Arc, } impl ApiShared { @@ -171,6 +175,8 @@ fn allowed_methods_for_path(path: &str) -> Option<&'static str> { "/v1/users" => Some(ALLOW_GET_POST), _ if user_action_route_matches(path, "/reset-quota") => Some(ALLOW_POST), _ if user_action_route_matches(path, "/rotate-secret") => Some(ALLOW_POST), + _ if user_action_route_matches(path, "/enable") => Some(ALLOW_POST), + _ if user_action_route_matches(path, "/disable") => Some(ALLOW_POST), _ if path .strip_prefix("/v1/users/") .map(|user| !user.is_empty() && !user.contains('/')) @@ -188,6 +194,7 @@ pub async fn serve( ip_tracker: Arc, me_pool: Arc>>>, route_runtime: Arc, + proxy_shared: Arc, upstream_manager: Arc, config_rx: watch::Receiver>, admission_rx: watch::Receiver, @@ -237,6 +244,7 @@ pub async fn serve( runtime_state: runtime_state.clone(), startup_tracker, route_runtime, + proxy_shared, }); spawn_runtime_watchers( @@ -582,6 +590,7 @@ async fn handle( } let expected_revision = parse_if_match(req.headers()); let body = read_json::(req.into_body(), body_limit).await?; + let requested_enabled = body.enabled; let result = create_user(body, expected_revision, &shared).await; let (mut data, revision) = match result { Ok(ok) => ok, @@ -594,6 +603,25 @@ async fn handle( }; let runtime_cfg = config_rx.borrow().clone(); data.user.in_runtime = runtime_cfg.access.users.contains_key(&data.user.username); + if let Some(enabled) = requested_enabled { + shared + .proxy_shared + .set_user_enabled(&data.user.username, enabled); + if !enabled { + let cancelled = shared + .proxy_shared + .cancel_user_sessions(&data.user.username); + if cancelled > 0 { + shared.runtime_events.record( + "api.user.disable.runtime", + format!( + "username={} cancelled_sessions={}", + data.user.username, cancelled + ), + ); + } + } + } shared.runtime_events.record( "api.user.create.ok", format!("username={}", data.user.username), @@ -606,6 +634,99 @@ async fn handle( Ok(success_response(status, data, revision)) } _ => { + if method == Method::POST + && let Some(base_user) = normalized_path + .strip_prefix("/v1/users/") + .and_then(|path| path.strip_suffix("/enable")) + && !base_user.is_empty() + && !base_user.contains('/') + { + let base_user = parse_route_username(base_user)?; + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let result = + set_user_enabled(base_user, true, expected_revision, &shared).await; + let (mut data, revision) = match result { + Ok(ok) => ok, + Err(error) => { + shared.runtime_events.record( + "api.user.enable.failed", + format!("username={} code={}", base_user, error.code), + ); + return Err(error); + } + }; + let runtime_cfg = config_rx.borrow().clone(); + data.in_runtime = runtime_cfg.access.users.contains_key(&data.username); + shared.proxy_shared.set_user_enabled(base_user, true); + shared + .runtime_events + .record("api.user.enable.ok", format!("username={}", base_user)); + let status = if data.in_runtime { + StatusCode::OK + } else { + StatusCode::ACCEPTED + }; + return Ok(success_response(status, data, revision)); + } + if method == Method::POST + && let Some(base_user) = normalized_path + .strip_prefix("/v1/users/") + .and_then(|path| path.strip_suffix("/disable")) + && !base_user.is_empty() + && !base_user.contains('/') + { + let base_user = parse_route_username(base_user)?; + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let result = + set_user_enabled(base_user, false, expected_revision, &shared).await; + let (mut data, revision) = match result { + Ok(ok) => ok, + Err(error) => { + shared.runtime_events.record( + "api.user.disable.failed", + format!("username={} code={}", base_user, error.code), + ); + return Err(error); + } + }; + let runtime_cfg = config_rx.borrow().clone(); + data.in_runtime = runtime_cfg.access.users.contains_key(&data.username); + let newly_disabled = shared.proxy_shared.set_user_enabled(base_user, false); + let cancelled = shared.proxy_shared.cancel_user_sessions(base_user); + shared.runtime_events.record( + "api.user.disable.ok", + format!( + "username={} newly_disabled={} cancelled_sessions={}", + base_user, newly_disabled, cancelled + ), + ); + let status = if data.in_runtime { + StatusCode::OK + } else { + StatusCode::ACCEPTED + }; + return Ok(success_response(status, data, revision)); + } if method == Method::POST && let Some(user) = normalized_path .strip_prefix("/v1/users/") @@ -763,6 +884,11 @@ async fn handle( let expected_revision = parse_if_match(req.headers()); let body = read_json::(req.into_body(), body_limit).await?; + let enabled_update = match &body.enabled { + Patch::Unchanged => None, + Patch::Remove => Some(true), + Patch::Set(enabled) => Some(*enabled), + }; let result = patch_user(user, body, expected_revision, &shared).await; let (mut data, revision) = match result { Ok(ok) => ok, @@ -776,6 +902,22 @@ async fn handle( }; let runtime_cfg = config_rx.borrow().clone(); data.in_runtime = runtime_cfg.access.users.contains_key(&data.username); + if let Some(enabled) = enabled_update { + shared + .proxy_shared + .set_user_enabled(&data.username, enabled); + if !enabled { + let cancelled = + shared.proxy_shared.cancel_user_sessions(&data.username); + shared.runtime_events.record( + "api.user.disable.runtime", + format!( + "username={} cancelled_sessions={}", + data.username, cancelled + ), + ); + } + } shared .runtime_events .record("api.user.patch.ok", format!("username={}", data.username)); @@ -809,9 +951,12 @@ async fn handle( return Err(error); } }; - shared - .runtime_events - .record("api.user.delete.ok", format!("username={}", deleted_user)); + shared.proxy_shared.set_user_enabled(&deleted_user, true); + let cancelled = shared.proxy_shared.cancel_user_sessions(&deleted_user); + shared.runtime_events.record( + "api.user.delete.ok", + format!("username={} cancelled_sessions={}", deleted_user, cancelled), + ); let runtime_cfg = config_rx.borrow().clone(); let in_runtime = runtime_cfg.access.users.contains_key(&deleted_user); let response = DeleteUserResponse { diff --git a/src/api/model.rs b/src/api/model.rs index 56e8fea..5a183d5 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -479,6 +479,7 @@ pub(super) struct TlsDomainLink { #[derive(Serialize)] pub(super) struct UserInfo { pub(super) username: String, + pub(super) enabled: bool, pub(super) in_runtime: bool, pub(super) user_ad_tag: Option, pub(super) max_tcp_conns: Option, @@ -545,6 +546,7 @@ pub(super) struct CreateUserRequest { pub(super) rate_limit_up_bps: Option, pub(super) rate_limit_down_bps: Option, pub(super) max_unique_ips: Option, + pub(super) enabled: Option, } #[derive(Deserialize)] @@ -564,6 +566,8 @@ pub(super) struct PatchUserRequest { pub(super) rate_limit_down_bps: Patch, #[serde(default, deserialize_with = "patch_field")] pub(super) max_unique_ips: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) enabled: Patch, } #[derive(Default, Deserialize)] diff --git a/src/api/users.rs b/src/api/users.rs index 24815fc..48acc25 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -32,6 +32,7 @@ pub(super) async fn create_user( let touches_user_rate_limits = body.rate_limit_up_bps.is_some() || body.rate_limit_down_bps.is_some(); let touches_user_max_unique_ips = body.max_unique_ips.is_some(); + let touches_user_enabled = matches!(body.enabled, Some(false)); if !is_valid_username(&body.username) { return Err(ApiFailure::bad_request( @@ -111,6 +112,9 @@ pub(super) async fn create_user( .user_max_unique_ips .insert(body.username.clone(), limit); } + if matches!(body.enabled, Some(false)) { + cfg.access.user_enabled.insert(body.username.clone(), false); + } cfg.validate() .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; @@ -134,6 +138,9 @@ pub(super) async fn create_user( if touches_user_max_unique_ips { touched_sections.push(AccessSection::UserMaxUniqueIps); } + if touches_user_enabled { + touched_sections.push(AccessSection::UserEnabled); + } let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; @@ -161,6 +168,7 @@ pub(super) async fn create_user( .find(|entry| entry.username == body.username) .unwrap_or(UserInfo { username: body.username.clone(), + enabled: cfg.access.is_user_enabled(&body.username), in_runtime: false, user_ad_tag: None, max_tcp_conns: cfg @@ -202,6 +210,7 @@ pub(super) async fn patch_user( let touches_user_rate_limits = !matches!(&body.rate_limit_up_bps, Patch::Unchanged) || !matches!(&body.rate_limit_down_bps, Patch::Unchanged); let touches_user_max_unique_ips = !matches!(&body.max_unique_ips, Patch::Unchanged); + let touches_user_enabled = !matches!(&body.enabled, Patch::Unchanged); if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) @@ -313,6 +322,15 @@ pub(super) async fn patch_user( Some(Some(limit)) } }; + match body.enabled { + Patch::Unchanged => {} + Patch::Remove | Patch::Set(true) => { + cfg.access.user_enabled.remove(user); + } + Patch::Set(false) => { + cfg.access.user_enabled.insert(user.to_string(), false); + } + } cfg.validate() .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; @@ -339,6 +357,9 @@ pub(super) async fn patch_user( if touches_user_max_unique_ips { touched_sections.push(AccessSection::UserMaxUniqueIps); } + if touches_user_enabled { + touched_sections.push(AccessSection::UserEnabled); + } let revision = if touched_sections.is_empty() { current_revision(&shared.config_path).await? @@ -399,6 +420,7 @@ pub(super) async fn rotate_secret( .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; let touched_sections = [ AccessSection::Users, + AccessSection::UserEnabled, AccessSection::UserAdTags, AccessSection::UserMaxTcpConns, AccessSection::UserExpirations, @@ -434,6 +456,55 @@ pub(super) async fn rotate_secret( )) } +pub(super) async fn set_user_enabled( + user: &str, + enabled: bool, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(UserInfo, String), ApiFailure> { + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if !cfg.access.users.contains_key(user) { + return Err(ApiFailure::new( + StatusCode::NOT_FOUND, + "not_found", + "User not found", + )); + } + + if enabled { + cfg.access.user_enabled.remove(user); + } else { + cfg.access.user_enabled.insert(user.to_string(), false); + } + + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + let revision = + save_access_sections_to_disk(&shared.config_path, &cfg, &[AccessSection::UserEnabled]) + .await?; + drop(_guard); + + let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); + let users = users_from_config( + &cfg, + &shared.stats, + &shared.ip_tracker, + detected_ip_v4, + detected_ip_v6, + None, + ) + .await; + let user_info = users + .into_iter() + .find(|entry| entry.username == user) + .ok_or_else(|| ApiFailure::internal("failed to build updated user view"))?; + + Ok((user_info, revision)) +} + pub(super) async fn delete_user( user: &str, expected_revision: Option, @@ -459,6 +530,7 @@ pub(super) async fn delete_user( } cfg.access.users.remove(user); + cfg.access.user_enabled.remove(user); cfg.access.user_ad_tags.remove(user); cfg.access.user_max_tcp_conns.remove(user); cfg.access.user_expirations.remove(user); @@ -470,6 +542,7 @@ pub(super) async fn delete_user( .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; let touched_sections = [ AccessSection::Users, + AccessSection::UserEnabled, AccessSection::UserAdTags, AccessSection::UserMaxTcpConns, AccessSection::UserExpirations, @@ -518,6 +591,7 @@ pub(super) async fn users_from_config( }) .unwrap_or_else(empty_user_links); users.push(UserInfo { + enabled: cfg.access.is_user_enabled(&username), in_runtime: runtime_cfg .map(|runtime| runtime.access.users.contains_key(&username)) .unwrap_or(false), @@ -876,6 +950,43 @@ mod tests { assert_eq!(alice.rate_limit_down_bps, None); } + #[tokio::test] + async fn users_from_config_reports_user_enabled_default_and_override() { + let mut cfg = ProxyConfig::default(); + cfg.access.users.insert( + "alice".to_string(), + "0123456789abcdef0123456789abcdef".to_string(), + ); + cfg.access.users.insert( + "bob".to_string(), + "fedcba9876543210fedcba9876543210".to_string(), + ); + cfg.access.user_enabled.insert("bob".to_string(), false); + + let stats = Stats::new(); + let tracker = UserIpTracker::new(); + let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + let bob = users + .iter() + .find(|entry| entry.username == "bob") + .expect("bob must be present"); + + assert!(alice.enabled); + assert!(!bob.enabled); + + cfg.access.user_enabled.insert("bob".to_string(), true); + let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await; + let bob = users + .iter() + .find(|entry| entry.username == "bob") + .expect("bob must be present"); + assert!(bob.enabled); + } + #[tokio::test] async fn users_from_config_marks_runtime_membership_when_snapshot_is_provided() { let mut disk_cfg = ProxyConfig::default(); diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 8135d31..4faef9b 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -118,6 +118,7 @@ pub struct HotFields { pub me_admission_poll_ms: u64, pub me_warn_rate_limit_ms: u64, pub users: std::collections::HashMap, + pub user_enabled: std::collections::HashMap, pub user_ad_tags: std::collections::HashMap, pub user_max_tcp_conns: std::collections::HashMap, pub user_max_tcp_conns_global_each: usize, @@ -247,6 +248,7 @@ impl HotFields { me_admission_poll_ms: cfg.general.me_admission_poll_ms, me_warn_rate_limit_ms: cfg.general.me_warn_rate_limit_ms, users: cfg.access.users.clone(), + user_enabled: cfg.access.user_enabled.clone(), user_ad_tags: cfg.access.user_ad_tags.clone(), user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each, @@ -551,6 +553,7 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.general.me_warn_rate_limit_ms = new.general.me_warn_rate_limit_ms; cfg.access.users = new.access.users.clone(); + cfg.access.user_enabled = new.access.user_enabled.clone(); cfg.access.user_ad_tags = new.access.user_ad_tags.clone(); cfg.access.user_max_tcp_conns = new.access.user_max_tcp_conns.clone(); cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each; @@ -1178,6 +1181,16 @@ fn log_changes( } } + if old_hot.user_enabled != new_hot.user_enabled { + info!( + "config reload: user_enabled updated ({} disabled overrides)", + new_hot + .user_enabled + .values() + .filter(|enabled| !**enabled) + .count() + ); + } if old_hot.user_max_tcp_conns != new_hot.user_max_tcp_conns { info!( "config reload: user_max_tcp_conns updated ({} entries)", diff --git a/src/config/load.rs b/src/config/load.rs index d14510f..41bfb71 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -411,6 +411,7 @@ const TLS_FETCH_CONFIG_KEYS: &[&str] = &[ const ACCESS_CONFIG_KEYS: &[&str] = &[ "users", + "user_enabled", "user_ad_tags", "user_max_tcp_conns", "user_max_tcp_conns_global_each", diff --git a/src/config/types.rs b/src/config/types.rs index b707dff..e79c2d1 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1892,6 +1892,9 @@ pub struct AccessConfig { #[serde(default = "default_access_users")] pub users: HashMap, + #[serde(default)] + pub user_enabled: HashMap, + /// Per-user ad_tag (32 hex chars from @MTProxybot). #[serde(default)] pub user_ad_tags: HashMap, @@ -1963,6 +1966,7 @@ impl Default for AccessConfig { fn default() -> Self { Self { users: default_access_users(), + user_enabled: HashMap::new(), user_ad_tags: HashMap::new(), user_max_tcp_conns: HashMap::new(), user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(), @@ -1983,6 +1987,10 @@ impl Default for AccessConfig { } impl AccessConfig { + pub fn is_user_enabled(&self, username: &str) -> bool { + self.user_enabled.get(username).copied().unwrap_or(true) + } + /// Returns true if `ip` is contained in any CIDR listed for `username` under `user_source_deny`. pub fn is_user_source_ip_denied(&self, username: &str, ip: IpAddr) -> bool { self.user_source_deny diff --git a/src/error.rs b/src/error.rs index ff58f4e..1cefe97 100644 --- a/src/error.rs +++ b/src/error.rs @@ -245,6 +245,9 @@ pub enum ProxyError { InvalidSecret { user: String, reason: String }, // ============= User Errors ============= + #[error("User {user} disabled")] + UserDisabled { user: String }, + #[error("User {user} expired")] UserExpired { user: String }, diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index e711ab4..2d4fb54 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -464,6 +464,12 @@ async fn run_telemt_core( config.network.dns_overrides.len() ); } + let shared_state = ProxySharedState::new(); + shared_state.apply_user_enabled_config(&config.access.user_enabled); + shared_state.traffic_limiter.apply_policy( + config.access.user_rate_limits.clone(), + config.access.cidr_rate_limits.clone(), + ); let (api_config_tx, api_config_rx) = watch::channel(Arc::new(config.clone())); let (detected_ips_tx, detected_ips_rx) = watch::channel((None::, None::)); @@ -502,6 +508,7 @@ async fn run_telemt_core( let me_pool_api = api_me_pool.clone(); let upstream_manager_api = upstream_manager.clone(); let route_runtime_api = route_runtime.clone(); + let proxy_shared_api = shared_state.clone(); let config_rx_api = api_config_rx.clone(); let admission_rx_api = admission_rx.clone(); let config_path_api = config_path.clone(); @@ -515,6 +522,7 @@ async fn run_telemt_core( ip_tracker_api, me_pool_api, route_runtime_api, + proxy_shared_api, upstream_manager_api, config_rx_api, admission_rx_api, @@ -732,11 +740,6 @@ async fn run_telemt_core( )); let buffer_pool = Arc::new(BufferPool::with_config(64 * 1024, 4096)); - let shared_state = ProxySharedState::new(); - shared_state.traffic_limiter.apply_policy( - config.access.user_rate_limits.clone(), - config.access.cidr_rate_limits.clone(), - ); if direct_first_startup { startup_tracker diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 8b9a9aa..6099014 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -3,7 +3,7 @@ use std::path::Path; use std::sync::Arc; use tokio::sync::{mpsc, watch}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; use tracing_subscriber::EnvFilter; use tracing_subscriber::reload; @@ -234,6 +234,27 @@ pub(crate) async fn spawn_runtime_tasks( } }); + let shared_user_enabled = shared_state.clone(); + let mut config_rx_user_enabled = config_rx.clone(); + tokio::spawn(async move { + loop { + if config_rx_user_enabled.changed().await.is_err() { + break; + } + let cfg = config_rx_user_enabled.borrow_and_update().clone(); + for user in shared_user_enabled.apply_user_enabled_config(&cfg.access.user_enabled) { + let cancelled = shared_user_enabled.cancel_user_sessions(&user); + if cancelled > 0 { + info!( + user = %user, + cancelled, + "Disabled user sessions cancelled after config reload" + ); + } + } + } + }); + let beobachten_writer = beobachten.clone(); let config_rx_beobachten = config_rx.clone(); tokio::spawn(async move { diff --git a/src/metrics.rs b/src/metrics.rs index 61a26c5..36981d3 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -55,8 +55,10 @@ pub async fn serve( return; } }; - let is_ipv6 = addr.is_ipv6(); - match bind_metrics_listener(addr, is_ipv6, listen_backlog) { + // Match `server.api.listen`: `[::]:port` is a dual-stack wildcard + // on Linux when `net.ipv6.bindv6only=0`. + let ipv6_only = addr.is_ipv6() && !addr.ip().is_unspecified(); + match bind_metrics_listener(addr, ipv6_only, listen_backlog) { Ok(listener) => { info!("Metrics endpoint: http://{}/metrics and /beobachten", addr); serve_listener( diff --git a/src/proxy/client.rs b/src/proxy/client.rs index a8357b7..5700d40 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1558,6 +1558,11 @@ impl RunningClientHandler { { let user = success.user.clone(); + if !shared.is_user_enabled(&user) { + warn!(user = %user, "Disabled user rejected"); + return Err(ProxyError::UserDisabled { user }); + } + let user_limit_reservation = match Self::acquire_user_connection_reservation_static( &user, &config, @@ -1576,6 +1581,8 @@ impl RunningClientHandler { let route_snapshot = route_runtime.snapshot(); let session_id = rng.u64(); + let _user_session = shared.register_user_session(&user, session_id); + let session_cancel = _user_session.token(); let selected_me_pool = if config.general.use_middle_proxy && matches!(route_snapshot.mode, RelayRouteMode::Middle) { @@ -1607,6 +1614,7 @@ impl RunningClientHandler { route_runtime.subscribe(), route_snapshot, session_id, + session_cancel.clone(), shared.clone(), ) .await @@ -1625,6 +1633,7 @@ impl RunningClientHandler { route_snapshot, session_id, local_addr, + session_cancel.clone(), shared.clone(), ) .await @@ -1644,6 +1653,7 @@ impl RunningClientHandler { route_snapshot, session_id, local_addr, + session_cancel, shared.clone(), ) .await diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 2fea54a..4f7cb20 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -10,6 +10,7 @@ use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split}; use tokio::sync::watch; +use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; use crate::config::ProxyConfig; @@ -258,6 +259,7 @@ where route_snapshot, session_id, SocketAddr::from(([0, 0, 0, 0], config.server.port)), + CancellationToken::new(), ProxySharedState::new(), ) .await @@ -276,6 +278,7 @@ pub(crate) async fn handle_via_direct_with_shared( route_snapshot: RouteCutoverState, session_id: u64, local_addr: SocketAddr, + session_cancel: CancellationToken, shared: Arc, ) -> Result<()> where @@ -302,14 +305,25 @@ where "Ignoring invalid scope hint and falling back to default upstream selection" ); } - let tg_stream = upstream_manager - .connect(dc_addr, Some(success.dc_idx), scope_hint) - .await?; + let tg_stream = tokio::select! { + result = upstream_manager.connect(dc_addr, Some(success.dc_idx), scope_hint) => result?, + _ = session_cancel.cancelled() => { + return Err(ProxyError::UserDisabled { + user: user.to_string(), + }); + } + }; debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); - let (tg_reader, tg_writer) = - do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?; + let (tg_reader, tg_writer) = tokio::select! { + result = do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()) => result?, + _ = session_cancel.cancelled() => { + return Err(ProxyError::UserDisabled { + user: user.to_string(), + }); + } + }; debug!(peer = %success.peer, "TG handshake complete, starting relay"); @@ -331,20 +345,22 @@ where } else { Duration::from_secs(1800) }; - let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout_and_lease( - client_reader, - client_writer, - tg_reader, - tg_writer, - config.general.direct_relay_copy_buf_c2s_bytes, - config.general.direct_relay_copy_buf_s2c_bytes, - user, - Arc::clone(&stats), - config.access.user_data_quota.get(user).copied(), - buffer_pool, - traffic_lease, - relay_activity_timeout, - ); + let relay_result = + crate::proxy::relay::relay_bidirectional_with_activity_timeout_lease_and_cancel( + client_reader, + client_writer, + tg_reader, + tg_writer, + config.general.direct_relay_copy_buf_c2s_bytes, + config.general.direct_relay_copy_buf_s2c_bytes, + user, + Arc::clone(&stats), + config.access.user_data_quota.get(user).copied(), + buffer_pool, + traffic_lease, + relay_activity_timeout, + session_cancel.clone(), + ); tokio::pin!(relay_result); let relay_result = loop { if let Some(cutover) = @@ -371,6 +387,11 @@ where break relay_result.await; } } + _ = session_cancel.cancelled() => { + break Err(ProxyError::UserDisabled { + user: user.to_string(), + }); + } } }; diff --git a/src/proxy/middle_relay/session.rs b/src/proxy/middle_relay/session.rs index 81cf297..4865993 100644 --- a/src/proxy/middle_relay/session.rs +++ b/src/proxy/middle_relay/session.rs @@ -13,6 +13,7 @@ pub(crate) async fn handle_via_middle_proxy( mut route_rx: watch::Receiver, route_snapshot: RouteCutoverState, session_id: u64, + session_cancel: CancellationToken, shared: Arc, ) -> Result<()> where @@ -20,6 +21,10 @@ where W: AsyncWrite + Unpin + Send + 'static, { let user = success.user.clone(); + if session_cancel.is_cancelled() { + return Err(ProxyError::UserDisabled { user }); + } + let quota_limit = config.access.user_data_quota.get(&user).copied(); let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; @@ -590,6 +595,25 @@ where } tokio::select! { + _ = session_cancel.cancelled() => { + warn!( + user = %user, + conn_id, + "Disabled user middle session cancelled" + ); + let _ = enqueue_c2me_command_in( + shared.as_ref(), + &c2me_tx, + C2MeCommand::Close, + c2me_send_timeout, + stats.as_ref(), + ) + .await; + main_result = Err(ProxyError::UserDisabled { + user: user.clone(), + }); + break; + } changed = route_rx.changed(), if route_watch_open => { if changed.is_err() { route_watch_open = false; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 5ea9e87..36af33a 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -55,11 +55,13 @@ use crate::error::{ProxyError, Result}; use crate::proxy::traffic_limiter::TrafficLease; use crate::stats::Stats; use crate::stream::BufferPool; +use std::future::pending; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, copy_bidirectional_with_sizes}; use tokio::time::Instant; +use tokio_util::sync::CancellationToken; use tracing::{debug, warn}; // ============= Constants ============= @@ -191,6 +193,84 @@ pub async fn relay_bidirectional_with_activity_timeout_and_lease traffic_lease: Option>, activity_timeout: Duration, ) -> Result<()> +where + CR: AsyncRead + Unpin + Send + 'static, + CW: AsyncWrite + Unpin + Send + 'static, + SR: AsyncRead + Unpin + Send + 'static, + SW: AsyncWrite + Unpin + Send + 'static, +{ + relay_bidirectional_with_activity_timeout_lease_cancel_inner( + client_reader, + client_writer, + server_reader, + server_writer, + c2s_buf_size, + s2c_buf_size, + user, + stats, + quota_limit, + _buffer_pool, + traffic_lease, + activity_timeout, + None, + ) + .await +} + +pub async fn relay_bidirectional_with_activity_timeout_lease_and_cancel( + client_reader: CR, + client_writer: CW, + server_reader: SR, + server_writer: SW, + c2s_buf_size: usize, + s2c_buf_size: usize, + user: &str, + stats: Arc, + quota_limit: Option, + _buffer_pool: Arc, + traffic_lease: Option>, + activity_timeout: Duration, + session_cancel: CancellationToken, +) -> Result<()> +where + CR: AsyncRead + Unpin + Send + 'static, + CW: AsyncWrite + Unpin + Send + 'static, + SR: AsyncRead + Unpin + Send + 'static, + SW: AsyncWrite + Unpin + Send + 'static, +{ + relay_bidirectional_with_activity_timeout_lease_cancel_inner( + client_reader, + client_writer, + server_reader, + server_writer, + c2s_buf_size, + s2c_buf_size, + user, + stats, + quota_limit, + _buffer_pool, + traffic_lease, + activity_timeout, + Some(session_cancel), + ) + .await +} + +async fn relay_bidirectional_with_activity_timeout_lease_cancel_inner( + client_reader: CR, + client_writer: CW, + server_reader: SR, + server_writer: SW, + c2s_buf_size: usize, + s2c_buf_size: usize, + user: &str, + stats: Arc, + quota_limit: Option, + _buffer_pool: Arc, + traffic_lease: Option>, + activity_timeout: Duration, + session_cancel: Option, +) -> Result<()> where CR: AsyncRead + Unpin + Send + 'static, CW: AsyncWrite + Unpin + Send + 'static, @@ -287,14 +367,29 @@ where // // When the watchdog fires, select! drops the copy future, // releasing the &mut borrows on client and server. - let copy_result = tokio::select! { + enum RelayOutcome { + Copy(std::io::Result<(u64, u64)>), + ActivityTimeout, + UserDisabled, + } + + let cancel_wait = async move { + match session_cancel { + Some(token) => token.cancelled().await, + None => pending::<()>().await, + } + }; + tokio::pin!(cancel_wait); + + let relay_outcome = tokio::select! { result = copy_bidirectional_with_sizes( &mut client, &mut server, c2s_buf_size.max(1), s2c_buf_size.max(1), - ) => Some(result), - _ = watchdog => None, // Activity timeout — cancel relay + ) => RelayOutcome::Copy(result), + _ = watchdog => RelayOutcome::ActivityTimeout, + _ = &mut cancel_wait => RelayOutcome::UserDisabled, }; // ── Clean shutdown ────────────────────────────────────────────── @@ -308,8 +403,8 @@ where let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed); let duration = epoch.elapsed(); - match copy_result { - Some(Ok((c2s, s2c))) => { + match relay_outcome { + RelayOutcome::Copy(Ok((c2s, s2c))) => { // Normal completion — one side closed the connection debug!( user = %user_owned, @@ -322,7 +417,7 @@ where ); Ok(()) } - Some(Err(e)) if is_quota_io_error(&e) => { + RelayOutcome::Copy(Err(e)) if is_quota_io_error(&e) => { let c2s = counters.c2s_bytes.load(Ordering::Relaxed); let s2c = counters.s2c_bytes.load(Ordering::Relaxed); warn!( @@ -338,7 +433,7 @@ where user: user_owned.clone(), }) } - Some(Err(e)) => { + RelayOutcome::Copy(Err(e)) => { // I/O error in one of the directions let c2s = counters.c2s_bytes.load(Ordering::Relaxed); let s2c = counters.s2c_bytes.load(Ordering::Relaxed); @@ -354,7 +449,7 @@ where ); Err(e.into()) } - None => { + RelayOutcome::ActivityTimeout => { // Activity timeout (watchdog fired) let c2s = counters.c2s_bytes.load(Ordering::Relaxed); let s2c = counters.s2c_bytes.load(Ordering::Relaxed); @@ -369,6 +464,22 @@ where ); Ok(()) } + RelayOutcome::UserDisabled => { + let c2s = counters.c2s_bytes.load(Ordering::Relaxed); + let s2c = counters.s2c_bytes.load(Ordering::Relaxed); + debug!( + user = %user_owned, + c2s_bytes = c2s, + s2c_bytes = s2c, + c2s_msgs = c2s_ops, + s2c_msgs = s2c_ops, + duration_secs = duration.as_secs(), + "Relay finished (user disabled)" + ); + Err(ProxyError::UserDisabled { + user: user_owned.clone(), + }) + } } } diff --git a/src/proxy/shared_state.rs b/src/proxy/shared_state.rs index 11e390e..9ed319b 100644 --- a/src/proxy/shared_state.rs +++ b/src/proxy/shared_state.rs @@ -1,5 +1,5 @@ -use std::collections::HashSet; use std::collections::hash_map::RandomState; +use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; @@ -7,6 +7,7 @@ use std::time::Instant; use dashmap::DashMap; use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState}; use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry}; @@ -67,10 +68,35 @@ pub(crate) struct ProxySharedState { pub(crate) handshake: HandshakeSharedState, pub(crate) middle_relay: MiddleRelaySharedState, pub(crate) traffic_limiter: Arc, + disabled_users: DashMap, + active_user_sessions: DashMap<(String, u64), CancellationToken>, pub(crate) conntrack_pressure_active: AtomicBool, pub(crate) conntrack_close_tx: Mutex>>, } +#[must_use = "registered user sessions must be kept alive until relay completion"] +pub(crate) struct UserSessionRegistration { + token: CancellationToken, + _guard: UserSessionGuard, +} + +impl UserSessionRegistration { + pub(crate) fn token(&self) -> CancellationToken { + self.token.clone() + } +} + +struct UserSessionGuard { + shared: Arc, + key: (String, u64), +} + +impl Drop for UserSessionGuard { + fn drop(&mut self) { + self.shared.active_user_sessions.remove(&self.key); + } +} + impl ProxySharedState { pub(crate) fn new() -> Arc { Arc::new(Self { @@ -101,11 +127,82 @@ impl ProxySharedState { relay_idle_mark_seq: AtomicU64::new(0), }, traffic_limiter: TrafficLimiter::new(), + disabled_users: DashMap::new(), + active_user_sessions: DashMap::new(), conntrack_pressure_active: AtomicBool::new(false), conntrack_close_tx: Mutex::new(None), }) } + pub(crate) fn is_user_enabled(&self, user: &str) -> bool { + !self.disabled_users.contains_key(user) + } + + pub(crate) fn set_user_enabled(&self, user: &str, enabled: bool) -> bool { + if enabled { + self.disabled_users.remove(user); + false + } else { + self.disabled_users.insert(user.to_string(), ()).is_none() + } + } + + pub(crate) fn apply_user_enabled_config( + &self, + user_enabled: &HashMap, + ) -> Vec { + let desired_disabled = user_enabled + .iter() + .filter_map(|(user, enabled)| (!*enabled).then_some(user.clone())) + .collect::>(); + let current_disabled = self + .disabled_users + .iter() + .map(|entry| entry.key().clone()) + .collect::>(); + + for user in current_disabled.difference(&desired_disabled) { + self.disabled_users.remove(user); + } + let newly_disabled = desired_disabled + .difference(¤t_disabled) + .cloned() + .collect::>(); + for user in desired_disabled { + self.disabled_users.insert(user, ()); + } + newly_disabled + } + + pub(crate) fn register_user_session( + self: &Arc, + user: &str, + session_id: u64, + ) -> UserSessionRegistration { + let token = CancellationToken::new(); + let key = (user.to_string(), session_id); + self.active_user_sessions.insert(key.clone(), token.clone()); + UserSessionRegistration { + token, + _guard: UserSessionGuard { + shared: Arc::clone(self), + key, + }, + } + } + + pub(crate) fn cancel_user_sessions(&self, user: &str) -> usize { + let tokens = self + .active_user_sessions + .iter() + .filter_map(|entry| (entry.key().0 == user).then(|| entry.value().clone())) + .collect::>(); + for token in &tokens { + token.cancel(); + } + tokens.len() + } + pub(crate) fn set_conntrack_close_sender(&self, tx: mpsc::Sender) { match self.conntrack_close_tx.lock() { Ok(mut guard) => { @@ -166,3 +263,48 @@ impl ProxySharedState { self.conntrack_pressure_active.load(Ordering::Relaxed) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn user_enabled_config_sync_tracks_disabled_overrides() { + let shared = ProxySharedState::new(); + assert!(shared.is_user_enabled("alice")); + + let mut user_enabled = HashMap::new(); + user_enabled.insert("alice".to_string(), false); + user_enabled.insert("bob".to_string(), true); + + let mut newly_disabled = shared.apply_user_enabled_config(&user_enabled); + newly_disabled.sort(); + assert_eq!(newly_disabled, vec!["alice".to_string()]); + assert!(!shared.is_user_enabled("alice")); + assert!(shared.is_user_enabled("bob")); + + assert!(shared.apply_user_enabled_config(&user_enabled).is_empty()); + + user_enabled.clear(); + assert!(shared.apply_user_enabled_config(&user_enabled).is_empty()); + assert!(shared.is_user_enabled("alice")); + } + + #[test] + fn cancel_user_sessions_cancels_only_registered_matching_user() { + let shared = ProxySharedState::new(); + let alice_1 = shared.register_user_session("alice", 1); + let alice_2 = shared.register_user_session("alice", 2); + let bob = shared.register_user_session("bob", 1); + let alice_1_token = alice_1.token(); + let alice_2_token = alice_2.token(); + let bob_token = bob.token(); + + drop(alice_1); + + assert_eq!(shared.cancel_user_sessions("alice"), 1); + assert!(!alice_1_token.is_cancelled()); + assert!(alice_2_token.is_cancelled()); + assert!(!bob_token.is_cancelled()); + } +}