telemt/src/api/config_store.rs

270 lines
8.4 KiB
Rust

use std::collections::BTreeMap;
use std::io::Write;
use std::path::{Path, PathBuf};
use chrono::{DateTime, Utc};
use hyper::header::IF_MATCH;
use serde::Serialize;
use sha2::{Digest, Sha256};
use crate::config::ProxyConfig;
use super::model::ApiFailure;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) enum AccessSection {
Users,
UserAdTags,
UserMaxTcpConns,
UserExpirations,
UserDataQuota,
UserMaxUniqueIps,
}
impl AccessSection {
fn table_name(self) -> &'static str {
match self {
Self::Users => "access.users",
Self::UserAdTags => "access.user_ad_tags",
Self::UserMaxTcpConns => "access.user_max_tcp_conns",
Self::UserExpirations => "access.user_expirations",
Self::UserDataQuota => "access.user_data_quota",
Self::UserMaxUniqueIps => "access.user_max_unique_ips",
}
}
}
pub(super) fn parse_if_match(headers: &hyper::HeaderMap) -> Option<String> {
headers
.get(IF_MATCH)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(|value| value.trim_matches('"').to_string())
}
pub(super) async fn ensure_expected_revision(
config_path: &Path,
expected_revision: Option<&str>,
) -> Result<(), ApiFailure> {
let Some(expected) = expected_revision else {
return Ok(());
};
let current = current_revision(config_path).await?;
if current != expected {
return Err(ApiFailure::new(
hyper::StatusCode::CONFLICT,
"revision_conflict",
"Config revision mismatch",
));
}
Ok(())
}
pub(super) async fn current_revision(config_path: &Path) -> Result<String, ApiFailure> {
let content = tokio::fs::read_to_string(config_path)
.await
.map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?;
Ok(compute_revision(&content))
}
pub(super) fn compute_revision(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
hex::encode(hasher.finalize())
}
pub(super) async fn load_config_from_disk(config_path: &Path) -> Result<ProxyConfig, ApiFailure> {
let config_path = config_path.to_path_buf();
tokio::task::spawn_blocking(move || ProxyConfig::load(config_path))
.await
.map_err(|e| ApiFailure::internal(format!("failed to join config loader: {}", e)))?
.map_err(|e| ApiFailure::internal(format!("failed to load config: {}", e)))
}
pub(super) async fn save_config_to_disk(
config_path: &Path,
cfg: &ProxyConfig,
) -> Result<String, ApiFailure> {
let serialized = toml::to_string_pretty(cfg)
.map_err(|e| ApiFailure::internal(format!("failed to serialize config: {}", e)))?;
write_atomic(config_path.to_path_buf(), serialized.clone()).await?;
Ok(compute_revision(&serialized))
}
pub(super) async fn save_access_sections_to_disk(
config_path: &Path,
cfg: &ProxyConfig,
sections: &[AccessSection],
) -> Result<String, ApiFailure> {
let mut content = tokio::fs::read_to_string(config_path)
.await
.map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?;
let mut applied = Vec::new();
for section in sections {
if applied.contains(section) {
continue;
}
let rendered = render_access_section(cfg, *section)?;
content = upsert_toml_table(&content, section.table_name(), &rendered);
applied.push(*section);
}
write_atomic(config_path.to_path_buf(), content.clone()).await?;
Ok(compute_revision(&content))
}
fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result<String, ApiFailure> {
let body = match section {
AccessSection::Users => {
let rows: BTreeMap<String, String> = cfg
.access
.users
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
serialize_table_body(&rows)?
}
AccessSection::UserAdTags => {
let rows: BTreeMap<String, String> = cfg
.access
.user_ad_tags
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
serialize_table_body(&rows)?
}
AccessSection::UserMaxTcpConns => {
let rows: BTreeMap<String, usize> = cfg
.access
.user_max_tcp_conns
.iter()
.map(|(key, value)| (key.clone(), *value))
.collect();
serialize_table_body(&rows)?
}
AccessSection::UserExpirations => {
let rows: BTreeMap<String, DateTime<Utc>> = cfg
.access
.user_expirations
.iter()
.map(|(key, value)| (key.clone(), *value))
.collect();
serialize_table_body(&rows)?
}
AccessSection::UserDataQuota => {
let rows: BTreeMap<String, u64> = cfg
.access
.user_data_quota
.iter()
.map(|(key, value)| (key.clone(), *value))
.collect();
serialize_table_body(&rows)?
}
AccessSection::UserMaxUniqueIps => {
let rows: BTreeMap<String, usize> = cfg
.access
.user_max_unique_ips
.iter()
.map(|(key, value)| (key.clone(), *value))
.collect();
serialize_table_body(&rows)?
}
};
let mut out = format!("[{}]\n", section.table_name());
if !body.is_empty() {
out.push_str(&body);
}
if !out.ends_with('\n') {
out.push('\n');
}
Ok(out)
}
fn serialize_table_body<T: Serialize>(value: &T) -> Result<String, ApiFailure> {
toml::to_string(value)
.map_err(|e| ApiFailure::internal(format!("failed to serialize access section: {}", e)))
}
fn upsert_toml_table(source: &str, table_name: &str, replacement: &str) -> String {
if let Some((start, end)) = find_toml_table_bounds(source, table_name) {
let mut out = String::with_capacity(source.len() + replacement.len());
out.push_str(&source[..start]);
out.push_str(replacement);
out.push_str(&source[end..]);
return out;
}
let mut out = source.to_string();
if !out.is_empty() && !out.ends_with('\n') {
out.push('\n');
}
if !out.is_empty() {
out.push('\n');
}
out.push_str(replacement);
out
}
fn find_toml_table_bounds(source: &str, table_name: &str) -> Option<(usize, usize)> {
let target = format!("[{}]", table_name);
let mut offset = 0usize;
let mut start = None;
for line in source.split_inclusive('\n') {
let trimmed = line.trim();
if let Some(start_offset) = start {
if trimmed.starts_with('[') {
return Some((start_offset, offset));
}
} else if trimmed == target {
start = Some(offset);
}
offset = offset.saturating_add(line.len());
}
start.map(|start_offset| (start_offset, source.len()))
}
async fn write_atomic(path: PathBuf, contents: String) -> Result<(), ApiFailure> {
tokio::task::spawn_blocking(move || write_atomic_sync(&path, &contents))
.await
.map_err(|e| ApiFailure::internal(format!("failed to join writer: {}", e)))?
.map_err(|e| ApiFailure::internal(format!("failed to write config: {}", e)))
}
fn write_atomic_sync(path: &Path, contents: &str) -> std::io::Result<()> {
let parent = path.parent().unwrap_or_else(|| Path::new("."));
std::fs::create_dir_all(parent)?;
let tmp_name = format!(
".{}.tmp-{}",
path.file_name()
.and_then(|s| s.to_str())
.unwrap_or("config.toml"),
rand::random::<u64>()
);
let tmp_path = parent.join(tmp_name);
let write_result = (|| {
let mut file = std::fs::OpenOptions::new()
.create_new(true)
.write(true)
.open(&tmp_path)?;
file.write_all(contents.as_bytes())?;
file.sync_all()?;
std::fs::rename(&tmp_path, path)?;
if let Ok(dir) = std::fs::File::open(parent) {
let _ = dir.sync_all();
}
Ok(())
})();
if write_result.is_err() {
let _ = std::fs::remove_file(&tmp_path);
}
write_result
}