diff --git a/src/api/config_store.rs b/src/api/config_store.rs new file mode 100644 index 0000000..e7fbbca --- /dev/null +++ b/src/api/config_store.rs @@ -0,0 +1,107 @@ +use std::io::Write; +use std::path::{Path, PathBuf}; + +use hyper::header::IF_MATCH; +use sha2::{Digest, Sha256}; + +use crate::config::ProxyConfig; + +use super::model::ApiFailure; + +pub(super) fn parse_if_match(headers: &hyper::HeaderMap) -> Option { + headers + .get(IF_MATCH) + .and_then(|value| value.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(|value| value.trim_matches('"').to_string()) +} + +pub(super) async fn ensure_expected_revision( + config_path: &Path, + expected_revision: Option<&str>, +) -> Result<(), ApiFailure> { + let Some(expected) = expected_revision else { + return Ok(()); + }; + let current = current_revision(config_path).await?; + if current != expected { + return Err(ApiFailure::new( + hyper::StatusCode::CONFLICT, + "revision_conflict", + "Config revision mismatch", + )); + } + Ok(()) +} + +pub(super) async fn current_revision(config_path: &Path) -> Result { + let content = tokio::fs::read_to_string(config_path) + .await + .map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?; + Ok(compute_revision(&content)) +} + +pub(super) fn compute_revision(content: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(content.as_bytes()); + hex::encode(hasher.finalize()) +} + +pub(super) async fn load_config_from_disk(config_path: &Path) -> Result { + let config_path = config_path.to_path_buf(); + tokio::task::spawn_blocking(move || ProxyConfig::load(config_path)) + .await + .map_err(|e| ApiFailure::internal(format!("failed to join config loader: {}", e)))? + .map_err(|e| ApiFailure::internal(format!("failed to load config: {}", e))) +} + +pub(super) async fn save_config_to_disk( + config_path: &Path, + cfg: &ProxyConfig, +) -> Result { + let serialized = toml::to_string_pretty(cfg) + .map_err(|e| ApiFailure::internal(format!("failed to serialize config: {}", e)))?; + write_atomic(config_path.to_path_buf(), serialized.clone()).await?; + Ok(compute_revision(&serialized)) +} + +async fn write_atomic(path: PathBuf, contents: String) -> Result<(), ApiFailure> { + tokio::task::spawn_blocking(move || write_atomic_sync(&path, &contents)) + .await + .map_err(|e| ApiFailure::internal(format!("failed to join writer: {}", e)))? + .map_err(|e| ApiFailure::internal(format!("failed to write config: {}", e))) +} + +fn write_atomic_sync(path: &Path, contents: &str) -> std::io::Result<()> { + let parent = path.parent().unwrap_or_else(|| Path::new(".")); + std::fs::create_dir_all(parent)?; + + let tmp_name = format!( + ".{}.tmp-{}", + path.file_name() + .and_then(|s| s.to_str()) + .unwrap_or("config.toml"), + rand::random::() + ); + let tmp_path = parent.join(tmp_name); + + let write_result = (|| { + let mut file = std::fs::OpenOptions::new() + .create_new(true) + .write(true) + .open(&tmp_path)?; + file.write_all(contents.as_bytes())?; + file.sync_all()?; + std::fs::rename(&tmp_path, path)?; + if let Ok(dir) = std::fs::File::open(parent) { + let _ = dir.sync_all(); + } + Ok(()) + })(); + + if write_result.is_err() { + let _ = std::fs::remove_file(&tmp_path); + } + write_result +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..c13828e --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,383 @@ +use std::convert::Infallible; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use http_body_util::{BodyExt, Full}; +use hyper::body::{Bytes, Incoming}; +use hyper::header::AUTHORIZATION; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Method, Request, Response, StatusCode}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use tokio::net::TcpListener; +use tokio::sync::{Mutex, watch}; +use tracing::{debug, info, warn}; + +use crate::config::ProxyConfig; +use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; + +mod config_store; +mod model; +mod users; + +use config_store::{current_revision, parse_if_match}; +use model::{ + ApiFailure, CreateUserRequest, ErrorBody, ErrorResponse, HealthData, PatchUserRequest, + RotateSecretRequest, SuccessResponse, SummaryData, +}; +use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; + +#[derive(Clone)] +pub(super) struct ApiShared { + pub(super) stats: Arc, + pub(super) ip_tracker: Arc, + pub(super) config_path: PathBuf, + pub(super) mutation_lock: Arc>, + pub(super) request_id: Arc, +} + +impl ApiShared { + fn next_request_id(&self) -> u64 { + self.request_id.fetch_add(1, Ordering::Relaxed) + } +} + +pub async fn serve( + listen: SocketAddr, + stats: Arc, + ip_tracker: Arc, + config_rx: watch::Receiver>, + config_path: PathBuf, +) { + let listener = match TcpListener::bind(listen).await { + Ok(listener) => listener, + Err(error) => { + warn!( + error = %error, + listen = %listen, + "Failed to bind API listener" + ); + return; + } + }; + + info!("API endpoint: http://{}/v1/*", listen); + + let shared = Arc::new(ApiShared { + stats, + ip_tracker, + config_path, + mutation_lock: Arc::new(Mutex::new(())), + request_id: Arc::new(AtomicU64::new(1)), + }); + + loop { + let (stream, peer) = match listener.accept().await { + Ok(v) => v, + Err(error) => { + warn!(error = %error, "API accept error"); + continue; + } + }; + + let shared_conn = shared.clone(); + let config_rx_conn = config_rx.clone(); + tokio::spawn(async move { + let svc = service_fn(move |req: Request| { + let shared_req = shared_conn.clone(); + let config_rx_req = config_rx_conn.clone(); + async move { handle(req, peer, shared_req, config_rx_req).await } + }); + if let Err(error) = http1::Builder::new() + .serve_connection(hyper_util::rt::TokioIo::new(stream), svc) + .await + { + debug!(error = %error, "API connection error"); + } + }); + } +} + +async fn handle( + req: Request, + peer: SocketAddr, + shared: Arc, + config_rx: watch::Receiver>, +) -> Result>, Infallible> { + let request_id = shared.next_request_id(); + let cfg = config_rx.borrow().clone(); + let api_cfg = &cfg.server.api; + + if !api_cfg.enabled { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::SERVICE_UNAVAILABLE, + "api_disabled", + "API is disabled", + ), + )); + } + + if !api_cfg.whitelist.is_empty() + && !api_cfg + .whitelist + .iter() + .any(|net| net.contains(peer.ip())) + { + return Ok(error_response( + request_id, + ApiFailure::new(StatusCode::FORBIDDEN, "forbidden", "Source IP is not allowed"), + )); + } + + if !api_cfg.auth_header.is_empty() { + let auth_ok = req + .headers() + .get(AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .map(|v| v == api_cfg.auth_header) + .unwrap_or(false); + if !auth_ok { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::UNAUTHORIZED, + "unauthorized", + "Missing or invalid Authorization header", + ), + )); + } + } + + let method = req.method().clone(); + let path = req.uri().path().to_string(); + let body_limit = api_cfg.request_body_limit_bytes; + + let result: Result>, ApiFailure> = async { + match (method.as_str(), path.as_str()) { + ("GET", "/v1/health") => { + let revision = current_revision(&shared.config_path).await?; + let data = HealthData { + status: "ok", + read_only: api_cfg.read_only, + }; + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/summary") => { + let revision = current_revision(&shared.config_path).await?; + let data = SummaryData { + uptime_seconds: shared.stats.uptime_secs(), + connections_total: shared.stats.get_connects_all(), + connections_bad_total: shared.stats.get_connects_bad(), + handshake_timeouts_total: shared.stats.get_handshake_timeouts(), + configured_users: cfg.access.users.len(), + }; + Ok(success_response(StatusCode::OK, data, revision)) + } + ("GET", "/v1/stats/users") | ("GET", "/v1/users") => { + let revision = current_revision(&shared.config_path).await?; + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + Ok(success_response(StatusCode::OK, users, revision)) + } + ("POST", "/v1/users") => { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let body = read_json::(req.into_body(), body_limit).await?; + let (data, revision) = create_user(body, expected_revision, &shared).await?; + Ok(success_response(StatusCode::CREATED, data, revision)) + } + _ => { + if let Some(user) = path.strip_prefix("/v1/users/") + && !user.is_empty() + && !user.contains('/') + { + if method == Method::GET { + let revision = current_revision(&shared.config_path).await?; + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + if let Some(user_info) = users.into_iter().find(|entry| entry.username == user) + { + return Ok(success_response(StatusCode::OK, user_info, revision)); + } + return Ok(error_response( + request_id, + ApiFailure::new(StatusCode::NOT_FOUND, "not_found", "User not found"), + )); + } + if method == Method::PATCH { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let body = read_json::(req.into_body(), body_limit).await?; + let (data, revision) = + patch_user(user, body, expected_revision, &shared).await?; + return Ok(success_response(StatusCode::OK, data, revision)); + } + if method == Method::DELETE { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let (deleted_user, revision) = + delete_user(user, expected_revision, &shared).await?; + return Ok(success_response(StatusCode::OK, deleted_user, revision)); + } + if method == Method::POST + && let Some(base_user) = user.strip_suffix("/rotate-secret") + && !base_user.is_empty() + && !base_user.contains('/') + { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let body = + read_optional_json::(req.into_body(), body_limit) + .await?; + let (data, revision) = + rotate_secret(base_user, body.unwrap_or_default(), expected_revision, &shared) + .await?; + return Ok(success_response(StatusCode::OK, data, revision)); + } + if method == Method::POST { + return Ok(error_response( + request_id, + ApiFailure::new(StatusCode::NOT_FOUND, "not_found", "Route not found"), + )); + } + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::METHOD_NOT_ALLOWED, + "method_not_allowed", + "Unsupported HTTP method for this route", + ), + )); + } + Ok(error_response( + request_id, + ApiFailure::new(StatusCode::NOT_FOUND, "not_found", "Route not found"), + )) + } + } + } + .await; + + match result { + Ok(resp) => Ok(resp), + Err(error) => Ok(error_response(request_id, error)), + } +} + +fn success_response( + status: StatusCode, + data: T, + revision: String, +) -> Response> { + let payload = SuccessResponse { + ok: true, + data, + revision, + }; + let body = serde_json::to_vec(&payload).unwrap_or_else(|_| b"{\"ok\":false}".to_vec()); + Response::builder() + .status(status) + .header("content-type", "application/json; charset=utf-8") + .body(Full::new(Bytes::from(body))) + .unwrap() +} + +fn error_response(request_id: u64, failure: ApiFailure) -> Response> { + let payload = ErrorResponse { + ok: false, + error: ErrorBody { + code: failure.code, + message: failure.message, + }, + request_id, + }; + let body = serde_json::to_vec(&payload).unwrap_or_else(|_| { + format!( + "{{\"ok\":false,\"error\":{{\"code\":\"internal_error\",\"message\":\"serialization failed\"}},\"request_id\":{}}}", + request_id + ) + .into_bytes() + }); + Response::builder() + .status(failure.status) + .header("content-type", "application/json; charset=utf-8") + .body(Full::new(Bytes::from(body))) + .unwrap() +} + +async fn read_json(body: Incoming, limit: usize) -> Result { + let bytes = read_body_with_limit(body, limit).await?; + serde_json::from_slice(&bytes).map_err(|_| ApiFailure::bad_request("Invalid JSON body")) +} + +async fn read_optional_json( + body: Incoming, + limit: usize, +) -> Result, ApiFailure> { + let bytes = read_body_with_limit(body, limit).await?; + if bytes.is_empty() { + return Ok(None); + } + serde_json::from_slice(&bytes) + .map(Some) + .map_err(|_| ApiFailure::bad_request("Invalid JSON body")) +} + +async fn read_body_with_limit(body: Incoming, limit: usize) -> Result, ApiFailure> { + let mut collected = Vec::new(); + let mut body = body; + while let Some(frame_result) = body.frame().await { + let frame = frame_result.map_err(|_| ApiFailure::bad_request("Invalid request body"))?; + if let Some(chunk) = frame.data_ref() { + if collected.len().saturating_add(chunk.len()) > limit { + return Err(ApiFailure::new( + StatusCode::PAYLOAD_TOO_LARGE, + "payload_too_large", + format!("Body exceeds {} bytes", limit), + )); + } + collected.extend_from_slice(chunk); + } + } + Ok(collected) +} diff --git a/src/api/model.rs b/src/api/model.rs new file mode 100644 index 0000000..bea2301 --- /dev/null +++ b/src/api/model.rs @@ -0,0 +1,144 @@ +use chrono::{DateTime, Utc}; +use hyper::StatusCode; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +const MAX_USERNAME_LEN: usize = 64; + +#[derive(Debug)] +pub(super) struct ApiFailure { + pub(super) status: StatusCode, + pub(super) code: &'static str, + pub(super) message: String, +} + +impl ApiFailure { + pub(super) fn new(status: StatusCode, code: &'static str, message: impl Into) -> Self { + Self { + status, + code, + message: message.into(), + } + } + + pub(super) fn internal(message: impl Into) -> Self { + Self::new(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", message) + } + + pub(super) fn bad_request(message: impl Into) -> Self { + Self::new(StatusCode::BAD_REQUEST, "bad_request", message) + } +} + +#[derive(Serialize)] +pub(super) struct ErrorBody { + pub(super) code: &'static str, + pub(super) message: String, +} + +#[derive(Serialize)] +pub(super) struct ErrorResponse { + pub(super) ok: bool, + pub(super) error: ErrorBody, + pub(super) request_id: u64, +} + +#[derive(Serialize)] +pub(super) struct SuccessResponse { + pub(super) ok: bool, + pub(super) data: T, + pub(super) revision: String, +} + +#[derive(Serialize)] +pub(super) struct HealthData { + pub(super) status: &'static str, + pub(super) read_only: bool, +} + +#[derive(Serialize)] +pub(super) struct SummaryData { + pub(super) uptime_seconds: f64, + pub(super) connections_total: u64, + pub(super) connections_bad_total: u64, + pub(super) handshake_timeouts_total: u64, + pub(super) configured_users: usize, +} + +#[derive(Serialize)] +pub(super) struct UserInfo { + pub(super) username: String, + pub(super) user_ad_tag: Option, + pub(super) max_tcp_conns: Option, + pub(super) expiration_rfc3339: Option, + pub(super) data_quota_bytes: Option, + pub(super) max_unique_ips: Option, + pub(super) current_connections: u64, + pub(super) active_unique_ips: usize, + pub(super) total_octets: u64, +} + +#[derive(Serialize)] +pub(super) struct CreateUserResponse { + pub(super) user: UserInfo, + pub(super) secret: String, +} + +#[derive(Deserialize)] +pub(super) struct CreateUserRequest { + pub(super) username: String, + pub(super) secret: Option, + pub(super) user_ad_tag: Option, + pub(super) max_tcp_conns: Option, + pub(super) expiration_rfc3339: Option, + pub(super) data_quota_bytes: Option, + pub(super) max_unique_ips: Option, +} + +#[derive(Deserialize)] +pub(super) struct PatchUserRequest { + pub(super) secret: Option, + pub(super) user_ad_tag: Option, + pub(super) max_tcp_conns: Option, + pub(super) expiration_rfc3339: Option, + pub(super) data_quota_bytes: Option, + pub(super) max_unique_ips: Option, +} + +#[derive(Default, Deserialize)] +pub(super) struct RotateSecretRequest { + pub(super) secret: Option, +} + +pub(super) fn parse_optional_expiration( + value: Option<&str>, +) -> Result>, ApiFailure> { + let Some(raw) = value else { + return Ok(None); + }; + let parsed = DateTime::parse_from_rfc3339(raw) + .map_err(|_| ApiFailure::bad_request("expiration_rfc3339 must be valid RFC3339"))?; + Ok(Some(parsed.with_timezone(&Utc))) +} + +pub(super) fn is_valid_user_secret(secret: &str) -> bool { + secret.len() == 32 && secret.chars().all(|c| c.is_ascii_hexdigit()) +} + +pub(super) fn is_valid_ad_tag(tag: &str) -> bool { + tag.len() == 32 && tag.chars().all(|c| c.is_ascii_hexdigit()) +} + +pub(super) fn is_valid_username(user: &str) -> bool { + !user.is_empty() + && user.len() <= MAX_USERNAME_LEN + && user + .chars() + .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | '.')) +} + +pub(super) fn random_user_secret() -> String { + let mut bytes = [0u8; 16]; + rand::rng().fill(&mut bytes); + hex::encode(bytes) +} diff --git a/src/api/users.rs b/src/api/users.rs new file mode 100644 index 0000000..75d659f --- /dev/null +++ b/src/api/users.rs @@ -0,0 +1,301 @@ +use std::collections::HashMap; + +use hyper::StatusCode; + +use crate::config::ProxyConfig; +use crate::ip_tracker::UserIpTracker; +use crate::stats::Stats; + +use super::ApiShared; +use super::config_store::{ + ensure_expected_revision, load_config_from_disk, save_config_to_disk, +}; +use super::model::{ + ApiFailure, CreateUserRequest, CreateUserResponse, PatchUserRequest, RotateSecretRequest, + UserInfo, is_valid_ad_tag, is_valid_user_secret, is_valid_username, parse_optional_expiration, + random_user_secret, +}; + +pub(super) async fn create_user( + body: CreateUserRequest, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(CreateUserResponse, String), ApiFailure> { + if !is_valid_username(&body.username) { + return Err(ApiFailure::bad_request( + "username must match [A-Za-z0-9_.-] and be 1..64 chars", + )); + } + + let secret = match body.secret { + Some(secret) => { + if !is_valid_user_secret(&secret) { + return Err(ApiFailure::bad_request( + "secret must be exactly 32 hex characters", + )); + } + secret + } + None => random_user_secret(), + }; + + if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + return Err(ApiFailure::bad_request( + "user_ad_tag must be exactly 32 hex characters", + )); + } + + let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?; + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if cfg.access.users.contains_key(&body.username) { + return Err(ApiFailure::new( + StatusCode::CONFLICT, + "user_exists", + "User already exists", + )); + } + + cfg.access.users.insert(body.username.clone(), secret.clone()); + if let Some(ad_tag) = body.user_ad_tag { + cfg.access.user_ad_tags.insert(body.username.clone(), ad_tag); + } + if let Some(limit) = body.max_tcp_conns { + cfg.access.user_max_tcp_conns.insert(body.username.clone(), limit); + } + if let Some(expiration) = expiration { + cfg.access + .user_expirations + .insert(body.username.clone(), expiration); + } + if let Some(quota) = body.data_quota_bytes { + cfg.access.user_data_quota.insert(body.username.clone(), quota); + } + + let updated_limit = body.max_unique_ips; + if let Some(limit) = updated_limit { + cfg.access + .user_max_unique_ips + .insert(body.username.clone(), limit); + } + + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + + if let Some(limit) = updated_limit { + shared.ip_tracker.set_user_limit(&body.username, limit).await; + } + + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let user = users + .into_iter() + .find(|entry| entry.username == body.username) + .unwrap_or(UserInfo { + username: body.username.clone(), + user_ad_tag: None, + max_tcp_conns: None, + expiration_rfc3339: None, + data_quota_bytes: None, + max_unique_ips: updated_limit, + current_connections: 0, + active_unique_ips: 0, + total_octets: 0, + }); + + Ok((CreateUserResponse { user, secret }, revision)) +} + +pub(super) async fn patch_user( + user: &str, + body: PatchUserRequest, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(UserInfo, String), ApiFailure> { + if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) { + return Err(ApiFailure::bad_request( + "secret must be exactly 32 hex characters", + )); + } + if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { + return Err(ApiFailure::bad_request( + "user_ad_tag must be exactly 32 hex characters", + )); + } + let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?; + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if !cfg.access.users.contains_key(user) { + return Err(ApiFailure::new( + StatusCode::NOT_FOUND, + "not_found", + "User not found", + )); + } + + if let Some(secret) = body.secret { + cfg.access.users.insert(user.to_string(), secret); + } + if let Some(ad_tag) = body.user_ad_tag { + cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); + } + if let Some(limit) = body.max_tcp_conns { + cfg.access.user_max_tcp_conns.insert(user.to_string(), limit); + } + if let Some(expiration) = expiration { + cfg.access.user_expirations.insert(user.to_string(), expiration); + } + if let Some(quota) = body.data_quota_bytes { + cfg.access.user_data_quota.insert(user.to_string(), quota); + } + + let mut updated_limit = None; + if let Some(limit) = body.max_unique_ips { + cfg.access.user_max_unique_ips.insert(user.to_string(), limit); + updated_limit = Some(limit); + } + + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + if let Some(limit) = updated_limit { + shared.ip_tracker.set_user_limit(user, limit).await; + } + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let user_info = users + .into_iter() + .find(|entry| entry.username == user) + .ok_or_else(|| ApiFailure::internal("failed to build updated user view"))?; + + Ok((user_info, revision)) +} + +pub(super) async fn rotate_secret( + user: &str, + body: RotateSecretRequest, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(CreateUserResponse, String), ApiFailure> { + let secret = body.secret.unwrap_or_else(random_user_secret); + if !is_valid_user_secret(&secret) { + return Err(ApiFailure::bad_request( + "secret must be exactly 32 hex characters", + )); + } + + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if !cfg.access.users.contains_key(user) { + return Err(ApiFailure::new( + StatusCode::NOT_FOUND, + "not_found", + "User not found", + )); + } + + cfg.access.users.insert(user.to_string(), secret.clone()); + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + + let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let user_info = users + .into_iter() + .find(|entry| entry.username == user) + .ok_or_else(|| ApiFailure::internal("failed to build updated user view"))?; + + Ok(( + CreateUserResponse { + user: user_info, + secret, + }, + revision, + )) +} + +pub(super) async fn delete_user( + user: &str, + expected_revision: Option, + shared: &ApiShared, +) -> Result<(String, String), ApiFailure> { + let _guard = shared.mutation_lock.lock().await; + let mut cfg = load_config_from_disk(&shared.config_path).await?; + ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; + + if !cfg.access.users.contains_key(user) { + return Err(ApiFailure::new( + StatusCode::NOT_FOUND, + "not_found", + "User not found", + )); + } + if cfg.access.users.len() <= 1 { + return Err(ApiFailure::new( + StatusCode::CONFLICT, + "last_user_forbidden", + "Cannot delete the last configured user", + )); + } + + cfg.access.users.remove(user); + cfg.access.user_ad_tags.remove(user); + cfg.access.user_max_tcp_conns.remove(user); + cfg.access.user_expirations.remove(user); + cfg.access.user_data_quota.remove(user); + cfg.access.user_max_unique_ips.remove(user); + + cfg.validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + let revision = save_config_to_disk(&shared.config_path, &cfg).await?; + drop(_guard); + shared.ip_tracker.clear_user_ips(user).await; + + Ok((user.to_string(), revision)) +} + +pub(super) async fn users_from_config( + cfg: &ProxyConfig, + stats: &Stats, + ip_tracker: &UserIpTracker, +) -> Vec { + let ip_counts = ip_tracker + .get_stats() + .await + .into_iter() + .map(|(user, count, _)| (user, count)) + .collect::>(); + + let mut names = cfg.access.users.keys().cloned().collect::>(); + names.sort(); + + let mut users = Vec::with_capacity(names.len()); + for username in names { + users.push(UserInfo { + user_ad_tag: cfg.access.user_ad_tags.get(&username).cloned(), + max_tcp_conns: cfg.access.user_max_tcp_conns.get(&username).copied(), + expiration_rfc3339: cfg + .access + .user_expirations + .get(&username) + .map(chrono::DateTime::::to_rfc3339), + data_quota_bytes: cfg.access.user_data_quota.get(&username).copied(), + max_unique_ips: cfg.access.user_max_unique_ips.get(&username).copied(), + current_connections: stats.get_user_curr_connects(&username), + active_unique_ips: ip_counts.get(&username).copied().unwrap_or(0), + total_octets: stats.get_user_total_octets(&username), + username, + }); + } + users +}