diff --git a/Cargo.lock b/Cargo.lock index aad4bbc..a7e52c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2834,6 +2834,7 @@ dependencies = [ "socket2", "static_assertions", "subtle", + "tempfile", "thiserror", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index ed158f4..d33815d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,6 +90,7 @@ tokio-test = "0.4" criterion = "0.8" proptest = "1.4" futures = "0.3" +tempfile = "3.27.0" [[bench]] name = "crypto_bench" diff --git a/src/api/config_edit.rs b/src/api/config_edit.rs new file mode 100644 index 0000000..5fe1f57 --- /dev/null +++ b/src/api/config_edit.rs @@ -0,0 +1,308 @@ +//! Config-editing API: read managed sections and apply sparse field patches. +//! `access.*` is intentionally not editable here (owned by the users API). + +use serde_json::Value as Json; +use toml::Value as Toml; + +use super::ApiShared; +use super::config_store::{ + EDITABLE_SECTIONS, compute_revision, current_revision, save_sections_to_disk, +}; +use super::model::ApiFailure; +use crate::config::ProxyConfig; +use crate::config::hot_reload::classify_config_changes; +use serde::Serialize; +use std::path::Path; + +#[derive(Debug, Serialize)] +pub(super) struct PatchConfigResponse { + pub revision: String, + pub restart_required: bool, + pub changed: Vec, +} + +/// Shared-state wrapper around [`apply_patch_to_path`]: serializes config +/// mutations behind `mutation_lock`, then records a runtime event. The route +/// handler calls this; the core logic stays decoupled for unit tests. +pub(super) async fn patch_config( + patch_json: Json, + expected_revision: Option, + shared: &ApiShared, +) -> Result { + let _guard = shared.mutation_lock.lock().await; + let resp = apply_patch_to_path(&shared.config_path, &patch_json, expected_revision).await?; + drop(_guard); + shared + .runtime_events + .record("api.config.patch.ok", format!("changed={:?}", resp.changed)); + Ok(resp) +} + +/// Core patch logic, decoupled from hyper/shared-state so it is unit-testable +/// against a temp file. The route handler holds `mutation_lock` while calling this. +pub(super) async fn apply_patch_to_path( + config_path: &Path, + patch_json: &Json, + expected_revision: Option, +) -> Result { + // 1. optimistic concurrency + let current = current_revision(config_path).await?; + if expected_revision.is_some_and(|expected| expected != current) { + return Err(ApiFailure::new( + hyper::StatusCode::CONFLICT, + "revision_conflict", + "Config revision mismatch", + )); + } + + // 2. convert + reject access / unknown sections + let patch_toml = json_to_toml(patch_json) + .map_err(|e| ApiFailure::bad_request(format!("invalid patch: {}", e)))?; + let patch_table = patch_toml + .as_table() + .ok_or_else(|| ApiFailure::bad_request("patch must be a JSON object"))?; + if patch_table.contains_key("access") { + return Err(ApiFailure::new( + hyper::StatusCode::BAD_REQUEST, + "access_not_editable", + "access.* is managed via the users API, not editable here", + )); + } + for key in patch_table.keys() { + if !EDITABLE_SECTIONS.contains(&key.as_str()) { + return Err(ApiFailure::new( + hyper::StatusCode::BAD_REQUEST, + "section_not_editable", + format!("section not editable: {}", key), + )); + } + } + let touched: Vec<&str> = patch_table + .keys() + .map(|k| k.as_str()) + .filter(|k| EDITABLE_SECTIONS.contains(k)) + .collect(); + if touched.is_empty() { + return Err(ApiFailure::bad_request("empty patch: no editable sections")); + } + + // 3. Parse old + merged from the SAME deserialize path so the classifier + // sees only the delta this patch introduces. `ProxyConfig::load` applies + // include-expansion / legacy-compat / normalization that a bare + // `try_into` does not; mixing the two paths would make unrelated fields + // compare unequal and spuriously force `restart_required`. + let original = tokio::fs::read_to_string(config_path) + .await + .map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?; + let original_toml: Toml = toml::from_str(&original) + .map_err(|e| ApiFailure::internal(format!("failed to parse config: {}", e)))?; + let old_cfg: ProxyConfig = original_toml + .clone() + .try_into() + .map_err(|e| ApiFailure::internal(format!("config does not deserialize: {}", e)))?; + + let mut merged = original_toml; + deep_merge(&mut merged, &patch_toml); + + let new_cfg: ProxyConfig = merged + .clone() + .try_into() + .map_err(|e| ApiFailure::bad_request(format!("config does not deserialize: {}", e)))?; + new_cfg + .validate() + .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; + + // 4. classify changes (Telemt's own hot/restart rule) + let class = classify_config_changes(&old_cfg, &new_cfg); + + // 5. write only the touched top-level sections + let revision = save_sections_to_disk(config_path, &new_cfg, &touched).await?; + + Ok(PatchConfigResponse { + revision, + restart_required: class.restart_required, + changed: class.changed, + }) +} + +/// Return the editable config sections (no `access.*`) + current revision. +pub(super) async fn read_managed_config(config_path: &Path) -> Result<(Toml, String), ApiFailure> { + let original = tokio::fs::read_to_string(config_path) + .await + .map_err(|e| ApiFailure::internal(format!("failed to read config: {}", e)))?; + let parsed: Toml = toml::from_str(&original) + .map_err(|e| ApiFailure::internal(format!("failed to parse config: {}", e)))?; + + let mut table = parsed + .as_table() + .cloned() + .unwrap_or_else(toml::value::Table::new); + table.remove("access"); // never expose users/secrets via this endpoint + + let revision = compute_revision(&original); + Ok((Toml::Table(table), revision)) +} + +/// Convert a serde_json value to a toml value. `null` is dropped from objects +/// (a patch never sets a key to TOML-null). Numbers become integers when exact, +/// otherwise floats. +fn json_to_toml(j: &Json) -> Result { + Ok(match j { + Json::Null => return Err("null is not representable in TOML".into()), + Json::Bool(b) => Toml::Boolean(*b), + Json::Number(n) => { + if let Some(i) = n.as_i64() { + Toml::Integer(i) + } else if let Some(f) = n.as_f64() { + Toml::Float(f) + } else { + return Err(format!("unrepresentable number: {}", n)); + } + } + Json::String(s) => Toml::String(s.clone()), + Json::Array(items) => { + let mut out = Vec::with_capacity(items.len()); + for item in items { + out.push(json_to_toml(item)?); + } + Toml::Array(out) + } + Json::Object(map) => { + let mut table = toml::value::Table::new(); + for (k, v) in map { + if v.is_null() { + continue; // skip nulls instead of erroring at object level + } + table.insert(k.clone(), json_to_toml(v)?); + } + Toml::Table(table) + } + }) +} + +/// Recursively overlay `patch` onto `base`. Tables merge key-by-key; every +/// other value type (scalars, arrays) replaces wholesale. +fn deep_merge(base: &mut Toml, patch: &Toml) { + match (base, patch) { + (Toml::Table(b), Toml::Table(p)) => { + for (k, pv) in p { + match b.get_mut(k) { + Some(bv) => deep_merge(bv, pv), + None => { + b.insert(k.clone(), pv.clone()); + } + } + } + } + (b, p) => *b = p.clone(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn json_object_converts_to_toml_table() { + let j: Json = serde_json::json!({"censorship": {"tls_domain": "a.com"}, "default_dc": 2}); + let t = json_to_toml(&j).expect("convertible"); + let table = t.as_table().unwrap(); + assert_eq!(table["censorship"]["tls_domain"].as_str(), Some("a.com")); + assert_eq!(table["default_dc"].as_integer(), Some(2)); + } + + #[test] + fn deep_merge_overlays_tables_and_replaces_scalars() { + let mut base: Toml = + toml::from_str("[censorship]\ntls_domain = \"old\"\nfake_cert_len = 100\n").unwrap(); + let patch: Toml = toml::from_str("[censorship]\ntls_domain = \"new\"\n").unwrap(); + + deep_merge(&mut base, &patch); + + let cens = base["censorship"].as_table().unwrap(); + assert_eq!(cens["tls_domain"].as_str(), Some("new")); // overlaid + assert_eq!(cens["fake_cert_len"].as_integer(), Some(100)); // preserved + } + + use std::path::PathBuf; + + fn temp_config(body: &str) -> (PathBuf, tempfile::TempDir) { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("config.toml"); + std::fs::write(&path, body).unwrap(); + (path, dir) + } + + #[tokio::test] + async fn patch_rejects_access_section() { + let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n"); + let patch: Json = serde_json::json!({"access": {"users": {"x": "y"}}}); + let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err(); + assert_eq!(err.code, "access_not_editable"); + } + + #[tokio::test] + async fn patch_revision_conflict() { + let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n"); + let patch: Json = serde_json::json!({"censorship": {"tls_domain": "b"}}); + let err = apply_patch_to_path(&path, &patch, Some("deadbeef".into())) + .await + .unwrap_err(); + assert_eq!(err.code, "revision_conflict"); + } + + #[tokio::test] + async fn patch_sni_reports_restart_required() { + let (path, _d) = + temp_config("[censorship]\ntls_domain = \"a.com\"\n[server]\nport = 443\n"); + let patch: Json = serde_json::json!({"censorship": {"tls_domain": "b.com"}}); + let resp = apply_patch_to_path(&path, &patch, None).await.unwrap(); + assert!(resp.restart_required); + assert!(resp.changed.iter().any(|c| c == "censorship")); + let written = std::fs::read_to_string(&path).unwrap(); + assert!(written.contains("tls_domain = \"b.com\"")); + assert_eq!( + resp.revision, + crate::api::config_store::compute_revision(&written) + ); + } + + #[tokio::test] + async fn read_managed_config_strips_access() { + let (path, _d) = temp_config( + "[censorship]\ntls_domain = \"a.com\"\n[access.users]\nbob = \"deadbeef\"\n", + ); + let (value, revision) = read_managed_config(&path).await.unwrap(); + let table = value.as_table().unwrap(); + assert!(table.contains_key("censorship")); + assert!(!table.contains_key("access")); // secrets never leave the box here + assert_eq!(revision, current_revision(&path).await.unwrap()); + } + + #[tokio::test] + async fn patch_rejects_server_section() { + let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n"); + let patch: Json = serde_json::json!({"server": {"port": 1}}); + let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err(); + assert_eq!(err.code, "section_not_editable"); + } + + #[tokio::test] + async fn patch_empty_is_rejected() { + let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n"); + let patch: Json = serde_json::json!({}); + assert!(apply_patch_to_path(&path, &patch, None).await.is_err()); + } + + #[tokio::test] + async fn patch_log_level_is_hot() { + // general.log_level is hot-reloadable -> a patch changing only it must + // report restart_required = false (exercises the full apply path, not + // just the classifier). Default LogLevel is Normal; patch to "debug". + let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n"); + let patch: Json = serde_json::json!({"general": {"log_level": "debug"}}); + let resp = apply_patch_to_path(&path, &patch, None).await.unwrap(); + assert!(!resp.restart_required); + assert!(resp.changed.iter().any(|c| c == "general")); + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 4239b59..8724416 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -28,6 +28,7 @@ use crate::stats::Stats; use crate::transport::UpstreamManager; use crate::transport::middle_proxy::MePool; +mod config_edit; mod config_store; mod events; mod http_utils; @@ -84,6 +85,7 @@ const ALLOW_GET: &str = "GET"; const ALLOW_POST: &str = "POST"; const ALLOW_GET_POST: &str = "GET, POST"; const ALLOW_GET_PATCH_DELETE: &str = "GET, PATCH, DELETE"; +const ALLOW_GET_PATCH: &str = "GET, PATCH"; pub(super) struct ApiRuntimeState { pub(super) process_started_at_epoch_secs: u64, @@ -174,6 +176,7 @@ fn allowed_methods_for_path(path: &str) -> Option<&'static str> { | "/v1/stats/users/quota" | "/v1/stats/users" => Some(ALLOW_GET), "/v1/users" => Some(ALLOW_GET_POST), + "/v1/config" => Some(ALLOW_GET_PATCH), _ if user_action_route_matches(path, "/reset-quota") => Some(ALLOW_POST), _ if user_action_route_matches(path, "/rotate-secret") => Some(ALLOW_POST), _ if user_action_route_matches(path, "/enable") => Some(ALLOW_POST), @@ -643,6 +646,37 @@ async fn handle( }; Ok(success_response(status, data, revision)) } + ("GET", "/v1/config") => { + let (value, revision) = + config_edit::read_managed_config(&shared.config_path).await?; + Ok(success_response(StatusCode::OK, value, revision)) + } + ("PATCH", "/v1/config") => { + if api_cfg.read_only { + return Ok(error_response( + request_id, + ApiFailure::new( + StatusCode::FORBIDDEN, + "read_only", + "API runs in read-only mode", + ), + )); + } + let expected_revision = parse_if_match(req.headers()); + let body = read_json::(req.into_body(), body_limit).await?; + match config_edit::patch_config(body, expected_revision, &shared).await { + Ok(resp) => { + let revision = resp.revision.clone(); + Ok(success_response(StatusCode::OK, resp, revision)) + } + Err(error) => { + shared + .runtime_events + .record("api.config.patch.failed", error.code); + Err(error) + } + } + } _ => { if method == Method::POST && let Some(base_user) = normalized_path