use std::collections::BTreeMap; use std::io::Write; use std::path::{Path, PathBuf}; use chrono::{DateTime, Utc}; use hyper::header::IF_MATCH; use serde::Serialize; use sha2::{Digest, Sha256}; use crate::config::ProxyConfig; use super::model::ApiFailure; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(super) enum AccessSection { Users, UserAdTags, UserMaxTcpConns, UserExpirations, UserDataQuota, UserMaxUniqueIps, } impl AccessSection { fn table_name(self) -> &'static str { match self { Self::Users => "access.users", Self::UserAdTags => "access.user_ad_tags", Self::UserMaxTcpConns => "access.user_max_tcp_conns", Self::UserExpirations => "access.user_expirations", Self::UserDataQuota => "access.user_data_quota", Self::UserMaxUniqueIps => "access.user_max_unique_ips", } } } pub(super) fn parse_if_match(headers: &hyper::HeaderMap) -> Option { headers .get(IF_MATCH) .and_then(|value| value.to_str().ok()) .map(str::trim) .filter(|value| !value.is_empty()) .map(|value| value.trim_matches('"').to_string()) } pub(super) async fn ensure_expected_revision( config_path: &Path, expected_revision: Option<&str>, ) -> Result<(), ApiFailure> { let Some(expected) = expected_revision else { return Ok(()); }; let current = current_revision(config_path).await?; if current != expected { return Err(ApiFailure::new( hyper::StatusCode::CONFLICT, "revision_conflict", "Config revision mismatch", )); } Ok(()) } pub(super) async fn current_revision(config_path: &Path) -> Result { let content = tokio::fs::read_to_string(config_path) .await .map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?; Ok(compute_revision(&content)) } pub(super) fn compute_revision(content: &str) -> String { let mut hasher = Sha256::new(); hasher.update(content.as_bytes()); hex::encode(hasher.finalize()) } pub(super) async fn load_config_from_disk(config_path: &Path) -> Result { let config_path = config_path.to_path_buf(); tokio::task::spawn_blocking(move || ProxyConfig::load(config_path)) .await .map_err(|e| ApiFailure::internal(format!("failed to join config loader: {}", e)))? .map_err(|e| ApiFailure::internal(format!("failed to load config: {}", e))) } pub(super) async fn save_config_to_disk( config_path: &Path, cfg: &ProxyConfig, ) -> Result { let serialized = toml::to_string_pretty(cfg) .map_err(|e| ApiFailure::internal(format!("failed to serialize config: {}", e)))?; write_atomic(config_path.to_path_buf(), serialized.clone()).await?; Ok(compute_revision(&serialized)) } pub(super) async fn save_access_sections_to_disk( config_path: &Path, cfg: &ProxyConfig, sections: &[AccessSection], ) -> Result { let mut content = tokio::fs::read_to_string(config_path) .await .map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?; let mut applied = Vec::new(); for section in sections { if applied.contains(section) { continue; } let rendered = render_access_section(cfg, *section)?; content = upsert_toml_table(&content, section.table_name(), &rendered); applied.push(*section); } write_atomic(config_path.to_path_buf(), content.clone()).await?; Ok(compute_revision(&content)) } fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result { let body = match section { AccessSection::Users => { let rows: BTreeMap = cfg .access .users .iter() .map(|(key, value)| (key.clone(), value.clone())) .collect(); serialize_table_body(&rows)? } AccessSection::UserAdTags => { let rows: BTreeMap = cfg .access .user_ad_tags .iter() .map(|(key, value)| (key.clone(), value.clone())) .collect(); serialize_table_body(&rows)? } AccessSection::UserMaxTcpConns => { let rows: BTreeMap = cfg .access .user_max_tcp_conns .iter() .map(|(key, value)| (key.clone(), *value)) .collect(); serialize_table_body(&rows)? } AccessSection::UserExpirations => { let rows: BTreeMap> = cfg .access .user_expirations .iter() .map(|(key, value)| (key.clone(), *value)) .collect(); serialize_table_body(&rows)? } AccessSection::UserDataQuota => { let rows: BTreeMap = cfg .access .user_data_quota .iter() .map(|(key, value)| (key.clone(), *value)) .collect(); serialize_table_body(&rows)? } AccessSection::UserMaxUniqueIps => { let rows: BTreeMap = cfg .access .user_max_unique_ips .iter() .map(|(key, value)| (key.clone(), *value)) .collect(); serialize_table_body(&rows)? } }; let mut out = format!("[{}]\n", section.table_name()); if !body.is_empty() { out.push_str(&body); } if !out.ends_with('\n') { out.push('\n'); } Ok(out) } fn serialize_table_body(value: &T) -> Result { toml::to_string(value) .map_err(|e| ApiFailure::internal(format!("failed to serialize access section: {}", e))) } fn upsert_toml_table(source: &str, table_name: &str, replacement: &str) -> String { if let Some((start, end)) = find_toml_table_bounds(source, table_name) { let mut out = String::with_capacity(source.len() + replacement.len()); out.push_str(&source[..start]); out.push_str(replacement); out.push_str(&source[end..]); return out; } let mut out = source.to_string(); if !out.is_empty() && !out.ends_with('\n') { out.push('\n'); } if !out.is_empty() { out.push('\n'); } out.push_str(replacement); out } fn find_toml_table_bounds(source: &str, table_name: &str) -> Option<(usize, usize)> { let target = format!("[{}]", table_name); let mut offset = 0usize; let mut start = None; for line in source.split_inclusive('\n') { let trimmed = line.trim(); if let Some(start_offset) = start { if trimmed.starts_with('[') { return Some((start_offset, offset)); } } else if trimmed == target { start = Some(offset); } offset = offset.saturating_add(line.len()); } start.map(|start_offset| (start_offset, source.len())) } async fn write_atomic(path: PathBuf, contents: String) -> Result<(), ApiFailure> { tokio::task::spawn_blocking(move || write_atomic_sync(&path, &contents)) .await .map_err(|e| ApiFailure::internal(format!("failed to join writer: {}", e)))? .map_err(|e| ApiFailure::internal(format!("failed to write config: {}", e))) } fn write_atomic_sync(path: &Path, contents: &str) -> std::io::Result<()> { let parent = path.parent().unwrap_or_else(|| Path::new(".")); std::fs::create_dir_all(parent)?; let tmp_name = format!( ".{}.tmp-{}", path.file_name() .and_then(|s| s.to_str()) .unwrap_or("config.toml"), rand::random::() ); let tmp_path = parent.join(tmp_name); let write_result = (|| { let mut file = std::fs::OpenOptions::new() .create_new(true) .write(true) .open(&tmp_path)?; file.write_all(contents.as_bytes())?; file.sync_all()?; std::fs::rename(&tmp_path, path)?; if let Ok(dir) = std::fs::File::open(parent) { let _ = dir.sync_all(); } Ok(()) })(); if write_result.is_err() { let _ = std::fs::remove_file(&tmp_path); } write_result }