diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index c949104..acc64cd 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -16,6 +16,7 @@ //! | `general` | `me_pool_drain_ttl_secs` | Applied on next ME map update | //! | `general` | `me_pool_min_fresh_ratio` | Applied on next ME map update | //! | `general` | `me_reinit_drain_timeout_secs`| Applied on next ME map update | +//! | `network` | `dns_overrides` | Applied immediately | //! | `access` | All user/quota fields | Effective immediately | //! //! Fields that require re-binding sockets (`server.port`, `censorship.*`, @@ -39,6 +40,7 @@ use super::load::ProxyConfig; pub struct HotFields { pub log_level: LogLevel, pub ad_tag: Option, + pub dns_overrides: Vec, pub middle_proxy_pool_size: usize, pub desync_all_full: bool, pub update_every_secs: u64, @@ -58,6 +60,7 @@ impl HotFields { Self { log_level: cfg.general.log_level.clone(), ad_tag: cfg.general.ad_tag.clone(), + dns_overrides: cfg.network.dns_overrides.clone(), middle_proxy_pool_size: cfg.general.middle_proxy_pool_size, desync_all_full: cfg.general.desync_all_full, update_every_secs: cfg.general.effective_update_every_secs(), @@ -189,6 +192,13 @@ fn log_changes( ); } + if old_hot.dns_overrides != new_hot.dns_overrides { + info!( + "config reload: network.dns_overrides updated ({} entries)", + new_hot.dns_overrides.len() + ); + } + if old_hot.middle_proxy_pool_size != new_hot.middle_proxy_pool_size { info!( "config reload: middle_proxy_pool_size: {} → {}", @@ -354,6 +364,16 @@ fn reload_config( return; } + if old_hot.dns_overrides != new_hot.dns_overrides + && let Err(e) = crate::network::dns_overrides::install_entries(&new_hot.dns_overrides) + { + error!( + "config reload: invalid network.dns_overrides: {}; keeping old config", + e + ); + return; + } + warn_non_hot_changes(&old_cfg, &new_cfg); log_changes(&old_hot, &new_hot, &new_cfg, log_tx, detected_ip_v4, detected_ip_v6); config_tx.send(Arc::new(new_cfg)).ok(); diff --git a/src/config/load.rs b/src/config/load.rs index 4e0e104..c1bbdef 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -75,6 +75,23 @@ fn push_unique_nonempty(target: &mut Vec, value: String) { } } +fn is_valid_ad_tag(tag: &str) -> bool { + tag.len() == 32 && tag.chars().all(|ch| ch.is_ascii_hexdigit()) +} + +fn sanitize_ad_tag(ad_tag: &mut Option) { + let Some(tag) = ad_tag.as_ref() else { + return; + }; + + if !is_valid_ad_tag(tag) { + warn!( + "Invalid general.ad_tag value, expected exactly 32 hex chars; ad_tag is disabled" + ); + *ad_tag = None; + } +} + // ============= Main Config ============= #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -184,6 +201,8 @@ impl ProxyConfig { } } + sanitize_ad_tag(&mut config.general.ad_tag); + if let Some(update_every) = config.general.update_every { if update_every == 0 { return Err(ProxyError::Config( @@ -380,6 +399,7 @@ impl ProxyConfig { } validate_network_cfg(&mut config.network)?; + crate::network::dns_overrides::validate_entries(&config.network.dns_overrides)?; if config.general.use_middle_proxy && config.network.ipv6 == Some(true) { warn!("IPv6 with Middle Proxy is experimental and may cause KDF address mismatch; consider disabling IPv6 or ME"); @@ -482,14 +502,18 @@ impl ProxyConfig { if let Some(tag) = &self.general.ad_tag { let zeros = "00000000000000000000000000000000"; + if !is_valid_ad_tag(tag) { + return Err(ProxyError::Config( + "general.ad_tag must be exactly 32 hex characters".to_string(), + )); + } if tag == zeros { warn!("ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel"); } - if tag.len() != 32 || tag.chars().any(|c| !c.is_ascii_hexdigit()) { - warn!("ad_tag is not a 32-char hex string; ensure you use value issued by @MTProxybot"); - } } + crate::network::dns_overrides::validate_entries(&self.network.dns_overrides)?; + Ok(()) } } @@ -509,6 +533,7 @@ mod tests { let cfg: ProxyConfig = toml::from_str(toml).unwrap(); assert_eq!(cfg.network.ipv6, default_network_ipv6()); + assert_eq!(cfg.network.stun_use, default_true()); assert_eq!(cfg.network.stun_tcp_fallback, default_stun_tcp_fallback()); assert_eq!( cfg.general.middle_proxy_warm_standby, @@ -532,6 +557,7 @@ mod tests { fn impl_defaults_are_sourced_from_default_helpers() { let network = NetworkConfig::default(); assert_eq!(network.ipv6, default_network_ipv6()); + assert_eq!(network.stun_use, default_true()); assert_eq!(network.stun_tcp_fallback, default_stun_tcp_fallback()); let general = GeneralConfig::default(); @@ -934,4 +960,87 @@ mod tests { assert_eq!(cfg.general.me_reinit_drain_timeout_secs, 90); let _ = std::fs::remove_file(path); } + + #[test] + fn invalid_ad_tag_is_disabled_during_load() { + let toml = r#" + [general] + ad_tag = "not_hex" + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_invalid_ad_tag_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert!(cfg.general.ad_tag.is_none()); + let _ = std::fs::remove_file(path); + } + + #[test] + fn valid_ad_tag_is_preserved_during_load() { + let toml = r#" + [general] + ad_tag = "00112233445566778899aabbccddeeff" + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_valid_ad_tag_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.general.ad_tag.as_deref(), + Some("00112233445566778899aabbccddeeff") + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn invalid_dns_override_is_rejected() { + let toml = r#" + [network] + dns_overrides = ["example.com:443:2001:db8::10"] + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_invalid_dns_override_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("must be bracketed")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn valid_dns_override_is_accepted() { + let toml = r#" + [network] + dns_overrides = ["example.com:443:127.0.0.1", "example.net:443:[2001:db8::10]"] + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_valid_dns_override_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!(cfg.network.dns_overrides.len(), 2); + let _ = std::fs::remove_file(path); + } } diff --git a/src/config/types.rs b/src/config/types.rs index 68086be..7d9f13a 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -97,6 +97,11 @@ pub struct NetworkConfig { #[serde(default)] pub multipath: bool, + /// Global switch for STUN probing. + /// When false, STUN is fully disabled and only non-STUN detection remains. + #[serde(default = "default_true")] + pub stun_use: bool, + /// STUN servers list for public IP discovery. #[serde(default = "default_stun_servers")] pub stun_servers: Vec, @@ -112,6 +117,11 @@ pub struct NetworkConfig { /// Cache file path for detected public IP. #[serde(default = "default_cache_public_ip_path")] pub cache_public_ip_path: String, + + /// Runtime DNS overrides in `host:port:ip` format. + /// IPv6 IP values must be bracketed: `[2001:db8::1]`. + #[serde(default)] + pub dns_overrides: Vec, } impl Default for NetworkConfig { @@ -121,10 +131,12 @@ impl Default for NetworkConfig { ipv6: default_network_ipv6(), prefer: default_prefer_4(), multipath: false, + stun_use: default_true(), stun_servers: default_stun_servers(), stun_tcp_fallback: default_stun_tcp_fallback(), http_ip_detect_urls: default_http_ip_detect_urls(), cache_public_ip_path: default_cache_public_ip_path(), + dns_overrides: Vec::new(), } } } diff --git a/src/main.rs b/src/main.rs index 95f7e5a..7389117 100644 --- a/src/main.rs +++ b/src/main.rs @@ -193,6 +193,11 @@ async fn main() -> std::result::Result<(), Box> { std::process::exit(1); } + if let Err(e) = crate::network::dns_overrides::install_entries(&config.network.dns_overrides) { + eprintln!("[telemt] Invalid network.dns_overrides: {}", e); + std::process::exit(1); + } + let has_rust_log = std::env::var("RUST_LOG").is_ok(); let effective_log_level = if cli_silent { LogLevel::Silent @@ -403,6 +408,12 @@ async fn main() -> std::result::Result<(), Box> { if !config.access.user_max_unique_ips.is_empty() { info!("IP limits configured for {} users", config.access.user_max_unique_ips.len()); } + if !config.network.dns_overrides.is_empty() { + info!( + "Runtime DNS overrides configured: {} entries", + config.network.dns_overrides.len() + ); + } // Connection concurrency limit let max_connections = Arc::new(Semaphore::new(10_000)); @@ -417,14 +428,17 @@ async fn main() -> std::result::Result<(), Box> { // ===================================================================== let me_pool: Option> = if use_middle_proxy { info!("=== Middle Proxy Mode ==="); + let me_nat_probe = config.general.middle_proxy_nat_probe && config.network.stun_use; + if config.general.middle_proxy_nat_probe && !config.network.stun_use { + info!("Middle-proxy STUN probing disabled by network.stun_use=false"); + } // ad_tag (proxy_tag) for advertising - let proxy_tag = config.general.ad_tag.as_ref().map(|tag| { - hex::decode(tag).unwrap_or_else(|_| { - warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty"); - Vec::new() - }) - }); + let proxy_tag = config + .general + .ad_tag + .as_ref() + .map(|tag| hex::decode(tag).expect("general.ad_tag must be validated before startup")); // ============================================================= // CRITICAL: Download Telegram proxy-secret (NOT user secret!) @@ -484,7 +498,7 @@ async fn main() -> std::result::Result<(), Box> { proxy_tag, proxy_secret, config.general.middle_proxy_nat_ip, - config.general.middle_proxy_nat_probe, + me_nat_probe, None, config.network.stun_servers.clone(), config.general.stun_nat_probe_concurrency, @@ -1037,9 +1051,18 @@ async fn main() -> std::result::Result<(), Box> { let stats = stats.clone(); let beobachten = beobachten.clone(); let config_rx_metrics = config_rx.clone(); + let ip_tracker_metrics = ip_tracker.clone(); let whitelist = config.server.metrics_whitelist.clone(); tokio::spawn(async move { - metrics::serve(port, stats, beobachten, config_rx_metrics, whitelist).await; + metrics::serve( + port, + stats, + beobachten, + ip_tracker_metrics, + config_rx_metrics, + whitelist, + ) + .await; }); } diff --git a/src/metrics.rs b/src/metrics.rs index 08abb2d..63b337b 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,4 +1,5 @@ use std::convert::Infallible; +use std::collections::{BTreeSet, HashMap}; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -13,6 +14,7 @@ use tokio::net::TcpListener; use tracing::{info, warn, debug}; use crate::config::ProxyConfig; +use crate::ip_tracker::UserIpTracker; use crate::stats::beobachten::BeobachtenStore; use crate::stats::Stats; @@ -20,6 +22,7 @@ pub async fn serve( port: u16, stats: Arc, beobachten: Arc, + ip_tracker: Arc, config_rx: tokio::sync::watch::Receiver>, whitelist: Vec, ) { @@ -49,13 +52,15 @@ pub async fn serve( let stats = stats.clone(); let beobachten = beobachten.clone(); + let ip_tracker = ip_tracker.clone(); let config_rx_conn = config_rx.clone(); tokio::spawn(async move { let svc = service_fn(move |req| { let stats = stats.clone(); let beobachten = beobachten.clone(); + let ip_tracker = ip_tracker.clone(); let config = config_rx_conn.borrow().clone(); - async move { handle(req, &stats, &beobachten, &config) } + async move { handle(req, &stats, &beobachten, &ip_tracker, &config).await } }); if let Err(e) = http1::Builder::new() .serve_connection(hyper_util::rt::TokioIo::new(stream), svc) @@ -67,14 +72,15 @@ pub async fn serve( } } -fn handle( +async fn handle( req: Request, stats: &Stats, beobachten: &BeobachtenStore, + ip_tracker: &UserIpTracker, config: &ProxyConfig, ) -> Result>, Infallible> { if req.uri().path() == "/metrics" { - let body = render_metrics(stats); + let body = render_metrics(stats, config, ip_tracker).await; let resp = Response::builder() .status(StatusCode::OK) .header("content-type", "text/plain; version=0.0.4; charset=utf-8") @@ -109,7 +115,7 @@ fn render_beobachten(beobachten: &BeobachtenStore, config: &ProxyConfig) -> Stri beobachten.snapshot_text(ttl) } -fn render_metrics(stats: &Stats) -> String { +async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIpTracker) -> String { use std::fmt::Write; let mut out = String::with_capacity(4096); @@ -349,6 +355,41 @@ fn render_metrics(stats: &Stats) -> String { let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed)); } + let ip_stats = ip_tracker.get_stats().await; + let ip_counts: HashMap = ip_stats + .into_iter() + .map(|(user, count, _)| (user, count)) + .collect(); + + let mut unique_users = BTreeSet::new(); + unique_users.extend(config.access.user_max_unique_ips.keys().cloned()); + unique_users.extend(ip_counts.keys().cloned()); + + let _ = writeln!(out, "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs"); + let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge"); + let _ = writeln!(out, "# HELP telemt_user_unique_ips_limit Per-user configured unique IP limit (0 means unlimited)"); + let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge"); + let _ = writeln!(out, "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)"); + let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge"); + + for user in unique_users { + let current = ip_counts.get(&user).copied().unwrap_or(0); + let limit = config.access.user_max_unique_ips.get(&user).copied().unwrap_or(0); + let utilization = if limit > 0 { + current as f64 / limit as f64 + } else { + 0.0 + }; + let _ = writeln!(out, "telemt_user_unique_ips_current{{user=\"{}\"}} {}", user, current); + let _ = writeln!(out, "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", user, limit); + let _ = writeln!( + out, + "telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}", + user, + utilization + ); + } + out } @@ -358,9 +399,16 @@ mod tests { use std::net::IpAddr; use http_body_util::BodyExt; - #[test] - fn test_render_metrics_format() { + #[tokio::test] + async fn test_render_metrics_format() { let stats = Arc::new(Stats::new()); + let tracker = UserIpTracker::new(); + let mut config = ProxyConfig::default(); + config + .access + .user_max_unique_ips + .insert("alice".to_string(), 4); + stats.increment_connects_all(); stats.increment_connects_all(); stats.increment_connects_bad(); @@ -372,8 +420,12 @@ mod tests { stats.increment_user_msgs_from("alice"); stats.increment_user_msgs_to("alice"); stats.increment_user_msgs_to("alice"); + tracker + .check_and_add("alice", "203.0.113.10".parse().unwrap()) + .await + .unwrap(); - let output = render_metrics(&stats); + let output = render_metrics(&stats, &config, &tracker).await; assert!(output.contains("telemt_connections_total 2")); assert!(output.contains("telemt_connections_bad_total 1")); @@ -384,22 +436,29 @@ mod tests { assert!(output.contains("telemt_user_octets_to_client{user=\"alice\"} 2048")); assert!(output.contains("telemt_user_msgs_from_client{user=\"alice\"} 1")); assert!(output.contains("telemt_user_msgs_to_client{user=\"alice\"} 2")); + assert!(output.contains("telemt_user_unique_ips_current{user=\"alice\"} 1")); + assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 4")); + assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.250000")); } - #[test] - fn test_render_empty_stats() { + #[tokio::test] + async fn test_render_empty_stats() { let stats = Stats::new(); - let output = render_metrics(&stats); + let tracker = UserIpTracker::new(); + let config = ProxyConfig::default(); + let output = render_metrics(&stats, &config, &tracker).await; assert!(output.contains("telemt_connections_total 0")); assert!(output.contains("telemt_connections_bad_total 0")); assert!(output.contains("telemt_handshake_timeouts_total 0")); assert!(!output.contains("user=")); } - #[test] - fn test_render_has_type_annotations() { + #[tokio::test] + async fn test_render_has_type_annotations() { let stats = Stats::new(); - let output = render_metrics(&stats); + let tracker = UserIpTracker::new(); + let config = ProxyConfig::default(); + let output = render_metrics(&stats, &config, &tracker).await; assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); assert!(output.contains("# TYPE telemt_connections_total counter")); assert!(output.contains("# TYPE telemt_connections_bad_total counter")); @@ -408,12 +467,16 @@ mod tests { assert!(output.contains( "# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge" )); + assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge")); + assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge")); + assert!(output.contains("# TYPE telemt_user_unique_ips_utilization gauge")); } #[tokio::test] async fn test_endpoint_integration() { let stats = Arc::new(Stats::new()); let beobachten = Arc::new(BeobachtenStore::new()); + let tracker = UserIpTracker::new(); let mut config = ProxyConfig::default(); stats.increment_connects_all(); stats.increment_connects_all(); @@ -423,7 +486,7 @@ mod tests { .uri("/metrics") .body(()) .unwrap(); - let resp = handle(req, &stats, &beobachten, &config).unwrap(); + let resp = handle(req, &stats, &beobachten, &tracker, &config).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); let body = resp.into_body().collect().await.unwrap().to_bytes(); assert!(std::str::from_utf8(body.as_ref()).unwrap().contains("telemt_connections_total 3")); @@ -439,7 +502,9 @@ mod tests { .uri("/beobachten") .body(()) .unwrap(); - let resp_beob = handle(req_beob, &stats, &beobachten, &config).unwrap(); + let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config) + .await + .unwrap(); assert_eq!(resp_beob.status(), StatusCode::OK); let body_beob = resp_beob.into_body().collect().await.unwrap().to_bytes(); let beob_text = std::str::from_utf8(body_beob.as_ref()).unwrap(); @@ -450,7 +515,9 @@ mod tests { .uri("/other") .body(()) .unwrap(); - let resp404 = handle(req404, &stats, &beobachten, &config).unwrap(); + let resp404 = handle(req404, &stats, &beobachten, &tracker, &config) + .await + .unwrap(); assert_eq!(resp404.status(), StatusCode::NOT_FOUND); } } diff --git a/src/network/dns_overrides.rs b/src/network/dns_overrides.rs new file mode 100644 index 0000000..447863a --- /dev/null +++ b/src/network/dns_overrides.rs @@ -0,0 +1,197 @@ +//! Runtime DNS overrides for `host:port` targets. + +use std::collections::HashMap; +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use std::sync::{OnceLock, RwLock}; + +use crate::error::{ProxyError, Result}; + +type OverrideMap = HashMap<(String, u16), IpAddr>; + +static DNS_OVERRIDES: OnceLock> = OnceLock::new(); + +fn overrides_store() -> &'static RwLock { + DNS_OVERRIDES.get_or_init(|| RwLock::new(HashMap::new())) +} + +fn parse_ip_spec(ip_spec: &str) -> Result { + if ip_spec.starts_with('[') && ip_spec.ends_with(']') { + let inner = &ip_spec[1..ip_spec.len() - 1]; + let ipv6 = inner.parse::().map_err(|_| { + ProxyError::Config(format!( + "network.dns_overrides IPv6 override is invalid: '{ip_spec}'" + )) + })?; + return Ok(IpAddr::V6(ipv6)); + } + + let ip = ip_spec.parse::().map_err(|_| { + ProxyError::Config(format!( + "network.dns_overrides IP is invalid: '{ip_spec}'" + )) + })?; + if matches!(ip, IpAddr::V6(_)) { + return Err(ProxyError::Config(format!( + "network.dns_overrides IPv6 must be bracketed: '{ip_spec}'" + ))); + } + Ok(ip) +} + +fn parse_entry(entry: &str) -> Result<((String, u16), IpAddr)> { + let trimmed = entry.trim(); + if trimmed.is_empty() { + return Err(ProxyError::Config( + "network.dns_overrides entry cannot be empty".to_string(), + )); + } + + let first_sep = trimmed.find(':').ok_or_else(|| { + ProxyError::Config(format!( + "network.dns_overrides entry must use host:port:ip format: '{trimmed}'" + )) + })?; + let second_sep = trimmed[first_sep + 1..] + .find(':') + .map(|idx| first_sep + 1 + idx) + .ok_or_else(|| { + ProxyError::Config(format!( + "network.dns_overrides entry must use host:port:ip format: '{trimmed}'" + )) + })?; + + let host = trimmed[..first_sep].trim(); + let port_str = trimmed[first_sep + 1..second_sep].trim(); + let ip_str = trimmed[second_sep + 1..].trim(); + + if host.is_empty() { + return Err(ProxyError::Config(format!( + "network.dns_overrides host cannot be empty: '{trimmed}'" + ))); + } + if host.contains(':') { + return Err(ProxyError::Config(format!( + "network.dns_overrides host must be a domain name without ':' in this format: '{trimmed}'" + ))); + } + + let port = port_str.parse::().map_err(|_| { + ProxyError::Config(format!( + "network.dns_overrides port is invalid: '{trimmed}'" + )) + })?; + let ip = parse_ip_spec(ip_str)?; + + Ok(((host.to_ascii_lowercase(), port), ip)) +} + +fn parse_entries(entries: &[String]) -> Result { + let mut parsed = HashMap::new(); + for entry in entries { + let (key, ip) = parse_entry(entry)?; + parsed.insert(key, ip); + } + Ok(parsed) +} + +/// Validate `network.dns_overrides` entries without updating runtime state. +pub fn validate_entries(entries: &[String]) -> Result<()> { + let _ = parse_entries(entries)?; + Ok(()) +} + +/// Replace runtime DNS overrides with a new validated snapshot. +pub fn install_entries(entries: &[String]) -> Result<()> { + let parsed = parse_entries(entries)?; + let mut guard = overrides_store() + .write() + .map_err(|_| ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()))?; + *guard = parsed; + Ok(()) +} + +/// Resolve a hostname override for `(host, port)` if present. +pub fn resolve(host: &str, port: u16) -> Option { + let key = (host.to_ascii_lowercase(), port); + overrides_store() + .read() + .ok() + .and_then(|guard| guard.get(&key).copied()) +} + +/// Resolve a hostname override and construct a socket address when present. +pub fn resolve_socket_addr(host: &str, port: u16) -> Option { + resolve(host, port).map(|ip| SocketAddr::new(ip, port)) +} + +/// Parse a runtime endpoint in `host:port` format. +/// +/// Supports: +/// - `example.com:443` +/// - `[2001:db8::1]:443` +pub fn split_host_port(endpoint: &str) -> Option<(String, u16)> { + if endpoint.starts_with('[') { + let bracket_end = endpoint.find(']')?; + if endpoint.as_bytes().get(bracket_end + 1) != Some(&b':') { + return None; + } + let host = endpoint[1..bracket_end].trim(); + let port = endpoint[bracket_end + 2..].trim().parse::().ok()?; + if host.is_empty() { + return None; + } + return Some((host.to_ascii_lowercase(), port)); + } + + let split_idx = endpoint.rfind(':')?; + let host = endpoint[..split_idx].trim(); + let port = endpoint[split_idx + 1..].trim().parse::().ok()?; + if host.is_empty() || host.contains(':') { + return None; + } + + Some((host.to_ascii_lowercase(), port)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn validate_accepts_ipv4_and_bracketed_ipv6() { + let entries = vec![ + "example.com:443:127.0.0.1".to_string(), + "example.net:8443:[2001:db8::10]".to_string(), + ]; + assert!(validate_entries(&entries).is_ok()); + } + + #[test] + fn validate_rejects_unbracketed_ipv6() { + let entries = vec!["example.net:443:2001:db8::10".to_string()]; + let err = validate_entries(&entries).unwrap_err().to_string(); + assert!(err.contains("must be bracketed")); + } + + #[test] + fn install_and_resolve_are_case_insensitive_for_host() { + let entries = vec!["MyPetrovich.ru:8443:127.0.0.1".to_string()]; + install_entries(&entries).unwrap(); + + let resolved = resolve("mypetrovich.ru", 8443); + assert_eq!(resolved, Some("127.0.0.1".parse().unwrap())); + } + + #[test] + fn split_host_port_parses_supported_shapes() { + assert_eq!( + split_host_port("example.com:443"), + Some(("example.com".to_string(), 443)) + ); + assert_eq!( + split_host_port("[2001:db8::1]:443"), + Some(("2001:db8::1".to_string(), 443)) + ); + assert_eq!(split_host_port("2001:db8::1:443"), None); + } +} diff --git a/src/network/mod.rs b/src/network/mod.rs index 78a1040..e57622d 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -1,3 +1,4 @@ +pub mod dns_overrides; pub mod probe; pub mod stun; diff --git a/src/network/probe.rs b/src/network/probe.rs index 6e84682..2ceeb2c 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -68,7 +68,7 @@ pub async fn run_probe( probe.ipv4_is_bogon = probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false); probe.ipv6_is_bogon = probe.detected_ipv6.map(is_bogon_v6).unwrap_or(false); - let stun_res = if nat_probe { + let stun_res = if nat_probe && config.stun_use { let servers = collect_stun_servers(config); if servers.is_empty() { warn!("STUN probe is enabled but network.stun_servers is empty"); @@ -80,6 +80,9 @@ pub async fn run_probe( ) .await } + } else if nat_probe { + info!("STUN probe is disabled by network.stun_use=false"); + DualStunResult::default() } else { DualStunResult::default() }; diff --git a/src/network/stun.rs b/src/network/stun.rs index 5bda495..bb5a873 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -7,6 +7,7 @@ use tokio::net::{lookup_host, UdpSocket}; use tokio::time::{timeout, Duration, sleep}; use crate::error::{ProxyError, Result}; +use crate::network::dns_overrides::{resolve, split_host_port}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IpFamily { @@ -198,6 +199,16 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result Some(addr), + _ => None, + }); + } + let mut addrs = lookup_host(stun_addr) .await .map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?; diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index d12cf41..8f19b40 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -10,6 +10,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::time::timeout; use tracing::debug; use crate::config::ProxyConfig; +use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; @@ -115,8 +116,10 @@ where "Forwarding bad client to mask host" ); - // Connect to mask host - let mask_addr = format!("{}:{}", mask_host, mask_port); + // Apply runtime DNS override for mask target when configured. + let mask_addr = resolve_socket_addr(mask_host, mask_port) + .map(|addr| addr.to_string()) + .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 561d4cc..ba80332 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -18,6 +18,7 @@ use x509_parser::prelude::FromDer; use x509_parser::certificate::X509Certificate; use crate::crypto::SecureRandom; +use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE}; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; use crate::tls_front::types::{ @@ -333,6 +334,17 @@ fn u24_bytes(value: usize) -> Option<[u8; 3]> { ]) } +async fn connect_with_dns_override( + host: &str, + port: u16, + connect_timeout: Duration, +) -> Result { + if let Some(addr) = resolve_socket_addr(host, port) { + return Ok(timeout(connect_timeout, TcpStream::connect(addr)).await??); + } + Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??) +} + fn encode_tls13_certificate_message(cert_chain_der: &[Vec]) -> Option> { if cert_chain_der.is_empty() { return None; @@ -369,8 +381,7 @@ async fn fetch_via_raw_tls( connect_timeout: Duration, proxy_protocol: u8, ) -> Result { - let addr = format!("{host}:{port}"); - let mut stream = timeout(connect_timeout, TcpStream::connect(addr)).await??; + let mut stream = connect_with_dns_override(host, port, connect_timeout).await?; let rng = SecureRandom::new(); let client_hello = build_client_hello(sni, &rng); @@ -437,24 +448,31 @@ async fn fetch_via_rustls( ) -> Result { // rustls handshake path for certificate and basic negotiated metadata. let mut stream = if let Some(manager) = upstream { - // Resolve host to SocketAddr - if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await { + if let Some(addr) = resolve_socket_addr(host, port) { + match manager.connect(addr, None, None).await { + Ok(s) => s, + Err(e) => { + warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect"); + connect_with_dns_override(host, port, connect_timeout).await? + } + } + } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await { if let Some(addr) = addrs.find(|a| a.is_ipv4()) { match manager.connect(addr, None, None).await { Ok(s) => s, Err(e) => { warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect"); - timeout(connect_timeout, TcpStream::connect((host, port))).await?? + connect_with_dns_override(host, port, connect_timeout).await? } } } else { - timeout(connect_timeout, TcpStream::connect((host, port))).await?? + connect_with_dns_override(host, port, connect_timeout).await? } } else { - timeout(connect_timeout, TcpStream::connect((host, port))).await?? + connect_with_dns_override(host, port, connect_timeout).await? } } else { - timeout(connect_timeout, TcpStream::connect((host, port))).await?? + connect_with_dns_override(host, port, connect_timeout).await? }; if proxy_protocol > 0 { diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index e2198a8..a442597 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -17,6 +17,7 @@ use tracing::{debug, warn, info, trace}; use crate::config::{UpstreamConfig, UpstreamType}; use crate::error::{Result, ProxyError}; +use crate::network::dns_overrides::{resolve_socket_addr, split_host_port}; use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT}; use crate::transport::socket::{create_outgoing_socket_bound, resolve_interface_ip}; use crate::transport::socks::{connect_socks4, connect_socks5}; @@ -209,6 +210,31 @@ impl UpstreamManager { None } + async fn connect_hostname_with_dns_override( + address: &str, + connect_timeout: Duration, + ) -> Result { + if let Some((host, port)) = split_host_port(address) + && let Some(addr) = resolve_socket_addr(&host, port) + { + return match tokio::time::timeout(connect_timeout, TcpStream::connect(addr)).await { + Ok(Ok(stream)) => Ok(stream), + Ok(Err(e)) => Err(ProxyError::Io(e)), + Err(_) => Err(ProxyError::ConnectionTimeout { + addr: addr.to_string(), + }), + }; + } + + match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await { + Ok(Ok(stream)) => Ok(stream), + Ok(Err(e)) => Err(ProxyError::Io(e)), + Err(_) => Err(ProxyError::ConnectionTimeout { + addr: address.to_string(), + }), + } + } + /// Select upstream using latency-weighted random selection. async fn select_upstream(&self, dc_idx: Option, scope: Option<&str>) -> Option { let upstreams = self.upstreams.read().await; @@ -433,15 +459,7 @@ impl UpstreamManager { if interface.is_some() { warn!("SOCKS4 interface binding is not supported for hostname addresses, ignoring"); } - match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await { - Ok(Ok(stream)) => stream, - Ok(Err(e)) => return Err(ProxyError::Io(e)), - Err(_) => { - return Err(ProxyError::ConnectionTimeout { - addr: address.clone(), - }); - } - } + Self::connect_hostname_with_dns_override(address, connect_timeout).await? }; // replace socks user_id with config.selected_scope, if set @@ -503,15 +521,7 @@ impl UpstreamManager { if interface.is_some() { warn!("SOCKS5 interface binding is not supported for hostname addresses, ignoring"); } - match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await { - Ok(Ok(stream)) => stream, - Ok(Err(e)) => return Err(ProxyError::Io(e)), - Err(_) => { - return Err(ProxyError::ConnectionTimeout { - addr: address.clone(), - }); - } - } + Self::connect_hostname_with_dns_override(address, connect_timeout).await? }; debug!(config = ?config, "Socks5 connection");