mirror of
https://github.com/telemt/telemt.git
synced 2026-06-09 20:41:44 +03:00
@@ -14,6 +14,7 @@ use super::model::ApiFailure;
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
pub(super) enum AccessSection {
|
pub(super) enum AccessSection {
|
||||||
Users,
|
Users,
|
||||||
|
UserEnabled,
|
||||||
UserAdTags,
|
UserAdTags,
|
||||||
UserMaxTcpConns,
|
UserMaxTcpConns,
|
||||||
UserExpirations,
|
UserExpirations,
|
||||||
@@ -26,6 +27,7 @@ impl AccessSection {
|
|||||||
fn table_name(self) -> &'static str {
|
fn table_name(self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::Users => "access.users",
|
Self::Users => "access.users",
|
||||||
|
Self::UserEnabled => "access.user_enabled",
|
||||||
Self::UserAdTags => "access.user_ad_tags",
|
Self::UserAdTags => "access.user_ad_tags",
|
||||||
Self::UserMaxTcpConns => "access.user_max_tcp_conns",
|
Self::UserMaxTcpConns => "access.user_max_tcp_conns",
|
||||||
Self::UserExpirations => "access.user_expirations",
|
Self::UserExpirations => "access.user_expirations",
|
||||||
@@ -135,6 +137,15 @@ fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result<St
|
|||||||
.collect();
|
.collect();
|
||||||
serialize_table_body(&rows)?
|
serialize_table_body(&rows)?
|
||||||
}
|
}
|
||||||
|
AccessSection::UserEnabled => {
|
||||||
|
let rows: BTreeMap<String, bool> = cfg
|
||||||
|
.access
|
||||||
|
.user_enabled
|
||||||
|
.iter()
|
||||||
|
.map(|(key, value)| (key.clone(), *value))
|
||||||
|
.collect();
|
||||||
|
serialize_table_body(&rows)?
|
||||||
|
}
|
||||||
AccessSection::UserAdTags => {
|
AccessSection::UserAdTags => {
|
||||||
let rows: BTreeMap<String, String> = cfg
|
let rows: BTreeMap<String, String> = cfg
|
||||||
.access
|
.access
|
||||||
@@ -204,6 +215,7 @@ fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result<St
|
|||||||
fn access_section_is_empty(cfg: &ProxyConfig, section: AccessSection) -> bool {
|
fn access_section_is_empty(cfg: &ProxyConfig, section: AccessSection) -> bool {
|
||||||
match section {
|
match section {
|
||||||
AccessSection::Users => cfg.access.users.is_empty(),
|
AccessSection::Users => cfg.access.users.is_empty(),
|
||||||
|
AccessSection::UserEnabled => cfg.access.user_enabled.is_empty(),
|
||||||
AccessSection::UserAdTags => cfg.access.user_ad_tags.is_empty(),
|
AccessSection::UserAdTags => cfg.access.user_ad_tags.is_empty(),
|
||||||
AccessSection::UserMaxTcpConns => cfg.access.user_max_tcp_conns.is_empty(),
|
AccessSection::UserMaxTcpConns => cfg.access.user_max_tcp_conns.is_empty(),
|
||||||
AccessSection::UserExpirations => cfg.access.user_expirations.is_empty(),
|
AccessSection::UserExpirations => cfg.access.user_expirations.is_empty(),
|
||||||
|
|||||||
153
src/api/mod.rs
153
src/api/mod.rs
@@ -22,6 +22,7 @@ use tracing::{debug, info, warn};
|
|||||||
use crate::config::{ApiGrayAction, ProxyConfig};
|
use crate::config::{ApiGrayAction, ProxyConfig};
|
||||||
use crate::ip_tracker::UserIpTracker;
|
use crate::ip_tracker::UserIpTracker;
|
||||||
use crate::proxy::route_mode::RouteRuntimeController;
|
use crate::proxy::route_mode::RouteRuntimeController;
|
||||||
|
use crate::proxy::shared_state::ProxySharedState;
|
||||||
use crate::startup::StartupTracker;
|
use crate::startup::StartupTracker;
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::transport::UpstreamManager;
|
use crate::transport::UpstreamManager;
|
||||||
@@ -51,6 +52,7 @@ use model::{
|
|||||||
PatchUserRequest, ResetUserQuotaResponse, RotateSecretRequest, SummaryData, UserActiveIps,
|
PatchUserRequest, ResetUserQuotaResponse, RotateSecretRequest, SummaryData, UserActiveIps,
|
||||||
is_valid_username,
|
is_valid_username,
|
||||||
};
|
};
|
||||||
|
use patch::Patch;
|
||||||
use runtime_edge::{
|
use runtime_edge::{
|
||||||
EdgeConnectionsCacheEntry, build_runtime_connections_summary_data,
|
EdgeConnectionsCacheEntry, build_runtime_connections_summary_data,
|
||||||
build_runtime_events_recent_data,
|
build_runtime_events_recent_data,
|
||||||
@@ -71,7 +73,8 @@ use runtime_zero::{
|
|||||||
build_system_info_data,
|
build_system_info_data,
|
||||||
};
|
};
|
||||||
use users::{
|
use users::{
|
||||||
build_user_quota_list, create_user, delete_user, patch_user, rotate_secret, users_from_config,
|
build_user_quota_list, create_user, delete_user, patch_user, rotate_secret, set_user_enabled,
|
||||||
|
users_from_config,
|
||||||
};
|
};
|
||||||
|
|
||||||
const API_MAX_CONTROL_CONNECTIONS: usize = 1024;
|
const API_MAX_CONTROL_CONNECTIONS: usize = 1024;
|
||||||
@@ -107,6 +110,7 @@ pub(super) struct ApiShared {
|
|||||||
pub(super) runtime_state: Arc<ApiRuntimeState>,
|
pub(super) runtime_state: Arc<ApiRuntimeState>,
|
||||||
pub(super) startup_tracker: Arc<StartupTracker>,
|
pub(super) startup_tracker: Arc<StartupTracker>,
|
||||||
pub(super) route_runtime: Arc<RouteRuntimeController>,
|
pub(super) route_runtime: Arc<RouteRuntimeController>,
|
||||||
|
pub(super) proxy_shared: Arc<ProxySharedState>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ApiShared {
|
impl ApiShared {
|
||||||
@@ -171,6 +175,8 @@ fn allowed_methods_for_path(path: &str) -> Option<&'static str> {
|
|||||||
"/v1/users" => Some(ALLOW_GET_POST),
|
"/v1/users" => Some(ALLOW_GET_POST),
|
||||||
_ if user_action_route_matches(path, "/reset-quota") => Some(ALLOW_POST),
|
_ if user_action_route_matches(path, "/reset-quota") => Some(ALLOW_POST),
|
||||||
_ if user_action_route_matches(path, "/rotate-secret") => Some(ALLOW_POST),
|
_ if user_action_route_matches(path, "/rotate-secret") => Some(ALLOW_POST),
|
||||||
|
_ if user_action_route_matches(path, "/enable") => Some(ALLOW_POST),
|
||||||
|
_ if user_action_route_matches(path, "/disable") => Some(ALLOW_POST),
|
||||||
_ if path
|
_ if path
|
||||||
.strip_prefix("/v1/users/")
|
.strip_prefix("/v1/users/")
|
||||||
.map(|user| !user.is_empty() && !user.contains('/'))
|
.map(|user| !user.is_empty() && !user.contains('/'))
|
||||||
@@ -188,6 +194,7 @@ pub async fn serve(
|
|||||||
ip_tracker: Arc<UserIpTracker>,
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
|
me_pool: Arc<RwLock<Option<Arc<MePool>>>>,
|
||||||
route_runtime: Arc<RouteRuntimeController>,
|
route_runtime: Arc<RouteRuntimeController>,
|
||||||
|
proxy_shared: Arc<ProxySharedState>,
|
||||||
upstream_manager: Arc<UpstreamManager>,
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
||||||
admission_rx: watch::Receiver<bool>,
|
admission_rx: watch::Receiver<bool>,
|
||||||
@@ -237,6 +244,7 @@ pub async fn serve(
|
|||||||
runtime_state: runtime_state.clone(),
|
runtime_state: runtime_state.clone(),
|
||||||
startup_tracker,
|
startup_tracker,
|
||||||
route_runtime,
|
route_runtime,
|
||||||
|
proxy_shared,
|
||||||
});
|
});
|
||||||
|
|
||||||
spawn_runtime_watchers(
|
spawn_runtime_watchers(
|
||||||
@@ -582,6 +590,7 @@ async fn handle(
|
|||||||
}
|
}
|
||||||
let expected_revision = parse_if_match(req.headers());
|
let expected_revision = parse_if_match(req.headers());
|
||||||
let body = read_json::<CreateUserRequest>(req.into_body(), body_limit).await?;
|
let body = read_json::<CreateUserRequest>(req.into_body(), body_limit).await?;
|
||||||
|
let requested_enabled = body.enabled;
|
||||||
let result = create_user(body, expected_revision, &shared).await;
|
let result = create_user(body, expected_revision, &shared).await;
|
||||||
let (mut data, revision) = match result {
|
let (mut data, revision) = match result {
|
||||||
Ok(ok) => ok,
|
Ok(ok) => ok,
|
||||||
@@ -594,6 +603,25 @@ async fn handle(
|
|||||||
};
|
};
|
||||||
let runtime_cfg = config_rx.borrow().clone();
|
let runtime_cfg = config_rx.borrow().clone();
|
||||||
data.user.in_runtime = runtime_cfg.access.users.contains_key(&data.user.username);
|
data.user.in_runtime = runtime_cfg.access.users.contains_key(&data.user.username);
|
||||||
|
if let Some(enabled) = requested_enabled {
|
||||||
|
shared
|
||||||
|
.proxy_shared
|
||||||
|
.set_user_enabled(&data.user.username, enabled);
|
||||||
|
if !enabled {
|
||||||
|
let cancelled = shared
|
||||||
|
.proxy_shared
|
||||||
|
.cancel_user_sessions(&data.user.username);
|
||||||
|
if cancelled > 0 {
|
||||||
|
shared.runtime_events.record(
|
||||||
|
"api.user.disable.runtime",
|
||||||
|
format!(
|
||||||
|
"username={} cancelled_sessions={}",
|
||||||
|
data.user.username, cancelled
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
shared.runtime_events.record(
|
shared.runtime_events.record(
|
||||||
"api.user.create.ok",
|
"api.user.create.ok",
|
||||||
format!("username={}", data.user.username),
|
format!("username={}", data.user.username),
|
||||||
@@ -606,6 +634,99 @@ async fn handle(
|
|||||||
Ok(success_response(status, data, revision))
|
Ok(success_response(status, data, revision))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
if method == Method::POST
|
||||||
|
&& let Some(base_user) = normalized_path
|
||||||
|
.strip_prefix("/v1/users/")
|
||||||
|
.and_then(|path| path.strip_suffix("/enable"))
|
||||||
|
&& !base_user.is_empty()
|
||||||
|
&& !base_user.contains('/')
|
||||||
|
{
|
||||||
|
let base_user = parse_route_username(base_user)?;
|
||||||
|
if api_cfg.read_only {
|
||||||
|
return Ok(error_response(
|
||||||
|
request_id,
|
||||||
|
ApiFailure::new(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"read_only",
|
||||||
|
"API runs in read-only mode",
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let expected_revision = parse_if_match(req.headers());
|
||||||
|
let result =
|
||||||
|
set_user_enabled(base_user, true, expected_revision, &shared).await;
|
||||||
|
let (mut data, revision) = match result {
|
||||||
|
Ok(ok) => ok,
|
||||||
|
Err(error) => {
|
||||||
|
shared.runtime_events.record(
|
||||||
|
"api.user.enable.failed",
|
||||||
|
format!("username={} code={}", base_user, error.code),
|
||||||
|
);
|
||||||
|
return Err(error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let runtime_cfg = config_rx.borrow().clone();
|
||||||
|
data.in_runtime = runtime_cfg.access.users.contains_key(&data.username);
|
||||||
|
shared.proxy_shared.set_user_enabled(base_user, true);
|
||||||
|
shared
|
||||||
|
.runtime_events
|
||||||
|
.record("api.user.enable.ok", format!("username={}", base_user));
|
||||||
|
let status = if data.in_runtime {
|
||||||
|
StatusCode::OK
|
||||||
|
} else {
|
||||||
|
StatusCode::ACCEPTED
|
||||||
|
};
|
||||||
|
return Ok(success_response(status, data, revision));
|
||||||
|
}
|
||||||
|
if method == Method::POST
|
||||||
|
&& let Some(base_user) = normalized_path
|
||||||
|
.strip_prefix("/v1/users/")
|
||||||
|
.and_then(|path| path.strip_suffix("/disable"))
|
||||||
|
&& !base_user.is_empty()
|
||||||
|
&& !base_user.contains('/')
|
||||||
|
{
|
||||||
|
let base_user = parse_route_username(base_user)?;
|
||||||
|
if api_cfg.read_only {
|
||||||
|
return Ok(error_response(
|
||||||
|
request_id,
|
||||||
|
ApiFailure::new(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"read_only",
|
||||||
|
"API runs in read-only mode",
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let expected_revision = parse_if_match(req.headers());
|
||||||
|
let result =
|
||||||
|
set_user_enabled(base_user, false, expected_revision, &shared).await;
|
||||||
|
let (mut data, revision) = match result {
|
||||||
|
Ok(ok) => ok,
|
||||||
|
Err(error) => {
|
||||||
|
shared.runtime_events.record(
|
||||||
|
"api.user.disable.failed",
|
||||||
|
format!("username={} code={}", base_user, error.code),
|
||||||
|
);
|
||||||
|
return Err(error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let runtime_cfg = config_rx.borrow().clone();
|
||||||
|
data.in_runtime = runtime_cfg.access.users.contains_key(&data.username);
|
||||||
|
let newly_disabled = shared.proxy_shared.set_user_enabled(base_user, false);
|
||||||
|
let cancelled = shared.proxy_shared.cancel_user_sessions(base_user);
|
||||||
|
shared.runtime_events.record(
|
||||||
|
"api.user.disable.ok",
|
||||||
|
format!(
|
||||||
|
"username={} newly_disabled={} cancelled_sessions={}",
|
||||||
|
base_user, newly_disabled, cancelled
|
||||||
|
),
|
||||||
|
);
|
||||||
|
let status = if data.in_runtime {
|
||||||
|
StatusCode::OK
|
||||||
|
} else {
|
||||||
|
StatusCode::ACCEPTED
|
||||||
|
};
|
||||||
|
return Ok(success_response(status, data, revision));
|
||||||
|
}
|
||||||
if method == Method::POST
|
if method == Method::POST
|
||||||
&& let Some(user) = normalized_path
|
&& let Some(user) = normalized_path
|
||||||
.strip_prefix("/v1/users/")
|
.strip_prefix("/v1/users/")
|
||||||
@@ -763,6 +884,11 @@ async fn handle(
|
|||||||
let expected_revision = parse_if_match(req.headers());
|
let expected_revision = parse_if_match(req.headers());
|
||||||
let body =
|
let body =
|
||||||
read_json::<PatchUserRequest>(req.into_body(), body_limit).await?;
|
read_json::<PatchUserRequest>(req.into_body(), body_limit).await?;
|
||||||
|
let enabled_update = match &body.enabled {
|
||||||
|
Patch::Unchanged => None,
|
||||||
|
Patch::Remove => Some(true),
|
||||||
|
Patch::Set(enabled) => Some(*enabled),
|
||||||
|
};
|
||||||
let result = patch_user(user, body, expected_revision, &shared).await;
|
let result = patch_user(user, body, expected_revision, &shared).await;
|
||||||
let (mut data, revision) = match result {
|
let (mut data, revision) = match result {
|
||||||
Ok(ok) => ok,
|
Ok(ok) => ok,
|
||||||
@@ -776,6 +902,22 @@ async fn handle(
|
|||||||
};
|
};
|
||||||
let runtime_cfg = config_rx.borrow().clone();
|
let runtime_cfg = config_rx.borrow().clone();
|
||||||
data.in_runtime = runtime_cfg.access.users.contains_key(&data.username);
|
data.in_runtime = runtime_cfg.access.users.contains_key(&data.username);
|
||||||
|
if let Some(enabled) = enabled_update {
|
||||||
|
shared
|
||||||
|
.proxy_shared
|
||||||
|
.set_user_enabled(&data.username, enabled);
|
||||||
|
if !enabled {
|
||||||
|
let cancelled =
|
||||||
|
shared.proxy_shared.cancel_user_sessions(&data.username);
|
||||||
|
shared.runtime_events.record(
|
||||||
|
"api.user.disable.runtime",
|
||||||
|
format!(
|
||||||
|
"username={} cancelled_sessions={}",
|
||||||
|
data.username, cancelled
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
shared
|
shared
|
||||||
.runtime_events
|
.runtime_events
|
||||||
.record("api.user.patch.ok", format!("username={}", data.username));
|
.record("api.user.patch.ok", format!("username={}", data.username));
|
||||||
@@ -809,9 +951,12 @@ async fn handle(
|
|||||||
return Err(error);
|
return Err(error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
shared
|
shared.proxy_shared.set_user_enabled(&deleted_user, true);
|
||||||
.runtime_events
|
let cancelled = shared.proxy_shared.cancel_user_sessions(&deleted_user);
|
||||||
.record("api.user.delete.ok", format!("username={}", deleted_user));
|
shared.runtime_events.record(
|
||||||
|
"api.user.delete.ok",
|
||||||
|
format!("username={} cancelled_sessions={}", deleted_user, cancelled),
|
||||||
|
);
|
||||||
let runtime_cfg = config_rx.borrow().clone();
|
let runtime_cfg = config_rx.borrow().clone();
|
||||||
let in_runtime = runtime_cfg.access.users.contains_key(&deleted_user);
|
let in_runtime = runtime_cfg.access.users.contains_key(&deleted_user);
|
||||||
let response = DeleteUserResponse {
|
let response = DeleteUserResponse {
|
||||||
|
|||||||
@@ -479,6 +479,7 @@ pub(super) struct TlsDomainLink {
|
|||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub(super) struct UserInfo {
|
pub(super) struct UserInfo {
|
||||||
pub(super) username: String,
|
pub(super) username: String,
|
||||||
|
pub(super) enabled: bool,
|
||||||
pub(super) in_runtime: bool,
|
pub(super) in_runtime: bool,
|
||||||
pub(super) user_ad_tag: Option<String>,
|
pub(super) user_ad_tag: Option<String>,
|
||||||
pub(super) max_tcp_conns: Option<usize>,
|
pub(super) max_tcp_conns: Option<usize>,
|
||||||
@@ -545,6 +546,7 @@ pub(super) struct CreateUserRequest {
|
|||||||
pub(super) rate_limit_up_bps: Option<u64>,
|
pub(super) rate_limit_up_bps: Option<u64>,
|
||||||
pub(super) rate_limit_down_bps: Option<u64>,
|
pub(super) rate_limit_down_bps: Option<u64>,
|
||||||
pub(super) max_unique_ips: Option<usize>,
|
pub(super) max_unique_ips: Option<usize>,
|
||||||
|
pub(super) enabled: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -564,6 +566,8 @@ pub(super) struct PatchUserRequest {
|
|||||||
pub(super) rate_limit_down_bps: Patch<u64>,
|
pub(super) rate_limit_down_bps: Patch<u64>,
|
||||||
#[serde(default, deserialize_with = "patch_field")]
|
#[serde(default, deserialize_with = "patch_field")]
|
||||||
pub(super) max_unique_ips: Patch<usize>,
|
pub(super) max_unique_ips: Patch<usize>,
|
||||||
|
#[serde(default, deserialize_with = "patch_field")]
|
||||||
|
pub(super) enabled: Patch<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Deserialize)]
|
#[derive(Default, Deserialize)]
|
||||||
|
|||||||
111
src/api/users.rs
111
src/api/users.rs
@@ -32,6 +32,7 @@ pub(super) async fn create_user(
|
|||||||
let touches_user_rate_limits =
|
let touches_user_rate_limits =
|
||||||
body.rate_limit_up_bps.is_some() || body.rate_limit_down_bps.is_some();
|
body.rate_limit_up_bps.is_some() || body.rate_limit_down_bps.is_some();
|
||||||
let touches_user_max_unique_ips = body.max_unique_ips.is_some();
|
let touches_user_max_unique_ips = body.max_unique_ips.is_some();
|
||||||
|
let touches_user_enabled = matches!(body.enabled, Some(false));
|
||||||
|
|
||||||
if !is_valid_username(&body.username) {
|
if !is_valid_username(&body.username) {
|
||||||
return Err(ApiFailure::bad_request(
|
return Err(ApiFailure::bad_request(
|
||||||
@@ -111,6 +112,9 @@ pub(super) async fn create_user(
|
|||||||
.user_max_unique_ips
|
.user_max_unique_ips
|
||||||
.insert(body.username.clone(), limit);
|
.insert(body.username.clone(), limit);
|
||||||
}
|
}
|
||||||
|
if matches!(body.enabled, Some(false)) {
|
||||||
|
cfg.access.user_enabled.insert(body.username.clone(), false);
|
||||||
|
}
|
||||||
|
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
||||||
@@ -134,6 +138,9 @@ pub(super) async fn create_user(
|
|||||||
if touches_user_max_unique_ips {
|
if touches_user_max_unique_ips {
|
||||||
touched_sections.push(AccessSection::UserMaxUniqueIps);
|
touched_sections.push(AccessSection::UserMaxUniqueIps);
|
||||||
}
|
}
|
||||||
|
if touches_user_enabled {
|
||||||
|
touched_sections.push(AccessSection::UserEnabled);
|
||||||
|
}
|
||||||
|
|
||||||
let revision =
|
let revision =
|
||||||
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
|
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
|
||||||
@@ -161,6 +168,7 @@ pub(super) async fn create_user(
|
|||||||
.find(|entry| entry.username == body.username)
|
.find(|entry| entry.username == body.username)
|
||||||
.unwrap_or(UserInfo {
|
.unwrap_or(UserInfo {
|
||||||
username: body.username.clone(),
|
username: body.username.clone(),
|
||||||
|
enabled: cfg.access.is_user_enabled(&body.username),
|
||||||
in_runtime: false,
|
in_runtime: false,
|
||||||
user_ad_tag: None,
|
user_ad_tag: None,
|
||||||
max_tcp_conns: cfg
|
max_tcp_conns: cfg
|
||||||
@@ -202,6 +210,7 @@ pub(super) async fn patch_user(
|
|||||||
let touches_user_rate_limits = !matches!(&body.rate_limit_up_bps, Patch::Unchanged)
|
let touches_user_rate_limits = !matches!(&body.rate_limit_up_bps, Patch::Unchanged)
|
||||||
|| !matches!(&body.rate_limit_down_bps, Patch::Unchanged);
|
|| !matches!(&body.rate_limit_down_bps, Patch::Unchanged);
|
||||||
let touches_user_max_unique_ips = !matches!(&body.max_unique_ips, Patch::Unchanged);
|
let touches_user_max_unique_ips = !matches!(&body.max_unique_ips, Patch::Unchanged);
|
||||||
|
let touches_user_enabled = !matches!(&body.enabled, Patch::Unchanged);
|
||||||
|
|
||||||
if let Some(secret) = body.secret.as_ref()
|
if let Some(secret) = body.secret.as_ref()
|
||||||
&& !is_valid_user_secret(secret)
|
&& !is_valid_user_secret(secret)
|
||||||
@@ -313,6 +322,15 @@ pub(super) async fn patch_user(
|
|||||||
Some(Some(limit))
|
Some(Some(limit))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
match body.enabled {
|
||||||
|
Patch::Unchanged => {}
|
||||||
|
Patch::Remove | Patch::Set(true) => {
|
||||||
|
cfg.access.user_enabled.remove(user);
|
||||||
|
}
|
||||||
|
Patch::Set(false) => {
|
||||||
|
cfg.access.user_enabled.insert(user.to_string(), false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
||||||
@@ -339,6 +357,9 @@ pub(super) async fn patch_user(
|
|||||||
if touches_user_max_unique_ips {
|
if touches_user_max_unique_ips {
|
||||||
touched_sections.push(AccessSection::UserMaxUniqueIps);
|
touched_sections.push(AccessSection::UserMaxUniqueIps);
|
||||||
}
|
}
|
||||||
|
if touches_user_enabled {
|
||||||
|
touched_sections.push(AccessSection::UserEnabled);
|
||||||
|
}
|
||||||
|
|
||||||
let revision = if touched_sections.is_empty() {
|
let revision = if touched_sections.is_empty() {
|
||||||
current_revision(&shared.config_path).await?
|
current_revision(&shared.config_path).await?
|
||||||
@@ -399,6 +420,7 @@ pub(super) async fn rotate_secret(
|
|||||||
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
||||||
let touched_sections = [
|
let touched_sections = [
|
||||||
AccessSection::Users,
|
AccessSection::Users,
|
||||||
|
AccessSection::UserEnabled,
|
||||||
AccessSection::UserAdTags,
|
AccessSection::UserAdTags,
|
||||||
AccessSection::UserMaxTcpConns,
|
AccessSection::UserMaxTcpConns,
|
||||||
AccessSection::UserExpirations,
|
AccessSection::UserExpirations,
|
||||||
@@ -434,6 +456,55 @@ pub(super) async fn rotate_secret(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) async fn set_user_enabled(
|
||||||
|
user: &str,
|
||||||
|
enabled: bool,
|
||||||
|
expected_revision: Option<String>,
|
||||||
|
shared: &ApiShared,
|
||||||
|
) -> Result<(UserInfo, String), ApiFailure> {
|
||||||
|
let _guard = shared.mutation_lock.lock().await;
|
||||||
|
let mut cfg = load_config_from_disk(&shared.config_path).await?;
|
||||||
|
ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?;
|
||||||
|
|
||||||
|
if !cfg.access.users.contains_key(user) {
|
||||||
|
return Err(ApiFailure::new(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
"not_found",
|
||||||
|
"User not found",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if enabled {
|
||||||
|
cfg.access.user_enabled.remove(user);
|
||||||
|
} else {
|
||||||
|
cfg.access.user_enabled.insert(user.to_string(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.validate()
|
||||||
|
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
||||||
|
let revision =
|
||||||
|
save_access_sections_to_disk(&shared.config_path, &cfg, &[AccessSection::UserEnabled])
|
||||||
|
.await?;
|
||||||
|
drop(_guard);
|
||||||
|
|
||||||
|
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
|
||||||
|
let users = users_from_config(
|
||||||
|
&cfg,
|
||||||
|
&shared.stats,
|
||||||
|
&shared.ip_tracker,
|
||||||
|
detected_ip_v4,
|
||||||
|
detected_ip_v6,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let user_info = users
|
||||||
|
.into_iter()
|
||||||
|
.find(|entry| entry.username == user)
|
||||||
|
.ok_or_else(|| ApiFailure::internal("failed to build updated user view"))?;
|
||||||
|
|
||||||
|
Ok((user_info, revision))
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) async fn delete_user(
|
pub(super) async fn delete_user(
|
||||||
user: &str,
|
user: &str,
|
||||||
expected_revision: Option<String>,
|
expected_revision: Option<String>,
|
||||||
@@ -459,6 +530,7 @@ pub(super) async fn delete_user(
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg.access.users.remove(user);
|
cfg.access.users.remove(user);
|
||||||
|
cfg.access.user_enabled.remove(user);
|
||||||
cfg.access.user_ad_tags.remove(user);
|
cfg.access.user_ad_tags.remove(user);
|
||||||
cfg.access.user_max_tcp_conns.remove(user);
|
cfg.access.user_max_tcp_conns.remove(user);
|
||||||
cfg.access.user_expirations.remove(user);
|
cfg.access.user_expirations.remove(user);
|
||||||
@@ -470,6 +542,7 @@ pub(super) async fn delete_user(
|
|||||||
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
||||||
let touched_sections = [
|
let touched_sections = [
|
||||||
AccessSection::Users,
|
AccessSection::Users,
|
||||||
|
AccessSection::UserEnabled,
|
||||||
AccessSection::UserAdTags,
|
AccessSection::UserAdTags,
|
||||||
AccessSection::UserMaxTcpConns,
|
AccessSection::UserMaxTcpConns,
|
||||||
AccessSection::UserExpirations,
|
AccessSection::UserExpirations,
|
||||||
@@ -518,6 +591,7 @@ pub(super) async fn users_from_config(
|
|||||||
})
|
})
|
||||||
.unwrap_or_else(empty_user_links);
|
.unwrap_or_else(empty_user_links);
|
||||||
users.push(UserInfo {
|
users.push(UserInfo {
|
||||||
|
enabled: cfg.access.is_user_enabled(&username),
|
||||||
in_runtime: runtime_cfg
|
in_runtime: runtime_cfg
|
||||||
.map(|runtime| runtime.access.users.contains_key(&username))
|
.map(|runtime| runtime.access.users.contains_key(&username))
|
||||||
.unwrap_or(false),
|
.unwrap_or(false),
|
||||||
@@ -876,6 +950,43 @@ mod tests {
|
|||||||
assert_eq!(alice.rate_limit_down_bps, None);
|
assert_eq!(alice.rate_limit_down_bps, None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn users_from_config_reports_user_enabled_default_and_override() {
|
||||||
|
let mut cfg = ProxyConfig::default();
|
||||||
|
cfg.access.users.insert(
|
||||||
|
"alice".to_string(),
|
||||||
|
"0123456789abcdef0123456789abcdef".to_string(),
|
||||||
|
);
|
||||||
|
cfg.access.users.insert(
|
||||||
|
"bob".to_string(),
|
||||||
|
"fedcba9876543210fedcba9876543210".to_string(),
|
||||||
|
);
|
||||||
|
cfg.access.user_enabled.insert("bob".to_string(), false);
|
||||||
|
|
||||||
|
let stats = Stats::new();
|
||||||
|
let tracker = UserIpTracker::new();
|
||||||
|
let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await;
|
||||||
|
let alice = users
|
||||||
|
.iter()
|
||||||
|
.find(|entry| entry.username == "alice")
|
||||||
|
.expect("alice must be present");
|
||||||
|
let bob = users
|
||||||
|
.iter()
|
||||||
|
.find(|entry| entry.username == "bob")
|
||||||
|
.expect("bob must be present");
|
||||||
|
|
||||||
|
assert!(alice.enabled);
|
||||||
|
assert!(!bob.enabled);
|
||||||
|
|
||||||
|
cfg.access.user_enabled.insert("bob".to_string(), true);
|
||||||
|
let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await;
|
||||||
|
let bob = users
|
||||||
|
.iter()
|
||||||
|
.find(|entry| entry.username == "bob")
|
||||||
|
.expect("bob must be present");
|
||||||
|
assert!(bob.enabled);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn users_from_config_marks_runtime_membership_when_snapshot_is_provided() {
|
async fn users_from_config_marks_runtime_membership_when_snapshot_is_provided() {
|
||||||
let mut disk_cfg = ProxyConfig::default();
|
let mut disk_cfg = ProxyConfig::default();
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ pub struct HotFields {
|
|||||||
pub me_admission_poll_ms: u64,
|
pub me_admission_poll_ms: u64,
|
||||||
pub me_warn_rate_limit_ms: u64,
|
pub me_warn_rate_limit_ms: u64,
|
||||||
pub users: std::collections::HashMap<String, String>,
|
pub users: std::collections::HashMap<String, String>,
|
||||||
|
pub user_enabled: std::collections::HashMap<String, bool>,
|
||||||
pub user_ad_tags: std::collections::HashMap<String, String>,
|
pub user_ad_tags: std::collections::HashMap<String, String>,
|
||||||
pub user_max_tcp_conns: std::collections::HashMap<String, usize>,
|
pub user_max_tcp_conns: std::collections::HashMap<String, usize>,
|
||||||
pub user_max_tcp_conns_global_each: usize,
|
pub user_max_tcp_conns_global_each: usize,
|
||||||
@@ -247,6 +248,7 @@ impl HotFields {
|
|||||||
me_admission_poll_ms: cfg.general.me_admission_poll_ms,
|
me_admission_poll_ms: cfg.general.me_admission_poll_ms,
|
||||||
me_warn_rate_limit_ms: cfg.general.me_warn_rate_limit_ms,
|
me_warn_rate_limit_ms: cfg.general.me_warn_rate_limit_ms,
|
||||||
users: cfg.access.users.clone(),
|
users: cfg.access.users.clone(),
|
||||||
|
user_enabled: cfg.access.user_enabled.clone(),
|
||||||
user_ad_tags: cfg.access.user_ad_tags.clone(),
|
user_ad_tags: cfg.access.user_ad_tags.clone(),
|
||||||
user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(),
|
user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(),
|
||||||
user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each,
|
user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each,
|
||||||
@@ -551,6 +553,7 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
|
|||||||
cfg.general.me_warn_rate_limit_ms = new.general.me_warn_rate_limit_ms;
|
cfg.general.me_warn_rate_limit_ms = new.general.me_warn_rate_limit_ms;
|
||||||
|
|
||||||
cfg.access.users = new.access.users.clone();
|
cfg.access.users = new.access.users.clone();
|
||||||
|
cfg.access.user_enabled = new.access.user_enabled.clone();
|
||||||
cfg.access.user_ad_tags = new.access.user_ad_tags.clone();
|
cfg.access.user_ad_tags = new.access.user_ad_tags.clone();
|
||||||
cfg.access.user_max_tcp_conns = new.access.user_max_tcp_conns.clone();
|
cfg.access.user_max_tcp_conns = new.access.user_max_tcp_conns.clone();
|
||||||
cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each;
|
cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each;
|
||||||
@@ -1178,6 +1181,16 @@ fn log_changes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if old_hot.user_enabled != new_hot.user_enabled {
|
||||||
|
info!(
|
||||||
|
"config reload: user_enabled updated ({} disabled overrides)",
|
||||||
|
new_hot
|
||||||
|
.user_enabled
|
||||||
|
.values()
|
||||||
|
.filter(|enabled| !**enabled)
|
||||||
|
.count()
|
||||||
|
);
|
||||||
|
}
|
||||||
if old_hot.user_max_tcp_conns != new_hot.user_max_tcp_conns {
|
if old_hot.user_max_tcp_conns != new_hot.user_max_tcp_conns {
|
||||||
info!(
|
info!(
|
||||||
"config reload: user_max_tcp_conns updated ({} entries)",
|
"config reload: user_max_tcp_conns updated ({} entries)",
|
||||||
|
|||||||
@@ -411,6 +411,7 @@ const TLS_FETCH_CONFIG_KEYS: &[&str] = &[
|
|||||||
|
|
||||||
const ACCESS_CONFIG_KEYS: &[&str] = &[
|
const ACCESS_CONFIG_KEYS: &[&str] = &[
|
||||||
"users",
|
"users",
|
||||||
|
"user_enabled",
|
||||||
"user_ad_tags",
|
"user_ad_tags",
|
||||||
"user_max_tcp_conns",
|
"user_max_tcp_conns",
|
||||||
"user_max_tcp_conns_global_each",
|
"user_max_tcp_conns_global_each",
|
||||||
|
|||||||
@@ -1892,6 +1892,9 @@ pub struct AccessConfig {
|
|||||||
#[serde(default = "default_access_users")]
|
#[serde(default = "default_access_users")]
|
||||||
pub users: HashMap<String, String>,
|
pub users: HashMap<String, String>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub user_enabled: HashMap<String, bool>,
|
||||||
|
|
||||||
/// Per-user ad_tag (32 hex chars from @MTProxybot).
|
/// Per-user ad_tag (32 hex chars from @MTProxybot).
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub user_ad_tags: HashMap<String, String>,
|
pub user_ad_tags: HashMap<String, String>,
|
||||||
@@ -1963,6 +1966,7 @@ impl Default for AccessConfig {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
users: default_access_users(),
|
users: default_access_users(),
|
||||||
|
user_enabled: HashMap::new(),
|
||||||
user_ad_tags: HashMap::new(),
|
user_ad_tags: HashMap::new(),
|
||||||
user_max_tcp_conns: HashMap::new(),
|
user_max_tcp_conns: HashMap::new(),
|
||||||
user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(),
|
user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(),
|
||||||
@@ -1983,6 +1987,10 @@ impl Default for AccessConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AccessConfig {
|
impl AccessConfig {
|
||||||
|
pub fn is_user_enabled(&self, username: &str) -> bool {
|
||||||
|
self.user_enabled.get(username).copied().unwrap_or(true)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if `ip` is contained in any CIDR listed for `username` under `user_source_deny`.
|
/// Returns true if `ip` is contained in any CIDR listed for `username` under `user_source_deny`.
|
||||||
pub fn is_user_source_ip_denied(&self, username: &str, ip: IpAddr) -> bool {
|
pub fn is_user_source_ip_denied(&self, username: &str, ip: IpAddr) -> bool {
|
||||||
self.user_source_deny
|
self.user_source_deny
|
||||||
|
|||||||
@@ -245,6 +245,9 @@ pub enum ProxyError {
|
|||||||
InvalidSecret { user: String, reason: String },
|
InvalidSecret { user: String, reason: String },
|
||||||
|
|
||||||
// ============= User Errors =============
|
// ============= User Errors =============
|
||||||
|
#[error("User {user} disabled")]
|
||||||
|
UserDisabled { user: String },
|
||||||
|
|
||||||
#[error("User {user} expired")]
|
#[error("User {user} expired")]
|
||||||
UserExpired { user: String },
|
UserExpired { user: String },
|
||||||
|
|
||||||
|
|||||||
@@ -464,6 +464,12 @@ async fn run_telemt_core(
|
|||||||
config.network.dns_overrides.len()
|
config.network.dns_overrides.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
let shared_state = ProxySharedState::new();
|
||||||
|
shared_state.apply_user_enabled_config(&config.access.user_enabled);
|
||||||
|
shared_state.traffic_limiter.apply_policy(
|
||||||
|
config.access.user_rate_limits.clone(),
|
||||||
|
config.access.cidr_rate_limits.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
let (api_config_tx, api_config_rx) = watch::channel(Arc::new(config.clone()));
|
let (api_config_tx, api_config_rx) = watch::channel(Arc::new(config.clone()));
|
||||||
let (detected_ips_tx, detected_ips_rx) = watch::channel((None::<IpAddr>, None::<IpAddr>));
|
let (detected_ips_tx, detected_ips_rx) = watch::channel((None::<IpAddr>, None::<IpAddr>));
|
||||||
@@ -502,6 +508,7 @@ async fn run_telemt_core(
|
|||||||
let me_pool_api = api_me_pool.clone();
|
let me_pool_api = api_me_pool.clone();
|
||||||
let upstream_manager_api = upstream_manager.clone();
|
let upstream_manager_api = upstream_manager.clone();
|
||||||
let route_runtime_api = route_runtime.clone();
|
let route_runtime_api = route_runtime.clone();
|
||||||
|
let proxy_shared_api = shared_state.clone();
|
||||||
let config_rx_api = api_config_rx.clone();
|
let config_rx_api = api_config_rx.clone();
|
||||||
let admission_rx_api = admission_rx.clone();
|
let admission_rx_api = admission_rx.clone();
|
||||||
let config_path_api = config_path.clone();
|
let config_path_api = config_path.clone();
|
||||||
@@ -515,6 +522,7 @@ async fn run_telemt_core(
|
|||||||
ip_tracker_api,
|
ip_tracker_api,
|
||||||
me_pool_api,
|
me_pool_api,
|
||||||
route_runtime_api,
|
route_runtime_api,
|
||||||
|
proxy_shared_api,
|
||||||
upstream_manager_api,
|
upstream_manager_api,
|
||||||
config_rx_api,
|
config_rx_api,
|
||||||
admission_rx_api,
|
admission_rx_api,
|
||||||
@@ -732,11 +740,6 @@ async fn run_telemt_core(
|
|||||||
));
|
));
|
||||||
|
|
||||||
let buffer_pool = Arc::new(BufferPool::with_config(64 * 1024, 4096));
|
let buffer_pool = Arc::new(BufferPool::with_config(64 * 1024, 4096));
|
||||||
let shared_state = ProxySharedState::new();
|
|
||||||
shared_state.traffic_limiter.apply_policy(
|
|
||||||
config.access.user_rate_limits.clone(),
|
|
||||||
config.access.cidr_rate_limits.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if direct_first_startup {
|
if direct_first_startup {
|
||||||
startup_tracker
|
startup_tracker
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use std::path::Path;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, info, warn};
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use tracing_subscriber::reload;
|
use tracing_subscriber::reload;
|
||||||
|
|
||||||
@@ -234,6 +234,27 @@ pub(crate) async fn spawn_runtime_tasks(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let shared_user_enabled = shared_state.clone();
|
||||||
|
let mut config_rx_user_enabled = config_rx.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
if config_rx_user_enabled.changed().await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let cfg = config_rx_user_enabled.borrow_and_update().clone();
|
||||||
|
for user in shared_user_enabled.apply_user_enabled_config(&cfg.access.user_enabled) {
|
||||||
|
let cancelled = shared_user_enabled.cancel_user_sessions(&user);
|
||||||
|
if cancelled > 0 {
|
||||||
|
info!(
|
||||||
|
user = %user,
|
||||||
|
cancelled,
|
||||||
|
"Disabled user sessions cancelled after config reload"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let beobachten_writer = beobachten.clone();
|
let beobachten_writer = beobachten.clone();
|
||||||
let config_rx_beobachten = config_rx.clone();
|
let config_rx_beobachten = config_rx.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
|||||||
@@ -55,8 +55,10 @@ pub async fn serve(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let is_ipv6 = addr.is_ipv6();
|
// Match `server.api.listen`: `[::]:port` is a dual-stack wildcard
|
||||||
match bind_metrics_listener(addr, is_ipv6, listen_backlog) {
|
// on Linux when `net.ipv6.bindv6only=0`.
|
||||||
|
let ipv6_only = addr.is_ipv6() && !addr.ip().is_unspecified();
|
||||||
|
match bind_metrics_listener(addr, ipv6_only, listen_backlog) {
|
||||||
Ok(listener) => {
|
Ok(listener) => {
|
||||||
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
|
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
|
||||||
serve_listener(
|
serve_listener(
|
||||||
|
|||||||
@@ -1558,6 +1558,11 @@ impl RunningClientHandler {
|
|||||||
{
|
{
|
||||||
let user = success.user.clone();
|
let user = success.user.clone();
|
||||||
|
|
||||||
|
if !shared.is_user_enabled(&user) {
|
||||||
|
warn!(user = %user, "Disabled user rejected");
|
||||||
|
return Err(ProxyError::UserDisabled { user });
|
||||||
|
}
|
||||||
|
|
||||||
let user_limit_reservation = match Self::acquire_user_connection_reservation_static(
|
let user_limit_reservation = match Self::acquire_user_connection_reservation_static(
|
||||||
&user,
|
&user,
|
||||||
&config,
|
&config,
|
||||||
@@ -1576,6 +1581,8 @@ impl RunningClientHandler {
|
|||||||
|
|
||||||
let route_snapshot = route_runtime.snapshot();
|
let route_snapshot = route_runtime.snapshot();
|
||||||
let session_id = rng.u64();
|
let session_id = rng.u64();
|
||||||
|
let _user_session = shared.register_user_session(&user, session_id);
|
||||||
|
let session_cancel = _user_session.token();
|
||||||
let selected_me_pool = if config.general.use_middle_proxy
|
let selected_me_pool = if config.general.use_middle_proxy
|
||||||
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
|
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
|
||||||
{
|
{
|
||||||
@@ -1607,6 +1614,7 @@ impl RunningClientHandler {
|
|||||||
route_runtime.subscribe(),
|
route_runtime.subscribe(),
|
||||||
route_snapshot,
|
route_snapshot,
|
||||||
session_id,
|
session_id,
|
||||||
|
session_cancel.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -1625,6 +1633,7 @@ impl RunningClientHandler {
|
|||||||
route_snapshot,
|
route_snapshot,
|
||||||
session_id,
|
session_id,
|
||||||
local_addr,
|
local_addr,
|
||||||
|
session_cancel.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -1644,6 +1653,7 @@ impl RunningClientHandler {
|
|||||||
route_snapshot,
|
route_snapshot,
|
||||||
session_id,
|
session_id,
|
||||||
local_addr,
|
local_addr,
|
||||||
|
session_cancel,
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use std::time::Duration;
|
|||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split};
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
@@ -258,6 +259,7 @@ where
|
|||||||
route_snapshot,
|
route_snapshot,
|
||||||
session_id,
|
session_id,
|
||||||
SocketAddr::from(([0, 0, 0, 0], config.server.port)),
|
SocketAddr::from(([0, 0, 0, 0], config.server.port)),
|
||||||
|
CancellationToken::new(),
|
||||||
ProxySharedState::new(),
|
ProxySharedState::new(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -276,6 +278,7 @@ pub(crate) async fn handle_via_direct_with_shared<R, W>(
|
|||||||
route_snapshot: RouteCutoverState,
|
route_snapshot: RouteCutoverState,
|
||||||
session_id: u64,
|
session_id: u64,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
|
session_cancel: CancellationToken,
|
||||||
shared: Arc<ProxySharedState>,
|
shared: Arc<ProxySharedState>,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
@@ -302,14 +305,25 @@ where
|
|||||||
"Ignoring invalid scope hint and falling back to default upstream selection"
|
"Ignoring invalid scope hint and falling back to default upstream selection"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let tg_stream = upstream_manager
|
let tg_stream = tokio::select! {
|
||||||
.connect(dc_addr, Some(success.dc_idx), scope_hint)
|
result = upstream_manager.connect(dc_addr, Some(success.dc_idx), scope_hint) => result?,
|
||||||
.await?;
|
_ = session_cancel.cancelled() => {
|
||||||
|
return Err(ProxyError::UserDisabled {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
|
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
|
||||||
|
|
||||||
let (tg_reader, tg_writer) =
|
let (tg_reader, tg_writer) = tokio::select! {
|
||||||
do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?;
|
result = do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()) => result?,
|
||||||
|
_ = session_cancel.cancelled() => {
|
||||||
|
return Err(ProxyError::UserDisabled {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||||
|
|
||||||
@@ -331,20 +345,22 @@ where
|
|||||||
} else {
|
} else {
|
||||||
Duration::from_secs(1800)
|
Duration::from_secs(1800)
|
||||||
};
|
};
|
||||||
let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout_and_lease(
|
let relay_result =
|
||||||
client_reader,
|
crate::proxy::relay::relay_bidirectional_with_activity_timeout_lease_and_cancel(
|
||||||
client_writer,
|
client_reader,
|
||||||
tg_reader,
|
client_writer,
|
||||||
tg_writer,
|
tg_reader,
|
||||||
config.general.direct_relay_copy_buf_c2s_bytes,
|
tg_writer,
|
||||||
config.general.direct_relay_copy_buf_s2c_bytes,
|
config.general.direct_relay_copy_buf_c2s_bytes,
|
||||||
user,
|
config.general.direct_relay_copy_buf_s2c_bytes,
|
||||||
Arc::clone(&stats),
|
user,
|
||||||
config.access.user_data_quota.get(user).copied(),
|
Arc::clone(&stats),
|
||||||
buffer_pool,
|
config.access.user_data_quota.get(user).copied(),
|
||||||
traffic_lease,
|
buffer_pool,
|
||||||
relay_activity_timeout,
|
traffic_lease,
|
||||||
);
|
relay_activity_timeout,
|
||||||
|
session_cancel.clone(),
|
||||||
|
);
|
||||||
tokio::pin!(relay_result);
|
tokio::pin!(relay_result);
|
||||||
let relay_result = loop {
|
let relay_result = loop {
|
||||||
if let Some(cutover) =
|
if let Some(cutover) =
|
||||||
@@ -371,6 +387,11 @@ where
|
|||||||
break relay_result.await;
|
break relay_result.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ = session_cancel.cancelled() => {
|
||||||
|
break Err(ProxyError::UserDisabled {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
|
|||||||
mut route_rx: watch::Receiver<RouteCutoverState>,
|
mut route_rx: watch::Receiver<RouteCutoverState>,
|
||||||
route_snapshot: RouteCutoverState,
|
route_snapshot: RouteCutoverState,
|
||||||
session_id: u64,
|
session_id: u64,
|
||||||
|
session_cancel: CancellationToken,
|
||||||
shared: Arc<ProxySharedState>,
|
shared: Arc<ProxySharedState>,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
@@ -20,6 +21,10 @@ where
|
|||||||
W: AsyncWrite + Unpin + Send + 'static,
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
let user = success.user.clone();
|
let user = success.user.clone();
|
||||||
|
if session_cancel.is_cancelled() {
|
||||||
|
return Err(ProxyError::UserDisabled { user });
|
||||||
|
}
|
||||||
|
|
||||||
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
||||||
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
|
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
|
||||||
let peer = success.peer;
|
let peer = success.peer;
|
||||||
@@ -590,6 +595,25 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
|
_ = session_cancel.cancelled() => {
|
||||||
|
warn!(
|
||||||
|
user = %user,
|
||||||
|
conn_id,
|
||||||
|
"Disabled user middle session cancelled"
|
||||||
|
);
|
||||||
|
let _ = enqueue_c2me_command_in(
|
||||||
|
shared.as_ref(),
|
||||||
|
&c2me_tx,
|
||||||
|
C2MeCommand::Close,
|
||||||
|
c2me_send_timeout,
|
||||||
|
stats.as_ref(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
main_result = Err(ProxyError::UserDisabled {
|
||||||
|
user: user.clone(),
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
changed = route_rx.changed(), if route_watch_open => {
|
changed = route_rx.changed(), if route_watch_open => {
|
||||||
if changed.is_err() {
|
if changed.is_err() {
|
||||||
route_watch_open = false;
|
route_watch_open = false;
|
||||||
|
|||||||
@@ -55,11 +55,13 @@ use crate::error::{ProxyError, Result};
|
|||||||
use crate::proxy::traffic_limiter::TrafficLease;
|
use crate::proxy::traffic_limiter::TrafficLease;
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::BufferPool;
|
use crate::stream::BufferPool;
|
||||||
|
use std::future::pending;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, copy_bidirectional_with_sizes};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, copy_bidirectional_with_sizes};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
// ============= Constants =============
|
// ============= Constants =============
|
||||||
@@ -191,6 +193,84 @@ pub async fn relay_bidirectional_with_activity_timeout_and_lease<CR, CW, SR, SW>
|
|||||||
traffic_lease: Option<Arc<TrafficLease>>,
|
traffic_lease: Option<Arc<TrafficLease>>,
|
||||||
activity_timeout: Duration,
|
activity_timeout: Duration,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
CR: AsyncRead + Unpin + Send + 'static,
|
||||||
|
CW: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
SR: AsyncRead + Unpin + Send + 'static,
|
||||||
|
SW: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
relay_bidirectional_with_activity_timeout_lease_cancel_inner(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
c2s_buf_size,
|
||||||
|
s2c_buf_size,
|
||||||
|
user,
|
||||||
|
stats,
|
||||||
|
quota_limit,
|
||||||
|
_buffer_pool,
|
||||||
|
traffic_lease,
|
||||||
|
activity_timeout,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn relay_bidirectional_with_activity_timeout_lease_and_cancel<CR, CW, SR, SW>(
|
||||||
|
client_reader: CR,
|
||||||
|
client_writer: CW,
|
||||||
|
server_reader: SR,
|
||||||
|
server_writer: SW,
|
||||||
|
c2s_buf_size: usize,
|
||||||
|
s2c_buf_size: usize,
|
||||||
|
user: &str,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
quota_limit: Option<u64>,
|
||||||
|
_buffer_pool: Arc<BufferPool>,
|
||||||
|
traffic_lease: Option<Arc<TrafficLease>>,
|
||||||
|
activity_timeout: Duration,
|
||||||
|
session_cancel: CancellationToken,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
CR: AsyncRead + Unpin + Send + 'static,
|
||||||
|
CW: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
SR: AsyncRead + Unpin + Send + 'static,
|
||||||
|
SW: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
relay_bidirectional_with_activity_timeout_lease_cancel_inner(
|
||||||
|
client_reader,
|
||||||
|
client_writer,
|
||||||
|
server_reader,
|
||||||
|
server_writer,
|
||||||
|
c2s_buf_size,
|
||||||
|
s2c_buf_size,
|
||||||
|
user,
|
||||||
|
stats,
|
||||||
|
quota_limit,
|
||||||
|
_buffer_pool,
|
||||||
|
traffic_lease,
|
||||||
|
activity_timeout,
|
||||||
|
Some(session_cancel),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn relay_bidirectional_with_activity_timeout_lease_cancel_inner<CR, CW, SR, SW>(
|
||||||
|
client_reader: CR,
|
||||||
|
client_writer: CW,
|
||||||
|
server_reader: SR,
|
||||||
|
server_writer: SW,
|
||||||
|
c2s_buf_size: usize,
|
||||||
|
s2c_buf_size: usize,
|
||||||
|
user: &str,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
quota_limit: Option<u64>,
|
||||||
|
_buffer_pool: Arc<BufferPool>,
|
||||||
|
traffic_lease: Option<Arc<TrafficLease>>,
|
||||||
|
activity_timeout: Duration,
|
||||||
|
session_cancel: Option<CancellationToken>,
|
||||||
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
CR: AsyncRead + Unpin + Send + 'static,
|
CR: AsyncRead + Unpin + Send + 'static,
|
||||||
CW: AsyncWrite + Unpin + Send + 'static,
|
CW: AsyncWrite + Unpin + Send + 'static,
|
||||||
@@ -287,14 +367,29 @@ where
|
|||||||
//
|
//
|
||||||
// When the watchdog fires, select! drops the copy future,
|
// When the watchdog fires, select! drops the copy future,
|
||||||
// releasing the &mut borrows on client and server.
|
// releasing the &mut borrows on client and server.
|
||||||
let copy_result = tokio::select! {
|
enum RelayOutcome {
|
||||||
|
Copy(std::io::Result<(u64, u64)>),
|
||||||
|
ActivityTimeout,
|
||||||
|
UserDisabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
let cancel_wait = async move {
|
||||||
|
match session_cancel {
|
||||||
|
Some(token) => token.cancelled().await,
|
||||||
|
None => pending::<()>().await,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tokio::pin!(cancel_wait);
|
||||||
|
|
||||||
|
let relay_outcome = tokio::select! {
|
||||||
result = copy_bidirectional_with_sizes(
|
result = copy_bidirectional_with_sizes(
|
||||||
&mut client,
|
&mut client,
|
||||||
&mut server,
|
&mut server,
|
||||||
c2s_buf_size.max(1),
|
c2s_buf_size.max(1),
|
||||||
s2c_buf_size.max(1),
|
s2c_buf_size.max(1),
|
||||||
) => Some(result),
|
) => RelayOutcome::Copy(result),
|
||||||
_ = watchdog => None, // Activity timeout — cancel relay
|
_ = watchdog => RelayOutcome::ActivityTimeout,
|
||||||
|
_ = &mut cancel_wait => RelayOutcome::UserDisabled,
|
||||||
};
|
};
|
||||||
|
|
||||||
// ── Clean shutdown ──────────────────────────────────────────────
|
// ── Clean shutdown ──────────────────────────────────────────────
|
||||||
@@ -308,8 +403,8 @@ where
|
|||||||
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
|
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
|
||||||
let duration = epoch.elapsed();
|
let duration = epoch.elapsed();
|
||||||
|
|
||||||
match copy_result {
|
match relay_outcome {
|
||||||
Some(Ok((c2s, s2c))) => {
|
RelayOutcome::Copy(Ok((c2s, s2c))) => {
|
||||||
// Normal completion — one side closed the connection
|
// Normal completion — one side closed the connection
|
||||||
debug!(
|
debug!(
|
||||||
user = %user_owned,
|
user = %user_owned,
|
||||||
@@ -322,7 +417,7 @@ where
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Some(Err(e)) if is_quota_io_error(&e) => {
|
RelayOutcome::Copy(Err(e)) if is_quota_io_error(&e) => {
|
||||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
warn!(
|
warn!(
|
||||||
@@ -338,7 +433,7 @@ where
|
|||||||
user: user_owned.clone(),
|
user: user_owned.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Some(Err(e)) => {
|
RelayOutcome::Copy(Err(e)) => {
|
||||||
// I/O error in one of the directions
|
// I/O error in one of the directions
|
||||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
@@ -354,7 +449,7 @@ where
|
|||||||
);
|
);
|
||||||
Err(e.into())
|
Err(e.into())
|
||||||
}
|
}
|
||||||
None => {
|
RelayOutcome::ActivityTimeout => {
|
||||||
// Activity timeout (watchdog fired)
|
// Activity timeout (watchdog fired)
|
||||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
@@ -369,6 +464,22 @@ where
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
RelayOutcome::UserDisabled => {
|
||||||
|
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||||
|
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||||
|
debug!(
|
||||||
|
user = %user_owned,
|
||||||
|
c2s_bytes = c2s,
|
||||||
|
s2c_bytes = s2c,
|
||||||
|
c2s_msgs = c2s_ops,
|
||||||
|
s2c_msgs = s2c_ops,
|
||||||
|
duration_secs = duration.as_secs(),
|
||||||
|
"Relay finished (user disabled)"
|
||||||
|
);
|
||||||
|
Err(ProxyError::UserDisabled {
|
||||||
|
user: user_owned.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::collections::HashSet;
|
|
||||||
use std::collections::hash_map::RandomState;
|
use std::collections::hash_map::RandomState;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
@@ -7,6 +7,7 @@ use std::time::Instant;
|
|||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
|
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
|
||||||
use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry};
|
use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry};
|
||||||
@@ -67,10 +68,35 @@ pub(crate) struct ProxySharedState {
|
|||||||
pub(crate) handshake: HandshakeSharedState,
|
pub(crate) handshake: HandshakeSharedState,
|
||||||
pub(crate) middle_relay: MiddleRelaySharedState,
|
pub(crate) middle_relay: MiddleRelaySharedState,
|
||||||
pub(crate) traffic_limiter: Arc<TrafficLimiter>,
|
pub(crate) traffic_limiter: Arc<TrafficLimiter>,
|
||||||
|
disabled_users: DashMap<String, ()>,
|
||||||
|
active_user_sessions: DashMap<(String, u64), CancellationToken>,
|
||||||
pub(crate) conntrack_pressure_active: AtomicBool,
|
pub(crate) conntrack_pressure_active: AtomicBool,
|
||||||
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
|
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use = "registered user sessions must be kept alive until relay completion"]
|
||||||
|
pub(crate) struct UserSessionRegistration {
|
||||||
|
token: CancellationToken,
|
||||||
|
_guard: UserSessionGuard,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserSessionRegistration {
|
||||||
|
pub(crate) fn token(&self) -> CancellationToken {
|
||||||
|
self.token.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct UserSessionGuard {
|
||||||
|
shared: Arc<ProxySharedState>,
|
||||||
|
key: (String, u64),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for UserSessionGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.shared.active_user_sessions.remove(&self.key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ProxySharedState {
|
impl ProxySharedState {
|
||||||
pub(crate) fn new() -> Arc<Self> {
|
pub(crate) fn new() -> Arc<Self> {
|
||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
@@ -101,11 +127,82 @@ impl ProxySharedState {
|
|||||||
relay_idle_mark_seq: AtomicU64::new(0),
|
relay_idle_mark_seq: AtomicU64::new(0),
|
||||||
},
|
},
|
||||||
traffic_limiter: TrafficLimiter::new(),
|
traffic_limiter: TrafficLimiter::new(),
|
||||||
|
disabled_users: DashMap::new(),
|
||||||
|
active_user_sessions: DashMap::new(),
|
||||||
conntrack_pressure_active: AtomicBool::new(false),
|
conntrack_pressure_active: AtomicBool::new(false),
|
||||||
conntrack_close_tx: Mutex::new(None),
|
conntrack_close_tx: Mutex::new(None),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_user_enabled(&self, user: &str) -> bool {
|
||||||
|
!self.disabled_users.contains_key(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_user_enabled(&self, user: &str, enabled: bool) -> bool {
|
||||||
|
if enabled {
|
||||||
|
self.disabled_users.remove(user);
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
self.disabled_users.insert(user.to_string(), ()).is_none()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn apply_user_enabled_config(
|
||||||
|
&self,
|
||||||
|
user_enabled: &HashMap<String, bool>,
|
||||||
|
) -> Vec<String> {
|
||||||
|
let desired_disabled = user_enabled
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(user, enabled)| (!*enabled).then_some(user.clone()))
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
let current_disabled = self
|
||||||
|
.disabled_users
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.key().clone())
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
for user in current_disabled.difference(&desired_disabled) {
|
||||||
|
self.disabled_users.remove(user);
|
||||||
|
}
|
||||||
|
let newly_disabled = desired_disabled
|
||||||
|
.difference(¤t_disabled)
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for user in desired_disabled {
|
||||||
|
self.disabled_users.insert(user, ());
|
||||||
|
}
|
||||||
|
newly_disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn register_user_session(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
user: &str,
|
||||||
|
session_id: u64,
|
||||||
|
) -> UserSessionRegistration {
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
let key = (user.to_string(), session_id);
|
||||||
|
self.active_user_sessions.insert(key.clone(), token.clone());
|
||||||
|
UserSessionRegistration {
|
||||||
|
token,
|
||||||
|
_guard: UserSessionGuard {
|
||||||
|
shared: Arc::clone(self),
|
||||||
|
key,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn cancel_user_sessions(&self, user: &str) -> usize {
|
||||||
|
let tokens = self
|
||||||
|
.active_user_sessions
|
||||||
|
.iter()
|
||||||
|
.filter_map(|entry| (entry.key().0 == user).then(|| entry.value().clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for token in &tokens {
|
||||||
|
token.cancel();
|
||||||
|
}
|
||||||
|
tokens.len()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn set_conntrack_close_sender(&self, tx: mpsc::Sender<ConntrackCloseEvent>) {
|
pub(crate) fn set_conntrack_close_sender(&self, tx: mpsc::Sender<ConntrackCloseEvent>) {
|
||||||
match self.conntrack_close_tx.lock() {
|
match self.conntrack_close_tx.lock() {
|
||||||
Ok(mut guard) => {
|
Ok(mut guard) => {
|
||||||
@@ -166,3 +263,48 @@ impl ProxySharedState {
|
|||||||
self.conntrack_pressure_active.load(Ordering::Relaxed)
|
self.conntrack_pressure_active.load(Ordering::Relaxed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn user_enabled_config_sync_tracks_disabled_overrides() {
|
||||||
|
let shared = ProxySharedState::new();
|
||||||
|
assert!(shared.is_user_enabled("alice"));
|
||||||
|
|
||||||
|
let mut user_enabled = HashMap::new();
|
||||||
|
user_enabled.insert("alice".to_string(), false);
|
||||||
|
user_enabled.insert("bob".to_string(), true);
|
||||||
|
|
||||||
|
let mut newly_disabled = shared.apply_user_enabled_config(&user_enabled);
|
||||||
|
newly_disabled.sort();
|
||||||
|
assert_eq!(newly_disabled, vec!["alice".to_string()]);
|
||||||
|
assert!(!shared.is_user_enabled("alice"));
|
||||||
|
assert!(shared.is_user_enabled("bob"));
|
||||||
|
|
||||||
|
assert!(shared.apply_user_enabled_config(&user_enabled).is_empty());
|
||||||
|
|
||||||
|
user_enabled.clear();
|
||||||
|
assert!(shared.apply_user_enabled_config(&user_enabled).is_empty());
|
||||||
|
assert!(shared.is_user_enabled("alice"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cancel_user_sessions_cancels_only_registered_matching_user() {
|
||||||
|
let shared = ProxySharedState::new();
|
||||||
|
let alice_1 = shared.register_user_session("alice", 1);
|
||||||
|
let alice_2 = shared.register_user_session("alice", 2);
|
||||||
|
let bob = shared.register_user_session("bob", 1);
|
||||||
|
let alice_1_token = alice_1.token();
|
||||||
|
let alice_2_token = alice_2.token();
|
||||||
|
let bob_token = bob.token();
|
||||||
|
|
||||||
|
drop(alice_1);
|
||||||
|
|
||||||
|
assert_eq!(shared.cancel_user_sessions("alice"), 1);
|
||||||
|
assert!(!alice_1_token.is_cancelled());
|
||||||
|
assert!(alice_2_token.is_cancelled());
|
||||||
|
assert!(!bob_token.is_cancelled());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user