diff --git a/src/api/mod.rs b/src/api/mod.rs index 299d5a1..55d790f 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,5 @@ use std::convert::Infallible; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; @@ -43,6 +43,8 @@ pub(super) struct ApiShared { pub(super) ip_tracker: Arc, pub(super) me_pool: Option>, pub(super) config_path: PathBuf, + pub(super) startup_detected_ip_v4: Option, + pub(super) startup_detected_ip_v6: Option, pub(super) mutation_lock: Arc>, pub(super) minimal_cache: Arc>>, pub(super) request_id: Arc, @@ -61,6 +63,8 @@ pub async fn serve( me_pool: Option>, config_rx: watch::Receiver>, config_path: PathBuf, + startup_detected_ip_v4: Option, + startup_detected_ip_v6: Option, ) { let listener = match TcpListener::bind(listen).await { Ok(listener) => listener, @@ -81,6 +85,8 @@ pub async fn serve( ip_tracker, me_pool, config_path, + startup_detected_ip_v4, + startup_detected_ip_v6, mutation_lock: Arc::new(Mutex::new(())), minimal_cache: Arc::new(Mutex::new(None)), request_id: Arc::new(AtomicU64::new(1)), @@ -212,7 +218,14 @@ async fn handle( } ("GET", "/v1/stats/users") | ("GET", "/v1/users") => { let revision = current_revision(&shared.config_path).await?; - let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let users = users_from_config( + &cfg, + &shared.stats, + &shared.ip_tracker, + shared.startup_detected_ip_v4, + shared.startup_detected_ip_v6, + ) + .await; Ok(success_response(StatusCode::OK, users, revision)) } ("POST", "/v1/users") => { @@ -238,7 +251,14 @@ async fn handle( { if method == Method::GET { let revision = current_revision(&shared.config_path).await?; - let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let users = users_from_config( + &cfg, + &shared.stats, + &shared.ip_tracker, + shared.startup_detected_ip_v4, + shared.startup_detected_ip_v6, + ) + .await; if let Some(user_info) = users.into_iter().find(|entry| entry.username == user) { return Ok(success_response(StatusCode::OK, user_info, revision)); diff --git a/src/api/users.rs b/src/api/users.rs index 9fc03e9..c907070 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -92,7 +92,14 @@ pub(super) async fn create_user( shared.ip_tracker.set_user_limit(&body.username, limit).await; } - let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let users = users_from_config( + &cfg, + &shared.stats, + &shared.ip_tracker, + shared.startup_detected_ip_v4, + shared.startup_detected_ip_v6, + ) + .await; let user = users .into_iter() .find(|entry| entry.username == body.username) @@ -106,7 +113,12 @@ pub(super) async fn create_user( current_connections: 0, active_unique_ips: 0, total_octets: 0, - links: build_user_links(&cfg, &secret), + links: build_user_links( + &cfg, + &secret, + shared.startup_detected_ip_v4, + shared.startup_detected_ip_v6, + ), }); Ok((CreateUserResponse { user, secret }, revision)) @@ -171,7 +183,14 @@ pub(super) async fn patch_user( if let Some(limit) = updated_limit { shared.ip_tracker.set_user_limit(user, limit).await; } - let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let users = users_from_config( + &cfg, + &shared.stats, + &shared.ip_tracker, + shared.startup_detected_ip_v4, + shared.startup_detected_ip_v6, + ) + .await; let user_info = users .into_iter() .find(|entry| entry.username == user) @@ -211,7 +230,14 @@ pub(super) async fn rotate_secret( let revision = save_config_to_disk(&shared.config_path, &cfg).await?; drop(_guard); - let users = users_from_config(&cfg, &shared.stats, &shared.ip_tracker).await; + let users = users_from_config( + &cfg, + &shared.stats, + &shared.ip_tracker, + shared.startup_detected_ip_v4, + shared.startup_detected_ip_v6, + ) + .await; let user_info = users .into_iter() .find(|entry| entry.username == user) @@ -270,6 +296,8 @@ pub(super) async fn users_from_config( cfg: &ProxyConfig, stats: &Stats, ip_tracker: &UserIpTracker, + startup_detected_ip_v4: Option, + startup_detected_ip_v6: Option, ) -> Vec { let ip_counts = ip_tracker .get_stats() @@ -287,7 +315,14 @@ pub(super) async fn users_from_config( .access .users .get(&username) - .map(|secret| build_user_links(cfg, secret)) + .map(|secret| { + build_user_links( + cfg, + secret, + startup_detected_ip_v4, + startup_detected_ip_v6, + ) + }) .unwrap_or(UserLinks { classic: Vec::new(), secure: Vec::new(), @@ -313,8 +348,13 @@ pub(super) async fn users_from_config( users } -fn build_user_links(cfg: &ProxyConfig, secret: &str) -> UserLinks { - let hosts = resolve_link_hosts(cfg); +fn build_user_links( + cfg: &ProxyConfig, + secret: &str, + startup_detected_ip_v4: Option, + startup_detected_ip_v6: Option, +) -> UserLinks { + let hosts = resolve_link_hosts(cfg, startup_detected_ip_v4, startup_detected_ip_v6); let port = cfg.general.links.public_port.unwrap_or(cfg.server.port); let tls_domains = resolve_tls_domains(cfg); @@ -353,7 +393,11 @@ fn build_user_links(cfg: &ProxyConfig, secret: &str) -> UserLinks { } } -fn resolve_link_hosts(cfg: &ProxyConfig) -> Vec { +fn resolve_link_hosts( + cfg: &ProxyConfig, + startup_detected_ip_v4: Option, + startup_detected_ip_v6: Option, +) -> Vec { if let Some(host) = cfg .general .links @@ -365,6 +409,17 @@ fn resolve_link_hosts(cfg: &ProxyConfig) -> Vec { return vec![host.to_string()]; } + let mut startup_hosts = Vec::new(); + if let Some(ip) = startup_detected_ip_v4 { + push_unique_host(&mut startup_hosts, &ip.to_string()); + } + if let Some(ip) = startup_detected_ip_v6 { + push_unique_host(&mut startup_hosts, &ip.to_string()); + } + if !startup_hosts.is_empty() { + return startup_hosts; + } + let mut hosts = Vec::new(); for listener in &cfg.server.listeners { if let Some(host) = listener diff --git a/src/main.rs b/src/main.rs index c4f0e68..1845fdb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1171,6 +1171,8 @@ async fn main() -> std::result::Result<(), Box> { let me_pool_api = me_pool.clone(); let config_rx_api = config_rx.clone(); let config_path_api = std::path::PathBuf::from(&config_path); + let startup_detected_ip_v4 = detected_ip_v4; + let startup_detected_ip_v6 = detected_ip_v6; tokio::spawn(async move { api::serve( listen, @@ -1179,6 +1181,8 @@ async fn main() -> std::result::Result<(), Box> { me_pool_api, config_rx_api, config_path_api, + startup_detected_ip_v4, + startup_detected_ip_v6, ) .await; });