mirror of https://github.com/telemt/telemt.git
270 lines
8.4 KiB
Rust
270 lines
8.4 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::io::Write;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use chrono::{DateTime, Utc};
|
|
use hyper::header::IF_MATCH;
|
|
use serde::Serialize;
|
|
use sha2::{Digest, Sha256};
|
|
|
|
use crate::config::ProxyConfig;
|
|
|
|
use super::model::ApiFailure;
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
|
pub(super) enum AccessSection {
|
|
Users,
|
|
UserAdTags,
|
|
UserMaxTcpConns,
|
|
UserExpirations,
|
|
UserDataQuota,
|
|
UserMaxUniqueIps,
|
|
}
|
|
|
|
impl AccessSection {
|
|
fn table_name(self) -> &'static str {
|
|
match self {
|
|
Self::Users => "access.users",
|
|
Self::UserAdTags => "access.user_ad_tags",
|
|
Self::UserMaxTcpConns => "access.user_max_tcp_conns",
|
|
Self::UserExpirations => "access.user_expirations",
|
|
Self::UserDataQuota => "access.user_data_quota",
|
|
Self::UserMaxUniqueIps => "access.user_max_unique_ips",
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(super) fn parse_if_match(headers: &hyper::HeaderMap) -> Option<String> {
|
|
headers
|
|
.get(IF_MATCH)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(str::trim)
|
|
.filter(|value| !value.is_empty())
|
|
.map(|value| value.trim_matches('"').to_string())
|
|
}
|
|
|
|
pub(super) async fn ensure_expected_revision(
|
|
config_path: &Path,
|
|
expected_revision: Option<&str>,
|
|
) -> Result<(), ApiFailure> {
|
|
let Some(expected) = expected_revision else {
|
|
return Ok(());
|
|
};
|
|
let current = current_revision(config_path).await?;
|
|
if current != expected {
|
|
return Err(ApiFailure::new(
|
|
hyper::StatusCode::CONFLICT,
|
|
"revision_conflict",
|
|
"Config revision mismatch",
|
|
));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub(super) async fn current_revision(config_path: &Path) -> Result<String, ApiFailure> {
|
|
let content = tokio::fs::read_to_string(config_path)
|
|
.await
|
|
.map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?;
|
|
Ok(compute_revision(&content))
|
|
}
|
|
|
|
pub(super) fn compute_revision(content: &str) -> String {
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(content.as_bytes());
|
|
hex::encode(hasher.finalize())
|
|
}
|
|
|
|
pub(super) async fn load_config_from_disk(config_path: &Path) -> Result<ProxyConfig, ApiFailure> {
|
|
let config_path = config_path.to_path_buf();
|
|
tokio::task::spawn_blocking(move || ProxyConfig::load(config_path))
|
|
.await
|
|
.map_err(|e| ApiFailure::internal(format!("failed to join config loader: {}", e)))?
|
|
.map_err(|e| ApiFailure::internal(format!("failed to load config: {}", e)))
|
|
}
|
|
|
|
pub(super) async fn save_config_to_disk(
|
|
config_path: &Path,
|
|
cfg: &ProxyConfig,
|
|
) -> Result<String, ApiFailure> {
|
|
let serialized = toml::to_string_pretty(cfg)
|
|
.map_err(|e| ApiFailure::internal(format!("failed to serialize config: {}", e)))?;
|
|
write_atomic(config_path.to_path_buf(), serialized.clone()).await?;
|
|
Ok(compute_revision(&serialized))
|
|
}
|
|
|
|
pub(super) async fn save_access_sections_to_disk(
|
|
config_path: &Path,
|
|
cfg: &ProxyConfig,
|
|
sections: &[AccessSection],
|
|
) -> Result<String, ApiFailure> {
|
|
let mut content = tokio::fs::read_to_string(config_path)
|
|
.await
|
|
.map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?;
|
|
|
|
let mut applied = Vec::new();
|
|
for section in sections {
|
|
if applied.contains(section) {
|
|
continue;
|
|
}
|
|
let rendered = render_access_section(cfg, *section)?;
|
|
content = upsert_toml_table(&content, section.table_name(), &rendered);
|
|
applied.push(*section);
|
|
}
|
|
|
|
write_atomic(config_path.to_path_buf(), content.clone()).await?;
|
|
Ok(compute_revision(&content))
|
|
}
|
|
|
|
fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result<String, ApiFailure> {
|
|
let body = match section {
|
|
AccessSection::Users => {
|
|
let rows: BTreeMap<String, String> = cfg
|
|
.access
|
|
.users
|
|
.iter()
|
|
.map(|(key, value)| (key.clone(), value.clone()))
|
|
.collect();
|
|
serialize_table_body(&rows)?
|
|
}
|
|
AccessSection::UserAdTags => {
|
|
let rows: BTreeMap<String, String> = cfg
|
|
.access
|
|
.user_ad_tags
|
|
.iter()
|
|
.map(|(key, value)| (key.clone(), value.clone()))
|
|
.collect();
|
|
serialize_table_body(&rows)?
|
|
}
|
|
AccessSection::UserMaxTcpConns => {
|
|
let rows: BTreeMap<String, usize> = cfg
|
|
.access
|
|
.user_max_tcp_conns
|
|
.iter()
|
|
.map(|(key, value)| (key.clone(), *value))
|
|
.collect();
|
|
serialize_table_body(&rows)?
|
|
}
|
|
AccessSection::UserExpirations => {
|
|
let rows: BTreeMap<String, DateTime<Utc>> = cfg
|
|
.access
|
|
.user_expirations
|
|
.iter()
|
|
.map(|(key, value)| (key.clone(), *value))
|
|
.collect();
|
|
serialize_table_body(&rows)?
|
|
}
|
|
AccessSection::UserDataQuota => {
|
|
let rows: BTreeMap<String, u64> = cfg
|
|
.access
|
|
.user_data_quota
|
|
.iter()
|
|
.map(|(key, value)| (key.clone(), *value))
|
|
.collect();
|
|
serialize_table_body(&rows)?
|
|
}
|
|
AccessSection::UserMaxUniqueIps => {
|
|
let rows: BTreeMap<String, usize> = cfg
|
|
.access
|
|
.user_max_unique_ips
|
|
.iter()
|
|
.map(|(key, value)| (key.clone(), *value))
|
|
.collect();
|
|
serialize_table_body(&rows)?
|
|
}
|
|
};
|
|
|
|
let mut out = format!("[{}]\n", section.table_name());
|
|
if !body.is_empty() {
|
|
out.push_str(&body);
|
|
}
|
|
if !out.ends_with('\n') {
|
|
out.push('\n');
|
|
}
|
|
Ok(out)
|
|
}
|
|
|
|
fn serialize_table_body<T: Serialize>(value: &T) -> Result<String, ApiFailure> {
|
|
toml::to_string(value)
|
|
.map_err(|e| ApiFailure::internal(format!("failed to serialize access section: {}", e)))
|
|
}
|
|
|
|
fn upsert_toml_table(source: &str, table_name: &str, replacement: &str) -> String {
|
|
if let Some((start, end)) = find_toml_table_bounds(source, table_name) {
|
|
let mut out = String::with_capacity(source.len() + replacement.len());
|
|
out.push_str(&source[..start]);
|
|
out.push_str(replacement);
|
|
out.push_str(&source[end..]);
|
|
return out;
|
|
}
|
|
|
|
let mut out = source.to_string();
|
|
if !out.is_empty() && !out.ends_with('\n') {
|
|
out.push('\n');
|
|
}
|
|
if !out.is_empty() {
|
|
out.push('\n');
|
|
}
|
|
out.push_str(replacement);
|
|
out
|
|
}
|
|
|
|
fn find_toml_table_bounds(source: &str, table_name: &str) -> Option<(usize, usize)> {
|
|
let target = format!("[{}]", table_name);
|
|
let mut offset = 0usize;
|
|
let mut start = None;
|
|
|
|
for line in source.split_inclusive('\n') {
|
|
let trimmed = line.trim();
|
|
if let Some(start_offset) = start {
|
|
if trimmed.starts_with('[') {
|
|
return Some((start_offset, offset));
|
|
}
|
|
} else if trimmed == target {
|
|
start = Some(offset);
|
|
}
|
|
offset = offset.saturating_add(line.len());
|
|
}
|
|
|
|
start.map(|start_offset| (start_offset, source.len()))
|
|
}
|
|
|
|
async fn write_atomic(path: PathBuf, contents: String) -> Result<(), ApiFailure> {
|
|
tokio::task::spawn_blocking(move || write_atomic_sync(&path, &contents))
|
|
.await
|
|
.map_err(|e| ApiFailure::internal(format!("failed to join writer: {}", e)))?
|
|
.map_err(|e| ApiFailure::internal(format!("failed to write config: {}", e)))
|
|
}
|
|
|
|
fn write_atomic_sync(path: &Path, contents: &str) -> std::io::Result<()> {
|
|
let parent = path.parent().unwrap_or_else(|| Path::new("."));
|
|
std::fs::create_dir_all(parent)?;
|
|
|
|
let tmp_name = format!(
|
|
".{}.tmp-{}",
|
|
path.file_name()
|
|
.and_then(|s| s.to_str())
|
|
.unwrap_or("config.toml"),
|
|
rand::random::<u64>()
|
|
);
|
|
let tmp_path = parent.join(tmp_name);
|
|
|
|
let write_result = (|| {
|
|
let mut file = std::fs::OpenOptions::new()
|
|
.create_new(true)
|
|
.write(true)
|
|
.open(&tmp_path)?;
|
|
file.write_all(contents.as_bytes())?;
|
|
file.sync_all()?;
|
|
std::fs::rename(&tmp_path, path)?;
|
|
if let Ok(dir) = std::fs::File::open(parent) {
|
|
let _ = dir.sync_all();
|
|
}
|
|
Ok(())
|
|
})();
|
|
|
|
if write_result.is_err() {
|
|
let _ = std::fs::remove_file(&tmp_path);
|
|
}
|
|
write_result
|
|
}
|