STUN switch + Ad-tag fixes + DNS-overrides

This commit is contained in:
Alexey 2026-02-27 15:59:27 +03:00
parent eba158ff8b
commit ac064fe773
No known key found for this signature in database
12 changed files with 530 additions and 56 deletions

View File

@ -16,6 +16,7 @@
//! | `general` | `me_pool_drain_ttl_secs` | Applied on next ME map update |
//! | `general` | `me_pool_min_fresh_ratio` | Applied on next ME map update |
//! | `general` | `me_reinit_drain_timeout_secs`| Applied on next ME map update |
//! | `network` | `dns_overrides` | Applied immediately |
//! | `access` | All user/quota fields | Effective immediately |
//!
//! Fields that require re-binding sockets (`server.port`, `censorship.*`,
@ -39,6 +40,7 @@ use super::load::ProxyConfig;
pub struct HotFields {
pub log_level: LogLevel,
pub ad_tag: Option<String>,
pub dns_overrides: Vec<String>,
pub middle_proxy_pool_size: usize,
pub desync_all_full: bool,
pub update_every_secs: u64,
@ -58,6 +60,7 @@ impl HotFields {
Self {
log_level: cfg.general.log_level.clone(),
ad_tag: cfg.general.ad_tag.clone(),
dns_overrides: cfg.network.dns_overrides.clone(),
middle_proxy_pool_size: cfg.general.middle_proxy_pool_size,
desync_all_full: cfg.general.desync_all_full,
update_every_secs: cfg.general.effective_update_every_secs(),
@ -189,6 +192,13 @@ fn log_changes(
);
}
if old_hot.dns_overrides != new_hot.dns_overrides {
info!(
"config reload: network.dns_overrides updated ({} entries)",
new_hot.dns_overrides.len()
);
}
if old_hot.middle_proxy_pool_size != new_hot.middle_proxy_pool_size {
info!(
"config reload: middle_proxy_pool_size: {} → {}",
@ -354,6 +364,16 @@ fn reload_config(
return;
}
if old_hot.dns_overrides != new_hot.dns_overrides
&& let Err(e) = crate::network::dns_overrides::install_entries(&new_hot.dns_overrides)
{
error!(
"config reload: invalid network.dns_overrides: {}; keeping old config",
e
);
return;
}
warn_non_hot_changes(&old_cfg, &new_cfg);
log_changes(&old_hot, &new_hot, &new_cfg, log_tx, detected_ip_v4, detected_ip_v6);
config_tx.send(Arc::new(new_cfg)).ok();

View File

@ -75,6 +75,23 @@ fn push_unique_nonempty(target: &mut Vec<String>, value: String) {
}
}
fn is_valid_ad_tag(tag: &str) -> bool {
tag.len() == 32 && tag.chars().all(|ch| ch.is_ascii_hexdigit())
}
fn sanitize_ad_tag(ad_tag: &mut Option<String>) {
let Some(tag) = ad_tag.as_ref() else {
return;
};
if !is_valid_ad_tag(tag) {
warn!(
"Invalid general.ad_tag value, expected exactly 32 hex chars; ad_tag is disabled"
);
*ad_tag = None;
}
}
// ============= Main Config =============
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -184,6 +201,8 @@ impl ProxyConfig {
}
}
sanitize_ad_tag(&mut config.general.ad_tag);
if let Some(update_every) = config.general.update_every {
if update_every == 0 {
return Err(ProxyError::Config(
@ -380,6 +399,7 @@ impl ProxyConfig {
}
validate_network_cfg(&mut config.network)?;
crate::network::dns_overrides::validate_entries(&config.network.dns_overrides)?;
if config.general.use_middle_proxy && config.network.ipv6 == Some(true) {
warn!("IPv6 with Middle Proxy is experimental and may cause KDF address mismatch; consider disabling IPv6 or ME");
@ -482,14 +502,18 @@ impl ProxyConfig {
if let Some(tag) = &self.general.ad_tag {
let zeros = "00000000000000000000000000000000";
if !is_valid_ad_tag(tag) {
return Err(ProxyError::Config(
"general.ad_tag must be exactly 32 hex characters".to_string(),
));
}
if tag == zeros {
warn!("ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel");
}
if tag.len() != 32 || tag.chars().any(|c| !c.is_ascii_hexdigit()) {
warn!("ad_tag is not a 32-char hex string; ensure you use value issued by @MTProxybot");
}
}
crate::network::dns_overrides::validate_entries(&self.network.dns_overrides)?;
Ok(())
}
}
@ -509,6 +533,7 @@ mod tests {
let cfg: ProxyConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.network.ipv6, default_network_ipv6());
assert_eq!(cfg.network.stun_use, default_true());
assert_eq!(cfg.network.stun_tcp_fallback, default_stun_tcp_fallback());
assert_eq!(
cfg.general.middle_proxy_warm_standby,
@ -532,6 +557,7 @@ mod tests {
fn impl_defaults_are_sourced_from_default_helpers() {
let network = NetworkConfig::default();
assert_eq!(network.ipv6, default_network_ipv6());
assert_eq!(network.stun_use, default_true());
assert_eq!(network.stun_tcp_fallback, default_stun_tcp_fallback());
let general = GeneralConfig::default();
@ -934,4 +960,87 @@ mod tests {
assert_eq!(cfg.general.me_reinit_drain_timeout_secs, 90);
let _ = std::fs::remove_file(path);
}
#[test]
fn invalid_ad_tag_is_disabled_during_load() {
let toml = r#"
[general]
ad_tag = "not_hex"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_invalid_ad_tag_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert!(cfg.general.ad_tag.is_none());
let _ = std::fs::remove_file(path);
}
#[test]
fn valid_ad_tag_is_preserved_during_load() {
let toml = r#"
[general]
ad_tag = "00112233445566778899aabbccddeeff"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_valid_ad_tag_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(
cfg.general.ad_tag.as_deref(),
Some("00112233445566778899aabbccddeeff")
);
let _ = std::fs::remove_file(path);
}
#[test]
fn invalid_dns_override_is_rejected() {
let toml = r#"
[network]
dns_overrides = ["example.com:443:2001:db8::10"]
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_invalid_dns_override_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("must be bracketed"));
let _ = std::fs::remove_file(path);
}
#[test]
fn valid_dns_override_is_accepted() {
let toml = r#"
[network]
dns_overrides = ["example.com:443:127.0.0.1", "example.net:443:[2001:db8::10]"]
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_valid_dns_override_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(cfg.network.dns_overrides.len(), 2);
let _ = std::fs::remove_file(path);
}
}

View File

@ -97,6 +97,11 @@ pub struct NetworkConfig {
#[serde(default)]
pub multipath: bool,
/// Global switch for STUN probing.
/// When false, STUN is fully disabled and only non-STUN detection remains.
#[serde(default = "default_true")]
pub stun_use: bool,
/// STUN servers list for public IP discovery.
#[serde(default = "default_stun_servers")]
pub stun_servers: Vec<String>,
@ -112,6 +117,11 @@ pub struct NetworkConfig {
/// Cache file path for detected public IP.
#[serde(default = "default_cache_public_ip_path")]
pub cache_public_ip_path: String,
/// Runtime DNS overrides in `host:port:ip` format.
/// IPv6 IP values must be bracketed: `[2001:db8::1]`.
#[serde(default)]
pub dns_overrides: Vec<String>,
}
impl Default for NetworkConfig {
@ -121,10 +131,12 @@ impl Default for NetworkConfig {
ipv6: default_network_ipv6(),
prefer: default_prefer_4(),
multipath: false,
stun_use: default_true(),
stun_servers: default_stun_servers(),
stun_tcp_fallback: default_stun_tcp_fallback(),
http_ip_detect_urls: default_http_ip_detect_urls(),
cache_public_ip_path: default_cache_public_ip_path(),
dns_overrides: Vec::new(),
}
}
}

View File

@ -193,6 +193,11 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
std::process::exit(1);
}
if let Err(e) = crate::network::dns_overrides::install_entries(&config.network.dns_overrides) {
eprintln!("[telemt] Invalid network.dns_overrides: {}", e);
std::process::exit(1);
}
let has_rust_log = std::env::var("RUST_LOG").is_ok();
let effective_log_level = if cli_silent {
LogLevel::Silent
@ -403,6 +408,12 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
if !config.access.user_max_unique_ips.is_empty() {
info!("IP limits configured for {} users", config.access.user_max_unique_ips.len());
}
if !config.network.dns_overrides.is_empty() {
info!(
"Runtime DNS overrides configured: {} entries",
config.network.dns_overrides.len()
);
}
// Connection concurrency limit
let max_connections = Arc::new(Semaphore::new(10_000));
@ -417,14 +428,17 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
// =====================================================================
let me_pool: Option<Arc<MePool>> = if use_middle_proxy {
info!("=== Middle Proxy Mode ===");
let me_nat_probe = config.general.middle_proxy_nat_probe && config.network.stun_use;
if config.general.middle_proxy_nat_probe && !config.network.stun_use {
info!("Middle-proxy STUN probing disabled by network.stun_use=false");
}
// ad_tag (proxy_tag) for advertising
let proxy_tag = config.general.ad_tag.as_ref().map(|tag| {
hex::decode(tag).unwrap_or_else(|_| {
warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty");
Vec::new()
})
});
let proxy_tag = config
.general
.ad_tag
.as_ref()
.map(|tag| hex::decode(tag).expect("general.ad_tag must be validated before startup"));
// =============================================================
// CRITICAL: Download Telegram proxy-secret (NOT user secret!)
@ -484,7 +498,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
proxy_tag,
proxy_secret,
config.general.middle_proxy_nat_ip,
config.general.middle_proxy_nat_probe,
me_nat_probe,
None,
config.network.stun_servers.clone(),
config.general.stun_nat_probe_concurrency,
@ -1037,9 +1051,18 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let stats = stats.clone();
let beobachten = beobachten.clone();
let config_rx_metrics = config_rx.clone();
let ip_tracker_metrics = ip_tracker.clone();
let whitelist = config.server.metrics_whitelist.clone();
tokio::spawn(async move {
metrics::serve(port, stats, beobachten, config_rx_metrics, whitelist).await;
metrics::serve(
port,
stats,
beobachten,
ip_tracker_metrics,
config_rx_metrics,
whitelist,
)
.await;
});
}

View File

@ -1,4 +1,5 @@
use std::convert::Infallible;
use std::collections::{BTreeSet, HashMap};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
@ -13,6 +14,7 @@ use tokio::net::TcpListener;
use tracing::{info, warn, debug};
use crate::config::ProxyConfig;
use crate::ip_tracker::UserIpTracker;
use crate::stats::beobachten::BeobachtenStore;
use crate::stats::Stats;
@ -20,6 +22,7 @@ pub async fn serve(
port: u16,
stats: Arc<Stats>,
beobachten: Arc<BeobachtenStore>,
ip_tracker: Arc<UserIpTracker>,
config_rx: tokio::sync::watch::Receiver<Arc<ProxyConfig>>,
whitelist: Vec<IpNetwork>,
) {
@ -49,13 +52,15 @@ pub async fn serve(
let stats = stats.clone();
let beobachten = beobachten.clone();
let ip_tracker = ip_tracker.clone();
let config_rx_conn = config_rx.clone();
tokio::spawn(async move {
let svc = service_fn(move |req| {
let stats = stats.clone();
let beobachten = beobachten.clone();
let ip_tracker = ip_tracker.clone();
let config = config_rx_conn.borrow().clone();
async move { handle(req, &stats, &beobachten, &config) }
async move { handle(req, &stats, &beobachten, &ip_tracker, &config).await }
});
if let Err(e) = http1::Builder::new()
.serve_connection(hyper_util::rt::TokioIo::new(stream), svc)
@ -67,14 +72,15 @@ pub async fn serve(
}
}
fn handle<B>(
async fn handle<B>(
req: Request<B>,
stats: &Stats,
beobachten: &BeobachtenStore,
ip_tracker: &UserIpTracker,
config: &ProxyConfig,
) -> Result<Response<Full<Bytes>>, Infallible> {
if req.uri().path() == "/metrics" {
let body = render_metrics(stats);
let body = render_metrics(stats, config, ip_tracker).await;
let resp = Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/plain; version=0.0.4; charset=utf-8")
@ -109,7 +115,7 @@ fn render_beobachten(beobachten: &BeobachtenStore, config: &ProxyConfig) -> Stri
beobachten.snapshot_text(ttl)
}
fn render_metrics(stats: &Stats) -> String {
async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIpTracker) -> String {
use std::fmt::Write;
let mut out = String::with_capacity(4096);
@ -349,6 +355,41 @@ fn render_metrics(stats: &Stats) -> String {
let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed));
}
let ip_stats = ip_tracker.get_stats().await;
let ip_counts: HashMap<String, usize> = ip_stats
.into_iter()
.map(|(user, count, _)| (user, count))
.collect();
let mut unique_users = BTreeSet::new();
unique_users.extend(config.access.user_max_unique_ips.keys().cloned());
unique_users.extend(ip_counts.keys().cloned());
let _ = writeln!(out, "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs");
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge");
let _ = writeln!(out, "# HELP telemt_user_unique_ips_limit Per-user configured unique IP limit (0 means unlimited)");
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge");
let _ = writeln!(out, "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)");
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge");
for user in unique_users {
let current = ip_counts.get(&user).copied().unwrap_or(0);
let limit = config.access.user_max_unique_ips.get(&user).copied().unwrap_or(0);
let utilization = if limit > 0 {
current as f64 / limit as f64
} else {
0.0
};
let _ = writeln!(out, "telemt_user_unique_ips_current{{user=\"{}\"}} {}", user, current);
let _ = writeln!(out, "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", user, limit);
let _ = writeln!(
out,
"telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}",
user,
utilization
);
}
out
}
@ -358,9 +399,16 @@ mod tests {
use std::net::IpAddr;
use http_body_util::BodyExt;
#[test]
fn test_render_metrics_format() {
#[tokio::test]
async fn test_render_metrics_format() {
let stats = Arc::new(Stats::new());
let tracker = UserIpTracker::new();
let mut config = ProxyConfig::default();
config
.access
.user_max_unique_ips
.insert("alice".to_string(), 4);
stats.increment_connects_all();
stats.increment_connects_all();
stats.increment_connects_bad();
@ -372,8 +420,12 @@ mod tests {
stats.increment_user_msgs_from("alice");
stats.increment_user_msgs_to("alice");
stats.increment_user_msgs_to("alice");
tracker
.check_and_add("alice", "203.0.113.10".parse().unwrap())
.await
.unwrap();
let output = render_metrics(&stats);
let output = render_metrics(&stats, &config, &tracker).await;
assert!(output.contains("telemt_connections_total 2"));
assert!(output.contains("telemt_connections_bad_total 1"));
@ -384,22 +436,29 @@ mod tests {
assert!(output.contains("telemt_user_octets_to_client{user=\"alice\"} 2048"));
assert!(output.contains("telemt_user_msgs_from_client{user=\"alice\"} 1"));
assert!(output.contains("telemt_user_msgs_to_client{user=\"alice\"} 2"));
assert!(output.contains("telemt_user_unique_ips_current{user=\"alice\"} 1"));
assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 4"));
assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.250000"));
}
#[test]
fn test_render_empty_stats() {
#[tokio::test]
async fn test_render_empty_stats() {
let stats = Stats::new();
let output = render_metrics(&stats);
let tracker = UserIpTracker::new();
let config = ProxyConfig::default();
let output = render_metrics(&stats, &config, &tracker).await;
assert!(output.contains("telemt_connections_total 0"));
assert!(output.contains("telemt_connections_bad_total 0"));
assert!(output.contains("telemt_handshake_timeouts_total 0"));
assert!(!output.contains("user="));
}
#[test]
fn test_render_has_type_annotations() {
#[tokio::test]
async fn test_render_has_type_annotations() {
let stats = Stats::new();
let output = render_metrics(&stats);
let tracker = UserIpTracker::new();
let config = ProxyConfig::default();
let output = render_metrics(&stats, &config, &tracker).await;
assert!(output.contains("# TYPE telemt_uptime_seconds gauge"));
assert!(output.contains("# TYPE telemt_connections_total counter"));
assert!(output.contains("# TYPE telemt_connections_bad_total counter"));
@ -408,12 +467,16 @@ mod tests {
assert!(output.contains(
"# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge"
));
assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge"));
assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge"));
assert!(output.contains("# TYPE telemt_user_unique_ips_utilization gauge"));
}
#[tokio::test]
async fn test_endpoint_integration() {
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let tracker = UserIpTracker::new();
let mut config = ProxyConfig::default();
stats.increment_connects_all();
stats.increment_connects_all();
@ -423,7 +486,7 @@ mod tests {
.uri("/metrics")
.body(())
.unwrap();
let resp = handle(req, &stats, &beobachten, &config).unwrap();
let resp = handle(req, &stats, &beobachten, &tracker, &config).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
assert!(std::str::from_utf8(body.as_ref()).unwrap().contains("telemt_connections_total 3"));
@ -439,7 +502,9 @@ mod tests {
.uri("/beobachten")
.body(())
.unwrap();
let resp_beob = handle(req_beob, &stats, &beobachten, &config).unwrap();
let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config)
.await
.unwrap();
assert_eq!(resp_beob.status(), StatusCode::OK);
let body_beob = resp_beob.into_body().collect().await.unwrap().to_bytes();
let beob_text = std::str::from_utf8(body_beob.as_ref()).unwrap();
@ -450,7 +515,9 @@ mod tests {
.uri("/other")
.body(())
.unwrap();
let resp404 = handle(req404, &stats, &beobachten, &config).unwrap();
let resp404 = handle(req404, &stats, &beobachten, &tracker, &config)
.await
.unwrap();
assert_eq!(resp404.status(), StatusCode::NOT_FOUND);
}
}

View File

@ -0,0 +1,197 @@
//! Runtime DNS overrides for `host:port` targets.
use std::collections::HashMap;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::sync::{OnceLock, RwLock};
use crate::error::{ProxyError, Result};
type OverrideMap = HashMap<(String, u16), IpAddr>;
static DNS_OVERRIDES: OnceLock<RwLock<OverrideMap>> = OnceLock::new();
fn overrides_store() -> &'static RwLock<OverrideMap> {
DNS_OVERRIDES.get_or_init(|| RwLock::new(HashMap::new()))
}
fn parse_ip_spec(ip_spec: &str) -> Result<IpAddr> {
if ip_spec.starts_with('[') && ip_spec.ends_with(']') {
let inner = &ip_spec[1..ip_spec.len() - 1];
let ipv6 = inner.parse::<Ipv6Addr>().map_err(|_| {
ProxyError::Config(format!(
"network.dns_overrides IPv6 override is invalid: '{ip_spec}'"
))
})?;
return Ok(IpAddr::V6(ipv6));
}
let ip = ip_spec.parse::<IpAddr>().map_err(|_| {
ProxyError::Config(format!(
"network.dns_overrides IP is invalid: '{ip_spec}'"
))
})?;
if matches!(ip, IpAddr::V6(_)) {
return Err(ProxyError::Config(format!(
"network.dns_overrides IPv6 must be bracketed: '{ip_spec}'"
)));
}
Ok(ip)
}
fn parse_entry(entry: &str) -> Result<((String, u16), IpAddr)> {
let trimmed = entry.trim();
if trimmed.is_empty() {
return Err(ProxyError::Config(
"network.dns_overrides entry cannot be empty".to_string(),
));
}
let first_sep = trimmed.find(':').ok_or_else(|| {
ProxyError::Config(format!(
"network.dns_overrides entry must use host:port:ip format: '{trimmed}'"
))
})?;
let second_sep = trimmed[first_sep + 1..]
.find(':')
.map(|idx| first_sep + 1 + idx)
.ok_or_else(|| {
ProxyError::Config(format!(
"network.dns_overrides entry must use host:port:ip format: '{trimmed}'"
))
})?;
let host = trimmed[..first_sep].trim();
let port_str = trimmed[first_sep + 1..second_sep].trim();
let ip_str = trimmed[second_sep + 1..].trim();
if host.is_empty() {
return Err(ProxyError::Config(format!(
"network.dns_overrides host cannot be empty: '{trimmed}'"
)));
}
if host.contains(':') {
return Err(ProxyError::Config(format!(
"network.dns_overrides host must be a domain name without ':' in this format: '{trimmed}'"
)));
}
let port = port_str.parse::<u16>().map_err(|_| {
ProxyError::Config(format!(
"network.dns_overrides port is invalid: '{trimmed}'"
))
})?;
let ip = parse_ip_spec(ip_str)?;
Ok(((host.to_ascii_lowercase(), port), ip))
}
fn parse_entries(entries: &[String]) -> Result<OverrideMap> {
let mut parsed = HashMap::new();
for entry in entries {
let (key, ip) = parse_entry(entry)?;
parsed.insert(key, ip);
}
Ok(parsed)
}
/// Validate `network.dns_overrides` entries without updating runtime state.
pub fn validate_entries(entries: &[String]) -> Result<()> {
let _ = parse_entries(entries)?;
Ok(())
}
/// Replace runtime DNS overrides with a new validated snapshot.
pub fn install_entries(entries: &[String]) -> Result<()> {
let parsed = parse_entries(entries)?;
let mut guard = overrides_store()
.write()
.map_err(|_| ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()))?;
*guard = parsed;
Ok(())
}
/// Resolve a hostname override for `(host, port)` if present.
pub fn resolve(host: &str, port: u16) -> Option<IpAddr> {
let key = (host.to_ascii_lowercase(), port);
overrides_store()
.read()
.ok()
.and_then(|guard| guard.get(&key).copied())
}
/// Resolve a hostname override and construct a socket address when present.
pub fn resolve_socket_addr(host: &str, port: u16) -> Option<SocketAddr> {
resolve(host, port).map(|ip| SocketAddr::new(ip, port))
}
/// Parse a runtime endpoint in `host:port` format.
///
/// Supports:
/// - `example.com:443`
/// - `[2001:db8::1]:443`
pub fn split_host_port(endpoint: &str) -> Option<(String, u16)> {
if endpoint.starts_with('[') {
let bracket_end = endpoint.find(']')?;
if endpoint.as_bytes().get(bracket_end + 1) != Some(&b':') {
return None;
}
let host = endpoint[1..bracket_end].trim();
let port = endpoint[bracket_end + 2..].trim().parse::<u16>().ok()?;
if host.is_empty() {
return None;
}
return Some((host.to_ascii_lowercase(), port));
}
let split_idx = endpoint.rfind(':')?;
let host = endpoint[..split_idx].trim();
let port = endpoint[split_idx + 1..].trim().parse::<u16>().ok()?;
if host.is_empty() || host.contains(':') {
return None;
}
Some((host.to_ascii_lowercase(), port))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_accepts_ipv4_and_bracketed_ipv6() {
let entries = vec![
"example.com:443:127.0.0.1".to_string(),
"example.net:8443:[2001:db8::10]".to_string(),
];
assert!(validate_entries(&entries).is_ok());
}
#[test]
fn validate_rejects_unbracketed_ipv6() {
let entries = vec!["example.net:443:2001:db8::10".to_string()];
let err = validate_entries(&entries).unwrap_err().to_string();
assert!(err.contains("must be bracketed"));
}
#[test]
fn install_and_resolve_are_case_insensitive_for_host() {
let entries = vec!["MyPetrovich.ru:8443:127.0.0.1".to_string()];
install_entries(&entries).unwrap();
let resolved = resolve("mypetrovich.ru", 8443);
assert_eq!(resolved, Some("127.0.0.1".parse().unwrap()));
}
#[test]
fn split_host_port_parses_supported_shapes() {
assert_eq!(
split_host_port("example.com:443"),
Some(("example.com".to_string(), 443))
);
assert_eq!(
split_host_port("[2001:db8::1]:443"),
Some(("2001:db8::1".to_string(), 443))
);
assert_eq!(split_host_port("2001:db8::1:443"), None);
}
}

View File

@ -1,3 +1,4 @@
pub mod dns_overrides;
pub mod probe;
pub mod stun;

View File

@ -68,7 +68,7 @@ pub async fn run_probe(
probe.ipv4_is_bogon = probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false);
probe.ipv6_is_bogon = probe.detected_ipv6.map(is_bogon_v6).unwrap_or(false);
let stun_res = if nat_probe {
let stun_res = if nat_probe && config.stun_use {
let servers = collect_stun_servers(config);
if servers.is_empty() {
warn!("STUN probe is enabled but network.stun_servers is empty");
@ -80,6 +80,9 @@ pub async fn run_probe(
)
.await
}
} else if nat_probe {
info!("STUN probe is disabled by network.stun_use=false");
DualStunResult::default()
} else {
DualStunResult::default()
};

View File

@ -7,6 +7,7 @@ use tokio::net::{lookup_host, UdpSocket};
use tokio::time::{timeout, Duration, sleep};
use crate::error::{ProxyError, Result};
use crate::network::dns_overrides::{resolve, split_host_port};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IpFamily {
@ -198,6 +199,16 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<S
});
}
if let Some((host, port)) = split_host_port(stun_addr)
&& let Some(ip) = resolve(&host, port)
{
let addr = SocketAddr::new(ip, port);
return Ok(match (addr.is_ipv4(), family) {
(true, IpFamily::V4) | (false, IpFamily::V6) => Some(addr),
_ => None,
});
}
let mut addrs = lookup_host(stun_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?;

View File

@ -10,6 +10,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tracing::debug;
use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
@ -115,8 +116,10 @@ where
"Forwarding bad client to mask host"
);
// Connect to mask host
let mask_addr = format!("{}:{}", mask_host, mask_port);
// Apply runtime DNS override for mask target when configured.
let mask_addr = resolve_socket_addr(mask_host, mask_port)
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
Ok(Ok(stream)) => {

View File

@ -18,6 +18,7 @@ use x509_parser::prelude::FromDer;
use x509_parser::certificate::X509Certificate;
use crate::crypto::SecureRandom;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE};
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use crate::tls_front::types::{
@ -333,6 +334,17 @@ fn u24_bytes(value: usize) -> Option<[u8; 3]> {
])
}
async fn connect_with_dns_override(
host: &str,
port: u16,
connect_timeout: Duration,
) -> Result<TcpStream> {
if let Some(addr) = resolve_socket_addr(host, port) {
return Ok(timeout(connect_timeout, TcpStream::connect(addr)).await??);
}
Ok(timeout(connect_timeout, TcpStream::connect((host, port))).await??)
}
fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> {
if cert_chain_der.is_empty() {
return None;
@ -369,8 +381,7 @@ async fn fetch_via_raw_tls(
connect_timeout: Duration,
proxy_protocol: u8,
) -> Result<TlsFetchResult> {
let addr = format!("{host}:{port}");
let mut stream = timeout(connect_timeout, TcpStream::connect(addr)).await??;
let mut stream = connect_with_dns_override(host, port, connect_timeout).await?;
let rng = SecureRandom::new();
let client_hello = build_client_hello(sni, &rng);
@ -437,24 +448,31 @@ async fn fetch_via_rustls(
) -> Result<TlsFetchResult> {
// rustls handshake path for certificate and basic negotiated metadata.
let mut stream = if let Some(manager) = upstream {
// Resolve host to SocketAddr
if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
if let Some(addr) = resolve_socket_addr(host, port) {
match manager.connect(addr, None, None).await {
Ok(s) => s,
Err(e) => {
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect");
connect_with_dns_override(host, port, connect_timeout).await?
}
}
} else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
match manager.connect(addr, None, None).await {
Ok(s) => s,
Err(e) => {
warn!(sni = %sni, error = %e, "Upstream connect failed, using direct connect");
timeout(connect_timeout, TcpStream::connect((host, port))).await??
connect_with_dns_override(host, port, connect_timeout).await?
}
}
} else {
timeout(connect_timeout, TcpStream::connect((host, port))).await??
connect_with_dns_override(host, port, connect_timeout).await?
}
} else {
timeout(connect_timeout, TcpStream::connect((host, port))).await??
connect_with_dns_override(host, port, connect_timeout).await?
}
} else {
timeout(connect_timeout, TcpStream::connect((host, port))).await??
connect_with_dns_override(host, port, connect_timeout).await?
};
if proxy_protocol > 0 {

View File

@ -17,6 +17,7 @@ use tracing::{debug, warn, info, trace};
use crate::config::{UpstreamConfig, UpstreamType};
use crate::error::{Result, ProxyError};
use crate::network::dns_overrides::{resolve_socket_addr, split_host_port};
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
use crate::transport::socket::{create_outgoing_socket_bound, resolve_interface_ip};
use crate::transport::socks::{connect_socks4, connect_socks5};
@ -209,6 +210,31 @@ impl UpstreamManager {
None
}
async fn connect_hostname_with_dns_override(
address: &str,
connect_timeout: Duration,
) -> Result<TcpStream> {
if let Some((host, port)) = split_host_port(address)
&& let Some(addr) = resolve_socket_addr(&host, port)
{
return match tokio::time::timeout(connect_timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => Ok(stream),
Ok(Err(e)) => Err(ProxyError::Io(e)),
Err(_) => Err(ProxyError::ConnectionTimeout {
addr: addr.to_string(),
}),
};
}
match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await {
Ok(Ok(stream)) => Ok(stream),
Ok(Err(e)) => Err(ProxyError::Io(e)),
Err(_) => Err(ProxyError::ConnectionTimeout {
addr: address.to_string(),
}),
}
}
/// Select upstream using latency-weighted random selection.
async fn select_upstream(&self, dc_idx: Option<i16>, scope: Option<&str>) -> Option<usize> {
let upstreams = self.upstreams.read().await;
@ -433,15 +459,7 @@ impl UpstreamManager {
if interface.is_some() {
warn!("SOCKS4 interface binding is not supported for hostname addresses, ignoring");
}
match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await {
Ok(Ok(stream)) => stream,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: address.clone(),
});
}
}
Self::connect_hostname_with_dns_override(address, connect_timeout).await?
};
// replace socks user_id with config.selected_scope, if set
@ -503,15 +521,7 @@ impl UpstreamManager {
if interface.is_some() {
warn!("SOCKS5 interface binding is not supported for hostname addresses, ignoring");
}
match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await {
Ok(Ok(stream)) => stream,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: address.clone(),
});
}
}
Self::connect_hostname_with_dns_override(address, connect_timeout).await?
};
debug!(config = ?config, "Socks5 connection");