diff --git a/src/api/model.rs b/src/api/model.rs index fa1f063..1ca9f33 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -5,6 +5,7 @@ use chrono::{DateTime, Utc}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; +use super::patch::{Patch, patch_field}; use crate::crypto::SecureRandom; const MAX_USERNAME_LEN: usize = 64; @@ -507,11 +508,16 @@ pub(super) struct CreateUserRequest { #[derive(Deserialize)] pub(super) struct PatchUserRequest { pub(super) secret: Option, - pub(super) user_ad_tag: Option, - pub(super) max_tcp_conns: Option, - pub(super) expiration_rfc3339: Option, - pub(super) data_quota_bytes: Option, - pub(super) max_unique_ips: Option, + #[serde(default, deserialize_with = "patch_field")] + pub(super) user_ad_tag: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) max_tcp_conns: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) expiration_rfc3339: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) data_quota_bytes: Patch, + #[serde(default, deserialize_with = "patch_field")] + pub(super) max_unique_ips: Patch, } #[derive(Default, Deserialize)] @@ -530,6 +536,20 @@ pub(super) fn parse_optional_expiration( Ok(Some(parsed.with_timezone(&Utc))) } +pub(super) fn parse_patch_expiration( + value: &Patch, +) -> Result>, ApiFailure> { + match value { + Patch::Unchanged => Ok(Patch::Unchanged), + Patch::Remove => Ok(Patch::Remove), + Patch::Set(raw) => { + let parsed = DateTime::parse_from_rfc3339(raw) + .map_err(|_| ApiFailure::bad_request("expiration_rfc3339 must be valid RFC3339"))?; + Ok(Patch::Set(parsed.with_timezone(&Utc))) + } + } +} + pub(super) fn is_valid_user_secret(secret: &str) -> bool { secret.len() == 32 && secret.chars().all(|c| c.is_ascii_hexdigit()) } diff --git a/src/api/patch.rs b/src/api/patch.rs index 6425af1..65cd191 100644 --- a/src/api/patch.rs +++ b/src/api/patch.rs @@ -41,6 +41,8 @@ where #[cfg(test)] mod tests { use super::*; + use crate::api::model::{PatchUserRequest, parse_patch_expiration}; + use chrono::{TimeZone, Utc}; use serde::Deserialize; #[derive(Deserialize)] @@ -76,4 +78,53 @@ mod tests { let h = parse(r#"{"value": 0}"#); assert!(matches!(h.value, Patch::Set(0))); } + + #[test] + fn parse_patch_expiration_passes_unchanged_and_remove_through() { + assert!(matches!( + parse_patch_expiration(&Patch::Unchanged), + Ok(Patch::Unchanged) + )); + assert!(matches!( + parse_patch_expiration(&Patch::Remove), + Ok(Patch::Remove) + )); + } + + #[test] + fn parse_patch_expiration_parses_set_value() { + let parsed = + parse_patch_expiration(&Patch::Set("2030-01-02T03:04:05Z".into())).expect("valid"); + match parsed { + Patch::Set(dt) => { + assert_eq!(dt, Utc.with_ymd_and_hms(2030, 1, 2, 3, 4, 5).unwrap()); + } + other => panic!("expected Patch::Set, got {:?}", other), + } + } + + #[test] + fn parse_patch_expiration_rejects_invalid_set_value() { + assert!(parse_patch_expiration(&Patch::Set("not-a-date".into())).is_err()); + } + + #[test] + fn patch_user_request_deserializes_mixed_states() { + let raw = r#"{ + "secret": "00112233445566778899aabbccddeeff", + "max_tcp_conns": 0, + "max_unique_ips": null, + "data_quota_bytes": 1024 + }"#; + let req: PatchUserRequest = serde_json::from_str(raw).expect("valid json"); + assert_eq!( + req.secret.as_deref(), + Some("00112233445566778899aabbccddeeff") + ); + assert!(matches!(req.max_tcp_conns, Patch::Set(0))); + assert!(matches!(req.max_unique_ips, Patch::Remove)); + assert!(matches!(req.data_quota_bytes, Patch::Set(1024))); + assert!(matches!(req.expiration_rfc3339, Patch::Unchanged)); + assert!(matches!(req.user_ad_tag, Patch::Unchanged)); + } } diff --git a/src/api/users.rs b/src/api/users.rs index 6b20b85..ef8f10a 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -14,8 +14,9 @@ use super::config_store::{ use super::model::{ ApiFailure, CreateUserRequest, CreateUserResponse, PatchUserRequest, RotateSecretRequest, UserInfo, UserLinks, is_valid_ad_tag, is_valid_user_secret, is_valid_username, - parse_optional_expiration, random_user_secret, + parse_optional_expiration, parse_patch_expiration, random_user_secret, }; +use super::patch::Patch; pub(super) async fn create_user( body: CreateUserRequest, @@ -182,14 +183,14 @@ pub(super) async fn patch_user( "secret must be exactly 32 hex characters", )); } - if let Some(ad_tag) = body.user_ad_tag.as_ref() + if let Patch::Set(ad_tag) = &body.user_ad_tag && !is_valid_ad_tag(ad_tag) { return Err(ApiFailure::bad_request( "user_ad_tag must be exactly 32 hex characters", )); } - let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?; + let expiration = parse_patch_expiration(&body.expiration_rfc3339)?; let _guard = shared.mutation_lock.lock().await; let mut cfg = load_config_from_disk(&shared.config_path).await?; ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?; @@ -205,38 +206,71 @@ pub(super) async fn patch_user( if let Some(secret) = body.secret { cfg.access.users.insert(user.to_string(), secret); } - if let Some(ad_tag) = body.user_ad_tag { - cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); + match body.user_ad_tag { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_ad_tags.remove(user); + } + Patch::Set(ad_tag) => { + cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); + } } - if let Some(limit) = body.max_tcp_conns { - cfg.access - .user_max_tcp_conns - .insert(user.to_string(), limit); + match body.max_tcp_conns { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_max_tcp_conns.remove(user); + } + Patch::Set(limit) => { + cfg.access + .user_max_tcp_conns + .insert(user.to_string(), limit); + } } - if let Some(expiration) = expiration { - cfg.access - .user_expirations - .insert(user.to_string(), expiration); + match expiration { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_expirations.remove(user); + } + Patch::Set(expiration) => { + cfg.access + .user_expirations + .insert(user.to_string(), expiration); + } } - if let Some(quota) = body.data_quota_bytes { - cfg.access.user_data_quota.insert(user.to_string(), quota); - } - - let mut updated_limit = None; - if let Some(limit) = body.max_unique_ips { - cfg.access - .user_max_unique_ips - .insert(user.to_string(), limit); - updated_limit = Some(limit); + match body.data_quota_bytes { + Patch::Unchanged => {} + Patch::Remove => { + cfg.access.user_data_quota.remove(user); + } + Patch::Set(quota) => { + cfg.access.user_data_quota.insert(user.to_string(), quota); + } } + // Capture how the per-user IP limit changed, so the in-memory ip_tracker + // can be synced (set or removed) after the config is persisted. + let max_unique_ips_change = match body.max_unique_ips { + Patch::Unchanged => None, + Patch::Remove => { + cfg.access.user_max_unique_ips.remove(user); + Some(None) + } + Patch::Set(limit) => { + cfg.access + .user_max_unique_ips + .insert(user.to_string(), limit); + Some(Some(limit)) + } + }; cfg.validate() .map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?; let revision = save_config_to_disk(&shared.config_path, &cfg).await?; drop(_guard); - if let Some(limit) = updated_limit { - shared.ip_tracker.set_user_limit(user, limit).await; + match max_unique_ips_change { + Some(Some(limit)) => shared.ip_tracker.set_user_limit(user, limit).await, + Some(None) => shared.ip_tracker.remove_user_limit(user).await, + None => {} } let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let users = users_from_config(