diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d1f22c1..571a400 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,14 @@ -## Pull Requests - Rules +# Pull Requests - Rules +## General - ONLY signed and verified commits - ONLY from your name - DO NOT commit with `codex` or `claude` as author/commiter - PREFER `flow` branch for development, not `main` + +## AI +We are not against modern tools, like AI, where you act as a principal or architect, but we consider it important: + +- you really understand what you're doing +- you understand the relationships and dependencies of the components being modified +- you understand the architecture of Telegram MTProto, MTProxy, Middle-End KDF at least generically +- you DO NOT commit for the sake of commits, but to help the community, core-developers and ordinary users diff --git a/Cargo.toml b/Cargo.toml index 72ac944..5fce860 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.0.4" +version = "3.0.6" edition = "2024" [dependencies] @@ -24,6 +24,7 @@ zeroize = { version = "1.8", features = ["derive"] } # Network socket2 = { version = "0.5", features = ["all"] } +nix = { version = "0.28", default-features = false, features = ["net"] } # Serialization serde = { version = "1.0", features = ["derive"] } @@ -47,13 +48,18 @@ regex = "1.11" crossbeam-queue = "0.3" num-bigint = "0.4" num-traits = "0.2" +anyhow = "1.0" # HTTP reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } +notify = { version = "6", features = ["macos_fsevent"] } hyper = { version = "1", features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } http-body-util = "0.1" httpdate = "1.0" +tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] } +rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] } +webpki-roots = "0.26" [dev-dependencies] tokio-test = "0.4" diff --git a/README.md b/README.md index cb7cd49..5fa1805 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ If you have expertise in asynchronous network applications, traffic analysis, re **This software is designed for Debian-based OS: in addition to Debian, these are Ubuntu, Mint, Kali, MX and many other Linux** 1. Download release ```bash -wget https://github.com/telemt/telemt/releases/latest/download/telemt +wget -qO- "https://github.com/telemt/telemt/releases/latest/download/telemt-$(uname -m)-linux-$(ldd --version 2>&1 | grep -iq musl && echo musl || echo gnu).tar.gz" | tar -xz ``` 2. Move to Bin Folder ```bash @@ -178,56 +178,97 @@ then Ctrl+X -> Y -> Enter to save ```toml # === General Settings === [general] -# prefer_ipv6 is deprecated; use [network].prefer -prefer_ipv6 = false fast_mode = true -use_middle_proxy = false -# ad_tag = "..." +use_middle_proxy = true +# ad_tag = "00000000000000000000000000000000" +# Path to proxy-secret binary (auto-downloaded if missing). +proxy_secret_path = "proxy-secret" # disable_colors = false # Disable colored output in logs (useful for files/systemd) -[network] -ipv4 = true -ipv6 = true # set false to disable, omit for auto -prefer = 4 # 4 or 6 -multipath = false +# === Log Level === +# Log level: debug | verbose | normal | silent +# Can be overridden with --silent or --log-level CLI flags +# RUST_LOG env var takes absolute priority over all of these +log_level = "normal" + +# === Middle Proxy - ME === +# Public IP override for ME KDF when behind NAT; leave unset to auto-detect. +# middle_proxy_nat_ip = "203.0.113.10" +# Enable STUN probing to discover public IP:port for ME. +middle_proxy_nat_probe = true +# Primary STUN server (host:port); defaults to Telegram STUN when empty. +middle_proxy_nat_stun = "stun.l.google.com:19302" +# Optional fallback STUN servers list. +middle_proxy_nat_stun_servers = ["stun1.l.google.com:19302", "stun2.l.google.com:19302"] +# Desired number of concurrent ME writers in pool. +middle_proxy_pool_size = 16 +# Pre-initialized warm-standby ME connections kept idle. +middle_proxy_warm_standby = 8 +# Ignore STUN/interface mismatch and keep ME enabled even if IP differs. +stun_iface_mismatch_ignore = false +# Keepalive padding frames - fl==4 +me_keepalive_enabled = true +me_keepalive_interval_secs = 25 # Period between keepalives +me_keepalive_jitter_secs = 5 # Jitter added to interval +me_keepalive_payload_random = true # Randomize 4-byte payload (vs zeros) +# Stagger extra ME connections on warmup to de-phase lifecycles. +me_warmup_stagger_enabled = true +me_warmup_step_delay_ms = 500 # Base delay between extra connects +me_warmup_step_jitter_ms = 300 # Jitter for warmup delay +# Reconnect policy knobs. +me_reconnect_max_concurrent_per_dc = 1 # Parallel reconnects per DC - EXPERIMENTAL! UNSTABLE! +me_reconnect_backoff_base_ms = 500 # Backoff start +me_reconnect_backoff_cap_ms = 30000 # Backoff cap +me_reconnect_fast_retry_count = 11 # Quick retries before backoff [general.modes] classic = false secure = false tls = true -# === Server Binding === -[server] -port = 443 -listen_addr_ipv4 = "0.0.0.0" -listen_addr_ipv6 = "::" -# metrics_port = 9090 -# metrics_whitelist = ["127.0.0.1", "::1"] - -# Listen on multiple interfaces/IPs (overrides listen_addr_*) -[[server.listeners]] -ip = "0.0.0.0" -# announce = "my.hostname.tld" # Optional: hostname for tg:// links -# OR -# announce = "1.2.3.4" # Optional: Public IP for tg:// links - -[[server.listeners]] -ip = "::" - -# Users to show in the startup log (tg:// links) [general.links] -show = ["hello"] # Only show links for user "hello" +show = "*" # show = ["alice", "bob"] # Only show links for alice and bob # show = "*" # Show links for all users # public_host = "proxy.example.com" # Host (IP or domain) for tg:// links # public_port = 443 # Port for tg:// links (default: server.port) +# === Network Parameters === +[network] +# Enable/disable families: true/false/auto(None) +ipv4 = true +ipv6 = false # UNSTABLE WITH ME +# prefer = 4 or 6 +prefer = 4 +multipath = false # EXPERIMENTAL! + +# === Server Binding === +[server] +port = 443 +listen_addr_ipv4 = "0.0.0.0" +listen_addr_ipv6 = "::" +# listen_unix_sock = "/var/run/telemt.sock" # Unix socket +# listen_unix_sock_perm = "0666" # Socket file permissions +# metrics_port = 9090 +# metrics_whitelist = ["127.0.0.1", "::1"] + +# Listen on multiple interfaces/IPs - IPv4 +[[server.listeners]] +ip = "0.0.0.0" + +# Listen on multiple interfaces/IPs - IPv6 +[[server.listeners]] +ip = "::" + # === Timeouts (in seconds) === [timeouts] -client_handshake = 15 +client_handshake = 30 tg_connect = 10 client_keepalive = 60 client_ack = 300 +# Quick ME reconnects for single-address DCs (count and per-attempt timeout, ms). +me_one_retry = 12 +me_one_timeout_ms = 1200 # === Anti-Censorship & Masking === [censorship] @@ -239,9 +280,9 @@ mask_port = 443 fake_cert_len = 2048 # === Access Control & Users === -# username "hello" is used for example [access] replay_check_len = 65536 +replay_window_secs = 1800 ignore_time_skew = false [access.users] @@ -251,28 +292,28 @@ hello = "00000000000000000000000000000000" # [access.user_max_tcp_conns] # hello = 50 +# [access.user_max_unique_ips] +# hello = 5 + # [access.user_data_quota] # hello = 1073741824 # 1 GB # === Upstreams & Routing === -# By default, direct connection is used, but you can add SOCKS proxy - -# Direct - Default [[upstreams]] type = "direct" enabled = true weight = 10 -# SOCKS5 # [[upstreams]] # type = "socks5" -# address = "127.0.0.1:9050" +# address = "127.0.0.1:1080" # enabled = false # weight = 1 # === DC Address Overrides === # [dc_overrides] # "203" = "91.105.192.100:443" + ``` ### Advanced #### Adtag diff --git a/config.toml b/config.toml index 45d6b75..2dc8937 100644 --- a/config.toml +++ b/config.toml @@ -1,16 +1,21 @@ # === General Settings === [general] -# prefer_ipv6 is deprecated; use [network].prefer instead -prefer_ipv6 = false fast_mode = true use_middle_proxy = true -#ad_tag = "00000000000000000000000000000000" +# ad_tag = "00000000000000000000000000000000" # Path to proxy-secret binary (auto-downloaded if missing). proxy_secret_path = "proxy-secret" +# disable_colors = false # Disable colored output in logs (useful for files/systemd) -# === Middle Proxy (ME) === +# === Log Level === +# Log level: debug | verbose | normal | silent +# Can be overridden with --silent or --log-level CLI flags +# RUST_LOG env var takes absolute priority over all of these +log_level = "normal" + +# === Middle Proxy - ME === # Public IP override for ME KDF when behind NAT; leave unset to auto-detect. -#middle_proxy_nat_ip = "203.0.113.10" +# middle_proxy_nat_ip = "203.0.113.10" # Enable STUN probing to discover public IP:port for ME. middle_proxy_nat_probe = true # Primary STUN server (host:port); defaults to Telegram STUN when empty. @@ -38,24 +43,27 @@ me_reconnect_backoff_base_ms = 500 # Backoff start me_reconnect_backoff_cap_ms = 30000 # Backoff cap me_reconnect_fast_retry_count = 11 # Quick retries before backoff -[network] -# Enable/disable families; ipv6 = true/false/auto(None) -ipv4 = true -ipv6 = true -# prefer = 4 or 6 -prefer = 4 -multipath = false - -# Log level: debug | verbose | normal | silent -# Can be overridden with --silent or --log-level CLI flags -# RUST_LOG env var takes absolute priority over all of these -log_level = "normal" - [general.modes] classic = false secure = false tls = true +[general.links] +show = "*" +# show = ["alice", "bob"] # Only show links for alice and bob +# show = "*" # Show links for all users +# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links +# public_port = 443 # Port for tg:// links (default: server.port) + +# === Network Parameters === +[network] +# Enable/disable families: true/false/auto(None) +ipv4 = true +ipv6 = false # UNSTABLE WITH ME +# prefer = 4 or 6 +prefer = 4 +multipath = false # EXPERIMENTAL! + # === Server Binding === [server] port = 443 @@ -63,23 +71,18 @@ listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv6 = "::" # listen_unix_sock = "/var/run/telemt.sock" # Unix socket # listen_unix_sock_perm = "0666" # Socket file permissions +# proxy_protocol = false # Enable if behind HAProxy/nginx with PROXY protocol # metrics_port = 9090 # metrics_whitelist = ["127.0.0.1", "::1"] -# Listen on multiple interfaces/IPs (overrides listen_addr_*) +# Listen on multiple interfaces/IPs - IPv4 [[server.listeners]] ip = "0.0.0.0" -# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links +# Listen on multiple interfaces/IPs - IPv6 [[server.listeners]] ip = "::" -# Users to show in the startup log (tg:// links) -[general.links] -show = ["hello"] # Users to show in the startup log (tg:// links) -# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links -# public_port = 443 # Port for tg:// links (default: server.port) - # === Timeouts (in seconds) === [timeouts] client_handshake = 30 @@ -93,11 +96,14 @@ me_one_timeout_ms = 1200 # === Anti-Censorship & Masking === [censorship] tls_domain = "petrovich.ru" +# tls_domains = ["example.com", "cdn.example.net"] # Additional domains for EE links mask = true mask_port = 443 # mask_host = "petrovich.ru" # Defaults to tls_domain if not set # mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host) fake_cert_len = 2048 +# tls_emulation = false # Fetch real cert lengths and emulate TLS records +# tls_front_dir = "tlsfront" # Cache directory for TLS emulation # === Access Control & Users === [access] @@ -123,6 +129,8 @@ hello = "00000000000000000000000000000000" type = "direct" enabled = true weight = 10 +# interface = "192.168.1.100" # Bind outgoing to specific IP or iface name +# bind_addresses = ["192.168.1.100"] # List for round-robin binding (family must match target) # [[upstreams]] # type = "socks5" diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 19269a2..a022021 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -23,6 +23,10 @@ pub(crate) fn default_fake_cert_len() -> usize { 2048 } +pub(crate) fn default_tls_front_dir() -> String { + "tlsfront".to_string() +} + pub(crate) fn default_replay_check_len() -> usize { 65_536 } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs new file mode 100644 index 0000000..246f8a3 --- /dev/null +++ b/src/config/hot_reload.rs @@ -0,0 +1,433 @@ +//! Hot-reload: watches the config file via inotify (Linux) / FSEvents (macOS) +//! / ReadDirectoryChangesW (Windows) using the `notify` crate. +//! SIGHUP is also supported on Unix as an additional manual trigger. +//! +//! # What can be reloaded without restart +//! +//! | Section | Field | Effect | +//! |-----------|-------------------------------|-----------------------------------| +//! | `general` | `log_level` | Filter updated via `log_level_tx` | +//! | `general` | `ad_tag` | Passed on next connection | +//! | `general` | `middle_proxy_pool_size` | Passed on next connection | +//! | `general` | `me_keepalive_*` | Passed on next connection | +//! | `access` | All user/quota fields | Effective immediately | +//! +//! Fields that require re-binding sockets (`server.port`, `censorship.*`, +//! `network.*`, `use_middle_proxy`) are **not** applied; a warning is emitted. + +use std::net::IpAddr; +use std::path::PathBuf; +use std::sync::Arc; + +use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher}; +use tokio::sync::{mpsc, watch}; +use tracing::{error, info, warn}; + +use crate::config::LogLevel; +use super::load::ProxyConfig; + +// ── Hot fields ──────────────────────────────────────────────────────────────── + +/// Fields that are safe to swap without restarting listeners. +#[derive(Debug, Clone, PartialEq)] +pub struct HotFields { + pub log_level: LogLevel, + pub ad_tag: Option, + pub middle_proxy_pool_size: usize, + pub me_keepalive_enabled: bool, + pub me_keepalive_interval_secs: u64, + pub me_keepalive_jitter_secs: u64, + pub me_keepalive_payload_random: bool, + pub access: crate::config::AccessConfig, +} + +impl HotFields { + pub fn from_config(cfg: &ProxyConfig) -> Self { + Self { + log_level: cfg.general.log_level.clone(), + ad_tag: cfg.general.ad_tag.clone(), + middle_proxy_pool_size: cfg.general.middle_proxy_pool_size, + me_keepalive_enabled: cfg.general.me_keepalive_enabled, + me_keepalive_interval_secs: cfg.general.me_keepalive_interval_secs, + me_keepalive_jitter_secs: cfg.general.me_keepalive_jitter_secs, + me_keepalive_payload_random: cfg.general.me_keepalive_payload_random, + access: cfg.access.clone(), + } + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +/// Warn if any non-hot fields changed (require restart). +fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig) { + if old.server.port != new.server.port { + warn!( + "config reload: server.port changed ({} → {}); restart required", + old.server.port, new.server.port + ); + } + if old.censorship.tls_domain != new.censorship.tls_domain { + warn!( + "config reload: censorship.tls_domain changed ('{}' → '{}'); restart required", + old.censorship.tls_domain, new.censorship.tls_domain + ); + } + if old.network.ipv4 != new.network.ipv4 || old.network.ipv6 != new.network.ipv6 { + warn!("config reload: network.ipv4/ipv6 changed; restart required"); + } + if old.general.use_middle_proxy != new.general.use_middle_proxy { + warn!("config reload: use_middle_proxy changed; restart required"); + } +} + +/// Resolve the public host for link generation — mirrors the logic in main.rs. +/// +/// Priority: +/// 1. `[general.links] public_host` — explicit override in config +/// 2. `detected_ip_v4` — from STUN/interface probe at startup +/// 3. `detected_ip_v6` — fallback +/// 4. `"UNKNOWN"` — warn the user to set `public_host` +fn resolve_link_host( + cfg: &ProxyConfig, + detected_ip_v4: Option, + detected_ip_v6: Option, +) -> String { + if let Some(ref h) = cfg.general.links.public_host { + return h.clone(); + } + detected_ip_v4 + .or(detected_ip_v6) + .map(|ip| ip.to_string()) + .unwrap_or_else(|| { + warn!( + "config reload: could not determine public IP for proxy links. \ + Set [general.links] public_host in config." + ); + "UNKNOWN".to_string() + }) +} + +/// Print TG proxy links for a single user — mirrors print_proxy_links() in main.rs. +fn print_user_links(user: &str, secret: &str, host: &str, port: u16, cfg: &ProxyConfig) { + info!(target: "telemt::links", "--- New user: {} ---", user); + if cfg.general.modes.classic { + info!( + target: "telemt::links", + " Classic: tg://proxy?server={}&port={}&secret={}", + host, port, secret + ); + } + if cfg.general.modes.secure { + info!( + target: "telemt::links", + " DD: tg://proxy?server={}&port={}&secret=dd{}", + host, port, secret + ); + } + if cfg.general.modes.tls { + let mut domains = vec![cfg.censorship.tls_domain.clone()]; + for d in &cfg.censorship.tls_domains { + if !domains.contains(d) { + domains.push(d.clone()); + } + } + for domain in &domains { + let domain_hex = hex::encode(domain.as_bytes()); + info!( + target: "telemt::links", + " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + host, port, secret, domain_hex + ); + } + } + info!(target: "telemt::links", "--------------------"); +} + +/// Log all detected changes and emit TG links for new users. +fn log_changes( + old_hot: &HotFields, + new_hot: &HotFields, + new_cfg: &ProxyConfig, + log_tx: &watch::Sender, + detected_ip_v4: Option, + detected_ip_v6: Option, +) { + if old_hot.log_level != new_hot.log_level { + info!( + "config reload: log_level: '{}' → '{}'", + old_hot.log_level, new_hot.log_level + ); + log_tx.send(new_hot.log_level.clone()).ok(); + } + + if old_hot.ad_tag != new_hot.ad_tag { + info!( + "config reload: ad_tag: {} → {}", + old_hot.ad_tag.as_deref().unwrap_or("none"), + new_hot.ad_tag.as_deref().unwrap_or("none"), + ); + } + + if old_hot.middle_proxy_pool_size != new_hot.middle_proxy_pool_size { + info!( + "config reload: middle_proxy_pool_size: {} → {}", + old_hot.middle_proxy_pool_size, new_hot.middle_proxy_pool_size, + ); + } + + if old_hot.me_keepalive_enabled != new_hot.me_keepalive_enabled + || old_hot.me_keepalive_interval_secs != new_hot.me_keepalive_interval_secs + || old_hot.me_keepalive_jitter_secs != new_hot.me_keepalive_jitter_secs + || old_hot.me_keepalive_payload_random != new_hot.me_keepalive_payload_random + { + info!( + "config reload: me_keepalive: enabled={} interval={}s jitter={}s random_payload={}", + new_hot.me_keepalive_enabled, + new_hot.me_keepalive_interval_secs, + new_hot.me_keepalive_jitter_secs, + new_hot.me_keepalive_payload_random, + ); + } + + if old_hot.access.users != new_hot.access.users { + let mut added: Vec<&String> = new_hot.access.users.keys() + .filter(|u| !old_hot.access.users.contains_key(*u)) + .collect(); + added.sort(); + + let mut removed: Vec<&String> = old_hot.access.users.keys() + .filter(|u| !new_hot.access.users.contains_key(*u)) + .collect(); + removed.sort(); + + let mut changed: Vec<&String> = new_hot.access.users.keys() + .filter(|u| { + old_hot.access.users.get(*u) + .map(|s| s != &new_hot.access.users[*u]) + .unwrap_or(false) + }) + .collect(); + changed.sort(); + + if !added.is_empty() { + info!( + "config reload: users added: [{}]", + added.iter().map(|s| s.as_str()).collect::>().join(", ") + ); + let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6); + let port = new_cfg.general.links.public_port.unwrap_or(new_cfg.server.port); + for user in &added { + if let Some(secret) = new_hot.access.users.get(*user) { + print_user_links(user, secret, &host, port, new_cfg); + } + } + } + if !removed.is_empty() { + info!( + "config reload: users removed: [{}]", + removed.iter().map(|s| s.as_str()).collect::>().join(", ") + ); + } + if !changed.is_empty() { + info!( + "config reload: users secret changed: [{}]", + changed.iter().map(|s| s.as_str()).collect::>().join(", ") + ); + } + } + + if old_hot.access.user_max_tcp_conns != new_hot.access.user_max_tcp_conns { + info!( + "config reload: user_max_tcp_conns updated ({} entries)", + new_hot.access.user_max_tcp_conns.len() + ); + } + if old_hot.access.user_expirations != new_hot.access.user_expirations { + info!( + "config reload: user_expirations updated ({} entries)", + new_hot.access.user_expirations.len() + ); + } + if old_hot.access.user_data_quota != new_hot.access.user_data_quota { + info!( + "config reload: user_data_quota updated ({} entries)", + new_hot.access.user_data_quota.len() + ); + } + if old_hot.access.user_max_unique_ips != new_hot.access.user_max_unique_ips { + info!( + "config reload: user_max_unique_ips updated ({} entries)", + new_hot.access.user_max_unique_ips.len() + ); + } +} + +/// Load config, validate, diff against current, and broadcast if changed. +fn reload_config( + config_path: &PathBuf, + config_tx: &watch::Sender>, + log_tx: &watch::Sender, + detected_ip_v4: Option, + detected_ip_v6: Option, +) { + let new_cfg = match ProxyConfig::load(config_path) { + Ok(c) => c, + Err(e) => { + error!("config reload: failed to parse {:?}: {}", config_path, e); + return; + } + }; + + if let Err(e) = new_cfg.validate() { + error!("config reload: validation failed: {}; keeping old config", e); + return; + } + + let old_cfg = config_tx.borrow().clone(); + let old_hot = HotFields::from_config(&old_cfg); + let new_hot = HotFields::from_config(&new_cfg); + + if old_hot == new_hot { + return; + } + + warn_non_hot_changes(&old_cfg, &new_cfg); + log_changes(&old_hot, &new_hot, &new_cfg, log_tx, detected_ip_v4, detected_ip_v6); + config_tx.send(Arc::new(new_cfg)).ok(); +} + +// ── Public API ──────────────────────────────────────────────────────────────── + +/// Spawn the hot-reload watcher task. +/// +/// Uses `notify` (inotify on Linux) to detect file changes instantly. +/// SIGHUP is also handled on Unix as an additional manual trigger. +/// +/// `detected_ip_v4` / `detected_ip_v6` are the IPs discovered during the +/// startup probe — used when generating proxy links for newly added users, +/// matching the same logic as the startup output. +pub fn spawn_config_watcher( + config_path: PathBuf, + initial: Arc, + detected_ip_v4: Option, + detected_ip_v6: Option, +) -> (watch::Receiver>, watch::Receiver) { + let initial_level = initial.general.log_level.clone(); + let (config_tx, config_rx) = watch::channel(initial); + let (log_tx, log_rx) = watch::channel(initial_level); + + // Bridge: sync notify callback → async task via mpsc. + let (notify_tx, mut notify_rx) = mpsc::channel::<()>(4); + + // Canonicalize the config path so it matches what notify returns in events + // (notify always gives absolute paths, but config_path may be relative). + let config_path = match config_path.canonicalize() { + Ok(p) => p, + Err(_) => config_path.to_path_buf(), // file doesn't exist yet, use as-is + }; + + // Watch the parent directory rather than the file itself, because many + // editors (vim, nano, systemd-sysusers) write via rename, which would + // cause inotify to lose track of the original inode. + let watch_dir = config_path + .parent() + .unwrap_or_else(|| std::path::Path::new(".")) + .to_path_buf(); + + let config_file = config_path.clone(); + let tx_clone = notify_tx.clone(); + + let watcher_result = recommended_watcher(move |res: notify::Result| { + let Ok(event) = res else { return }; + + let is_our_file = event.paths.iter().any(|p| p == &config_file); + if !is_our_file { + return; + } + let relevant = matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ); + if relevant { + let _ = tx_clone.try_send(()); + } + }); + + match watcher_result { + Ok(mut watcher) => { + match watcher.watch(&watch_dir, RecursiveMode::NonRecursive) { + Ok(()) => info!("config watcher: watching {:?} via inotify", config_path), + Err(e) => warn!( + "config watcher: failed to watch {:?}: {}; use SIGHUP to reload", + watch_dir, e + ), + } + + tokio::spawn(async move { + let _watcher = watcher; // keep alive + + #[cfg(unix)] + let mut sighup = { + use tokio::signal::unix::{SignalKind, signal}; + signal(SignalKind::hangup()).expect("Failed to register SIGHUP handler") + }; + + loop { + #[cfg(unix)] + tokio::select! { + msg = notify_rx.recv() => { + if msg.is_none() { break; } + } + _ = sighup.recv() => { + info!("SIGHUP received — reloading {:?}", config_path); + } + } + #[cfg(not(unix))] + if notify_rx.recv().await.is_none() { break; } + + // Debounce: drain extra events fired within 50ms. + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + while notify_rx.try_recv().is_ok() {} + + reload_config( + &config_path, + &config_tx, + &log_tx, + detected_ip_v4, + detected_ip_v6, + ); + } + }); + } + Err(e) => { + warn!( + "config watcher: inotify unavailable ({}); only SIGHUP will trigger reload", + e + ); + // Fall back to SIGHUP-only. + tokio::spawn(async move { + #[cfg(unix)] + { + use tokio::signal::unix::{SignalKind, signal}; + let mut sighup = signal(SignalKind::hangup()) + .expect("Failed to register SIGHUP handler"); + loop { + sighup.recv().await; + info!("SIGHUP received — reloading {:?}", config_path); + reload_config( + &config_path, + &config_tx, + &log_tx, + detected_ip_v4, + detected_ip_v6, + ); + } + } + #[cfg(not(unix))] + let _ = (config_tx, log_tx, config_path); + }); + } + } + + (config_rx, log_rx) +} diff --git a/src/config/load.rs b/src/config/load.rs index a2fc19b..60a6bc2 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -163,6 +163,21 @@ impl ProxyConfig { config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); } + // Merge primary + extra TLS domains, deduplicate (primary always first). + if !config.censorship.tls_domains.is_empty() { + let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len()); + all.push(config.censorship.tls_domain.clone()); + for d in std::mem::take(&mut config.censorship.tls_domains) { + if !d.is_empty() && !all.contains(&d) { + all.push(d); + } + } + // keep primary as tls_domain; store remaining back to tls_domains + if all.len() > 1 { + config.censorship.tls_domains = all[1..].to_vec(); + } + } + // Migration: prefer_ipv6 -> network.prefer. if config.general.prefer_ipv6 { if config.network.prefer == 4 { @@ -180,7 +195,7 @@ impl ProxyConfig { validate_network_cfg(&mut config.network)?; // Random fake_cert_len only when default is in use. - if config.censorship.fake_cert_len == default_fake_cert_len() { + if !config.censorship.tls_emulation && config.censorship.fake_cert_len == default_fake_cert_len() { config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); } @@ -235,7 +250,7 @@ impl ProxyConfig { // Migration: Populate upstreams if empty (Default Direct). if config.upstreams.is_empty() { config.upstreams.push(UpstreamConfig { - upstream_type: UpstreamType::Direct { interface: None }, + upstream_type: UpstreamType::Direct { interface: None, bind_addresses: None }, weight: 1, enabled: true, scopes: String::new(), diff --git a/src/config/mod.rs b/src/config/mod.rs index a82d92b..c7187ad 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod defaults; mod types; mod load; +pub mod hot_reload; pub use load::ProxyConfig; pub use types::*; diff --git a/src/config/types.rs b/src/config/types.rs index 9f6467a..fa69c12 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -295,6 +295,11 @@ pub struct ServerConfig { #[serde(default)] pub listen_tcp: Option, + /// Accept HAProxy PROXY protocol headers on incoming connections. + /// When enabled, real client IPs are extracted from PROXY v1/v2 headers. + #[serde(default)] + pub proxy_protocol: bool, + #[serde(default)] pub metrics_port: Option, @@ -314,6 +319,7 @@ impl Default for ServerConfig { listen_unix_sock: None, listen_unix_sock_perm: None, listen_tcp: None, + proxy_protocol: false, metrics_port: None, metrics_whitelist: default_metrics_whitelist(), listeners: Vec::new(), @@ -362,6 +368,10 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] pub tls_domain: String, + /// Additional TLS domains for generating multiple proxy links. + #[serde(default)] + pub tls_domains: Vec, + #[serde(default = "default_true")] pub mask: bool, @@ -376,22 +386,33 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_fake_cert_len")] pub fake_cert_len: usize, + + /// Enable TLS certificate emulation using cached real certificates. + #[serde(default)] + pub tls_emulation: bool, + + /// Directory to store TLS front cache (on disk). + #[serde(default = "default_tls_front_dir")] + pub tls_front_dir: String, } impl Default for AntiCensorshipConfig { fn default() -> Self { Self { tls_domain: default_tls_domain(), + tls_domains: Vec::new(), mask: true, mask_host: None, mask_port: default_mask_port(), mask_unix_sock: None, fake_cert_len: default_fake_cert_len(), + tls_emulation: false, + tls_front_dir: default_tls_front_dir(), } } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct AccessConfig { #[serde(default)] pub users: HashMap, @@ -446,6 +467,8 @@ pub enum UpstreamType { Direct { #[serde(default)] interface: Option, + #[serde(default)] + bind_addresses: Option>, }, Socks4 { address: String, diff --git a/src/main.rs b/src/main.rs index d542b63..4b33c9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,9 +23,11 @@ mod proxy; mod stats; mod stream; mod transport; +mod tls_front; mod util; use crate::config::{LogLevel, ProxyConfig}; +use crate::config::hot_reload::spawn_config_watcher; use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe}; @@ -36,6 +38,7 @@ use crate::transport::middle_proxy::{ MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line, }; use crate::transport::{ListenOptions, UpstreamManager, create_listener}; +use crate::tls_front::TlsFrontCache; fn parse_cli() -> (String, bool, Option) { let mut config_path = "config.toml".to_string(); @@ -129,12 +132,22 @@ fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { ); } if config.general.modes.tls { - let domain_hex = hex::encode(&config.censorship.tls_domain); - info!( - target: "telemt::links", - " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", - host, port, secret, domain_hex - ); + let mut domains = Vec::with_capacity(1 + config.censorship.tls_domains.len()); + domains.push(config.censorship.tls_domain.clone()); + for d in &config.censorship.tls_domains { + if !domains.contains(d) { + domains.push(d.clone()); + } + } + + for domain in domains { + let domain_hex = hex::encode(&domain); + info!( + target: "telemt::links", + " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + host, port, secret, domain_hex + ); + } } } else { warn!(target: "telemt::links", "User '{}' in show_link not found", user_name); @@ -247,6 +260,46 @@ async fn main() -> std::result::Result<(), Box> { info!("IP limits configured for {} users", config.access.user_max_unique_ips.len()); } + // TLS front cache (optional emulation) + let mut tls_domains = Vec::with_capacity(1 + config.censorship.tls_domains.len()); + tls_domains.push(config.censorship.tls_domain.clone()); + for d in &config.censorship.tls_domains { + if !tls_domains.contains(d) { + tls_domains.push(d.clone()); + } + } + + let tls_cache: Option> = if config.censorship.tls_emulation { + let cache = Arc::new(TlsFrontCache::new( + &tls_domains, + config.censorship.fake_cert_len, + &config.censorship.tls_front_dir, + )); + + let cache_clone = cache.clone(); + let domains = tls_domains.clone(); + let port = config.censorship.mask_port; + tokio::spawn(async move { + for domain in domains { + match crate::tls_front::fetcher::fetch_real_tls( + &domain, + port, + &domain, + Duration::from_secs(5), + ) + .await + { + Ok(res) => cache_clone.update_from_fetch(&domain, res).await, + Err(e) => warn!(domain = %domain, error = %e, "TLS emulation fetch failed"), + } + } + }); + + Some(cache) + } else { + None + }; + // Connection concurrency limit let _max_connections = Arc::new(Semaphore::new(10_000)); @@ -604,6 +657,19 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai detected_ip_v4, detected_ip_v6 ); + // ── Hot-reload watcher ──────────────────────────────────────────────── + // Uses inotify to detect file changes instantly (SIGHUP also works). + // detected_ip_v4/v6 are passed so newly added users get correct TG links. + let (config_rx, mut log_level_rx): ( + tokio::sync::watch::Receiver>, + tokio::sync::watch::Receiver, + ) = spawn_config_watcher( + std::path::PathBuf::from(&config_path), + config.clone(), + detected_ip_v4, + detected_ip_v6, + ); + let mut listeners = Vec::new(); for listener_conf in &config.server.listeners { @@ -708,13 +774,14 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai has_unix_listener = true; - let config = config.clone(); + let mut config_rx_unix: tokio::sync::watch::Receiver> = config_rx.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); tokio::spawn(async move { @@ -726,20 +793,21 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let conn_id = unix_conn_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let fake_peer = SocketAddr::from(([127, 0, 0, 1], (conn_id % 65535) as u16)); - let config = config.clone(); + let config = config_rx_unix.borrow_and_update().clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); tokio::spawn(async move { if let Err(e) = crate::proxy::client::handle_client_stream( stream, fake_peer, config, stats, upstream_manager, replay_checker, buffer_pool, rng, - me_pool, ip_tracker, + me_pool, tls_cache, ip_tracker, ).await { debug!(error = %e, "Unix socket connection error"); } @@ -771,6 +839,20 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai .reload(runtime_filter) .expect("Failed to switch log filter"); + // Apply log_level changes from hot-reload to the tracing filter. + tokio::spawn(async move { + loop { + if log_level_rx.changed().await.is_err() { + break; + } + let level = log_level_rx.borrow_and_update().clone(); + let new_filter = tracing_subscriber::EnvFilter::new(level.to_filter_str()); + if let Err(e) = filter_handle.reload(new_filter) { + tracing::error!("config reload: failed to update log filter: {}", e); + } + } + }); + if let Some(port) = config.server.metrics_port { let stats = stats.clone(); let whitelist = config.server.metrics_whitelist.clone(); @@ -780,26 +862,28 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai } for listener in listeners { - let config = config.clone(); + let mut config_rx: tokio::sync::watch::Receiver> = config_rx.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); tokio::spawn(async move { loop { match listener.accept().await { Ok((stream, peer_addr)) => { - let config = config.clone(); + let config = config_rx.borrow_and_update().clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); tokio::spawn(async move { @@ -813,12 +897,13 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai buffer_pool, rng, me_pool, + tls_cache, ip_tracker, ) .run() .await { - debug!(peer = %peer_addr, error = %e, "Connection error"); + warn!(peer = %peer_addr, error = %e, "Connection closed with error"); } }); } diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 520b6ea..39eb7e6 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -397,6 +397,84 @@ pub fn build_server_hello( response } +/// Extract SNI (server_name) from a TLS ClientHello. +pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { + if handshake.len() < 43 || handshake[0] != TLS_RECORD_HANDSHAKE { + return None; + } + + let mut pos = 5; // after record header + if handshake.get(pos).copied()? != 0x01 { + return None; // not ClientHello + } + + // Handshake length bytes + pos += 4; // type + len (3) + + // version (2) + random (32) + pos += 2 + 32; + if pos + 1 > handshake.len() { + return None; + } + + let session_id_len = *handshake.get(pos)? as usize; + pos += 1 + session_id_len; + if pos + 2 > handshake.len() { + return None; + } + + let cipher_suites_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; + pos += 2 + cipher_suites_len; + if pos + 1 > handshake.len() { + return None; + } + + let comp_len = *handshake.get(pos)? as usize; + pos += 1 + comp_len; + if pos + 2 > handshake.len() { + return None; + } + + let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; + pos += 2; + let ext_end = pos + ext_len; + if ext_end > handshake.len() { + return None; + } + + while pos + 4 <= ext_end { + let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]); + let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize; + pos += 4; + if pos + elen > ext_end { + break; + } + if etype == 0x0000 && elen >= 5 { + // server_name extension + let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; + let mut sn_pos = pos + 2; + let sn_end = std::cmp::min(sn_pos + list_len, pos + elen); + while sn_pos + 3 <= sn_end { + let name_type = handshake[sn_pos]; + let name_len = u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize; + sn_pos += 3; + if sn_pos + name_len > sn_end { + break; + } + if name_type == 0 && name_len > 0 { + if let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) { + return Some(host.to_string()); + } + } + sn_pos += name_len; + } + } + pos += elen; + } + + None +} + /// Check if bytes look like a TLS ClientHello pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { if first_bytes.len() < 3 { diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 87d6b52..8a8ae81 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -30,7 +30,8 @@ use crate::protocol::tls; use crate::stats::{ReplayChecker, Stats}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::middle_proxy::MePool; -use crate::transport::{UpstreamManager, configure_client_socket}; +use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol}; +use crate::tls_front::TlsFrontCache; use crate::proxy::direct_relay::handle_via_direct; use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; @@ -47,13 +48,35 @@ pub async fn handle_client_stream( buffer_pool: Arc, rng: Arc, me_pool: Option>, + tls_cache: Option>, ip_tracker: Arc, ) -> Result<()> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { stats.increment_connects_all(); - debug!(peer = %peer, "New connection (generic stream)"); + let mut real_peer = peer; + + if config.server.proxy_protocol { + match parse_proxy_protocol(&mut stream, peer).await { + Ok(info) => { + debug!( + peer = %peer, + client = %info.src_addr, + version = info.version, + "PROXY protocol header parsed" + ); + real_peer = info.src_addr; + } + Err(e) => { + stats.increment_connects_bad(); + warn!(peer = %peer, error = %e, "Invalid PROXY protocol header"); + return Err(e); + } + } + } + + debug!(peer = %real_peer, "New connection (generic stream)"); let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); let stats_for_timeout = stats.clone(); @@ -69,13 +92,13 @@ where stream.read_exact(&mut first_bytes).await?; let is_tls = tls::is_tls_handshake(&first_bytes[..3]); - debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + debug!(peer = %real_peer, is_tls = is_tls, "Handshake type detected"); if is_tls { let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + debug!(peer = %real_peer, tls_len = tls_len, "TLS handshake too short"); stats.increment_connects_bad(); let (reader, writer) = tokio::io::split(stream); handle_bad_client(reader, writer, &first_bytes, &config).await; @@ -89,8 +112,8 @@ where let (read_half, write_half) = tokio::io::split(stream); let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( - &handshake, read_half, write_half, peer, - &config, &replay_checker, &rng, + &handshake, read_half, write_half, real_peer, + &config, &replay_checker, &rng, tls_cache.clone(), ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { @@ -107,7 +130,7 @@ where .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &mtproto_handshake, tls_reader, tls_writer, peer, + &mtproto_handshake, tls_reader, tls_writer, real_peer, &config, &replay_checker, true, ).await { HandshakeResult::Success(result) => result, @@ -123,12 +146,12 @@ where RunningClientHandler::handle_authenticated_static( crypto_reader, crypto_writer, success, upstream_manager, stats, config, buffer_pool, rng, me_pool, - local_addr, peer, ip_tracker.clone(), + local_addr, real_peer, ip_tracker.clone(), ), ))) } else { if !config.general.modes.classic && !config.general.modes.secure { - debug!(peer = %peer, "Non-TLS modes disabled"); + debug!(peer = %real_peer, "Non-TLS modes disabled"); stats.increment_connects_bad(); let (reader, writer) = tokio::io::split(stream); handle_bad_client(reader, writer, &first_bytes, &config).await; @@ -142,7 +165,7 @@ where let (read_half, write_half) = tokio::io::split(stream); let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &handshake, read_half, write_half, peer, + &handshake, read_half, write_half, real_peer, &config, &replay_checker, false, ).await { HandshakeResult::Success(result) => result, @@ -166,7 +189,7 @@ where rng, me_pool, local_addr, - peer, + real_peer, ip_tracker.clone(), ) ))) @@ -203,6 +226,7 @@ pub struct RunningClientHandler { buffer_pool: Arc, rng: Arc, me_pool: Option>, + tls_cache: Option>, ip_tracker: Arc, } @@ -217,6 +241,7 @@ impl ClientHandler { buffer_pool: Arc, rng: Arc, me_pool: Option>, + tls_cache: Option>, ip_tracker: Arc, ) -> RunningClientHandler { RunningClientHandler { @@ -229,6 +254,7 @@ impl ClientHandler { buffer_pool, rng, me_pool, + tls_cache, ip_tracker, } } @@ -275,6 +301,25 @@ impl RunningClientHandler { } async fn do_handshake(mut self) -> Result { + if self.config.server.proxy_protocol { + match parse_proxy_protocol(&mut self.stream, self.peer).await { + Ok(info) => { + debug!( + peer = %self.peer, + client = %info.src_addr, + version = info.version, + "PROXY protocol header parsed" + ); + self.peer = info.src_addr; + } + Err(e) => { + self.stats.increment_connects_bad(); + warn!(peer = %self.peer, error = %e, "Invalid PROXY protocol header"); + return Err(e); + } + } + } + let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; @@ -327,6 +372,7 @@ impl RunningClientHandler { &config, &replay_checker, &self.rng, + self.tls_cache.clone(), ) .await { diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 0023b7a..8b61112 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -1,6 +1,7 @@ //! MTProto Handshake use std::net::SocketAddr; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace, info}; use zeroize::Zeroize; @@ -12,6 +13,7 @@ use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; use crate::error::{ProxyError, HandshakeResult}; use crate::stats::ReplayChecker; use crate::config::ProxyConfig; +use crate::tls_front::{TlsFrontCache, emulator}; /// Result of successful handshake /// @@ -55,6 +57,7 @@ pub async fn handle_tls_handshake( config: &ProxyConfig, replay_checker: &ReplayChecker, rng: &SecureRandom, + tls_cache: Option>, ) -> HandshakeResult<(FakeTlsReader, FakeTlsWriter, String), R, W> where R: AsyncRead + Unpin, @@ -102,13 +105,37 @@ where None => return HandshakeResult::BadClient { reader, writer }, }; - let response = tls::build_server_hello( - secret, - &validation.digest, - &validation.session_id, - config.censorship.fake_cert_len, - rng, - ); + let cached = if config.censorship.tls_emulation { + if let Some(cache) = tls_cache.as_ref() { + if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { + Some(cache.get(&sni).await) + } else { + Some(cache.get(&config.censorship.tls_domain).await) + } + } else { + None + } + } else { + None + }; + + let response = if let Some(cached_entry) = cached { + emulator::build_emulated_server_hello( + secret, + &validation.digest, + &validation.session_id, + &cached_entry, + rng, + ) + } else { + tls::build_server_hello( + secret, + &validation.digest, + &validation.session_id, + config.censorship.fake_cert_len, + rng, + ) + }; debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); diff --git a/src/tls_front/cache.rs b/src/tls_front/cache.rs new file mode 100644 index 0000000..3fddd07 --- /dev/null +++ b/src/tls_front/cache.rs @@ -0,0 +1,103 @@ +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::{SystemTime, Duration}; + +use tokio::sync::RwLock; +use tokio::time::sleep; +use tracing::{debug, warn, info}; + +use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsFetchResult}; + +/// Lightweight in-memory + optional on-disk cache for TLS fronting data. +#[derive(Debug)] +pub struct TlsFrontCache { + memory: RwLock>>, + default: Arc, + disk_path: PathBuf, +} + +impl TlsFrontCache { + pub fn new(domains: &[String], default_len: usize, disk_path: impl AsRef) -> Self { + let default_template = ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }; + + let default = Arc::new(CachedTlsData { + server_hello_template: default_template, + cert_info: None, + app_data_records_sizes: vec![default_len], + total_app_data_len: default_len, + fetched_at: SystemTime::now(), + domain: "default".to_string(), + }); + + let mut map = HashMap::new(); + for d in domains { + map.insert(d.clone(), default.clone()); + } + + Self { + memory: RwLock::new(map), + default, + disk_path: disk_path.as_ref().to_path_buf(), + } + } + + pub async fn get(&self, sni: &str) -> Arc { + let guard = self.memory.read().await; + guard.get(sni).cloned().unwrap_or_else(|| self.default.clone()) + } + + pub async fn set(&self, domain: &str, data: CachedTlsData) { + let mut guard = self.memory.write().await; + guard.insert(domain.to_string(), Arc::new(data)); + } + + /// Spawn background updater that periodically refreshes cached domains using provided fetcher. + pub fn spawn_updater( + self: Arc, + domains: Vec, + interval: Duration, + fetcher: F, + ) where + F: Fn(String) -> tokio::task::JoinHandle<()> + Send + Sync + 'static, + { + tokio::spawn(async move { + loop { + for domain in &domains { + fetcher(domain.clone()).await; + } + sleep(interval).await; + } + }); + } + + /// Replace cached entry from a fetch result. + pub async fn update_from_fetch(&self, domain: &str, fetched: TlsFetchResult) { + let data = CachedTlsData { + server_hello_template: fetched.server_hello_parsed, + cert_info: None, + app_data_records_sizes: fetched.app_data_records_sizes.clone(), + total_app_data_len: fetched.total_app_data_len, + fetched_at: SystemTime::now(), + domain: domain.to_string(), + }; + + self.set(domain, data).await; + debug!(domain = %domain, len = fetched.total_app_data_len, "TLS cache updated"); + } + + pub fn default_entry(&self) -> Arc { + self.default.clone() + } + + pub fn disk_path(&self) -> &Path { + &self.disk_path + } +} diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs new file mode 100644 index 0000000..8328884 --- /dev/null +++ b/src/tls_front/emulator.rs @@ -0,0 +1,104 @@ +use crate::crypto::{sha256_hmac, SecureRandom}; +use crate::protocol::constants::{ + TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, +}; +use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key}; +use crate::tls_front::types::CachedTlsData; + +/// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata. +pub fn build_emulated_server_hello( + secret: &[u8], + client_digest: &[u8; TLS_DIGEST_LEN], + session_id: &[u8], + cached: &CachedTlsData, + rng: &SecureRandom, +) -> Vec { + // --- ServerHello --- + let mut extensions = Vec::new(); + // KeyShare (x25519) + let key = gen_fake_x25519_key(rng); + extensions.extend_from_slice(&0x0033u16.to_be_bytes()); // key_share + extensions.extend_from_slice(&(2 + 2 + 32u16).to_be_bytes()); // len + extensions.extend_from_slice(&0x001du16.to_be_bytes()); // X25519 + extensions.extend_from_slice(&(32u16).to_be_bytes()); + extensions.extend_from_slice(&key); + // supported_versions (TLS1.3) + extensions.extend_from_slice(&0x002bu16.to_be_bytes()); + extensions.extend_from_slice(&(2u16).to_be_bytes()); + extensions.extend_from_slice(&0x0304u16.to_be_bytes()); + + let extensions_len = extensions.len() as u16; + + let body_len = 2 + // version + 32 + // random + 1 + session_id.len() + // session id + 2 + // cipher + 1 + // compression + 2 + extensions.len(); // extensions + + let mut message = Vec::with_capacity(4 + body_len); + message.push(0x02); // ServerHello + let len_bytes = (body_len as u32).to_be_bytes(); + message.extend_from_slice(&len_bytes[1..4]); + message.extend_from_slice(&cached.server_hello_template.version); // 0x0303 + message.extend_from_slice(&[0u8; 32]); // random placeholder + message.push(session_id.len() as u8); + message.extend_from_slice(session_id); + let cipher = if cached.server_hello_template.cipher_suite == [0, 0] { + [0x13, 0x01] + } else { + cached.server_hello_template.cipher_suite + }; + message.extend_from_slice(&cipher); + message.push(cached.server_hello_template.compression); + message.extend_from_slice(&extensions_len.to_be_bytes()); + message.extend_from_slice(&extensions); + + let mut server_hello = Vec::with_capacity(5 + message.len()); + server_hello.push(TLS_RECORD_HANDSHAKE); + server_hello.extend_from_slice(&TLS_VERSION); + server_hello.extend_from_slice(&(message.len() as u16).to_be_bytes()); + server_hello.extend_from_slice(&message); + + // --- ChangeCipherSpec --- + let change_cipher_spec = [ + TLS_RECORD_CHANGE_CIPHER, + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x01, + 0x01, + ]; + + // --- ApplicationData (fake encrypted records) --- + // Use the same number and sizes of ApplicationData records as the cached server. + let mut sizes = cached.app_data_records_sizes.clone(); + if sizes.is_empty() { + sizes.push(cached.total_app_data_len.max(1024)); + } + + let mut app_data = Vec::new(); + for size in sizes { + let mut rec = Vec::with_capacity(5 + size); + rec.push(TLS_RECORD_APPLICATION); + rec.extend_from_slice(&TLS_VERSION); + rec.extend_from_slice(&(size as u16).to_be_bytes()); + rec.extend_from_slice(&rng.bytes(size)); + app_data.extend_from_slice(&rec); + } + + // --- Combine --- + let mut response = Vec::with_capacity(server_hello.len() + change_cipher_spec.len() + app_data.len()); + response.extend_from_slice(&server_hello); + response.extend_from_slice(&change_cipher_spec); + response.extend_from_slice(&app_data); + + // --- HMAC --- + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len()); + hmac_input.extend_from_slice(client_digest); + hmac_input.extend_from_slice(&response); + let digest = sha256_hmac(secret, &hmac_input); + response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + + response +} diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs new file mode 100644 index 0000000..0ce8d6b --- /dev/null +++ b/src/tls_front/fetcher.rs @@ -0,0 +1,391 @@ +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{Context, Result, anyhow}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tokio_rustls::client::TlsStream; +use tokio_rustls::TlsConnector; +use tracing::{debug, warn}; + +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::client::ClientConfig; +use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; +use rustls::{DigitallySignedStruct, Error as RustlsError}; + +use crate::crypto::SecureRandom; +use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use crate::tls_front::types::{ParsedServerHello, TlsExtension, TlsFetchResult}; + +/// No-op verifier: accept any certificate (we only need lengths and metadata). +#[derive(Debug)] +struct NoVerify; + +impl ServerCertVerifier for NoVerify { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + use rustls::SignatureScheme::*; + vec![ + RSA_PKCS1_SHA256, + RSA_PSS_SHA256, + ECDSA_NISTP256_SHA256, + ECDSA_NISTP384_SHA384, + ] + } +} + +fn build_client_config() -> Arc { + let root = rustls::RootCertStore::empty(); + + let provider = rustls::crypto::ring::default_provider(); + let mut config = ClientConfig::builder_with_provider(Arc::new(provider)) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .expect("protocol versions") + .with_root_certificates(root) + .with_no_client_auth(); + + config + .dangerous() + .set_certificate_verifier(Arc::new(NoVerify)); + + Arc::new(config) +} + +fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { + // === ClientHello body === + let mut body = Vec::new(); + + // Legacy version (TLS 1.0) as in real ClientHello headers + body.extend_from_slice(&[0x03, 0x03]); + + // Random + body.extend_from_slice(&rng.bytes(32)); + + // Session ID: empty + body.push(0); + + // Cipher suites (common minimal set, TLS1.3 + a few 1.2 fallbacks) + let cipher_suites: [u8; 10] = [ + 0x13, 0x01, // TLS_AES_128_GCM_SHA256 + 0x13, 0x02, // TLS_AES_256_GCM_SHA384 + 0x13, 0x03, // TLS_CHACHA20_POLY1305_SHA256 + 0x00, 0x2f, // TLS_RSA_WITH_AES_128_CBC_SHA (legacy) + 0x00, 0xff, // RENEGOTIATION_INFO_SCSV + ]; + body.extend_from_slice(&(cipher_suites.len() as u16).to_be_bytes()); + body.extend_from_slice(&cipher_suites); + + // Compression methods: null only + body.push(1); + body.push(0); + + // === Extensions === + let mut exts = Vec::new(); + + // server_name (SNI) + let sni_bytes = sni.as_bytes(); + let mut sni_ext = Vec::with_capacity(5 + sni_bytes.len()); + sni_ext.extend_from_slice(&(sni_bytes.len() as u16 + 3).to_be_bytes()); + sni_ext.push(0); // host_name + sni_ext.extend_from_slice(&(sni_bytes.len() as u16).to_be_bytes()); + sni_ext.extend_from_slice(sni_bytes); + exts.extend_from_slice(&0x0000u16.to_be_bytes()); + exts.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + exts.extend_from_slice(&sni_ext); + + // supported_groups + let groups: [u16; 2] = [0x001d, 0x0017]; // x25519, secp256r1 + exts.extend_from_slice(&0x000au16.to_be_bytes()); + exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes()); + exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes()); + for g in groups { exts.extend_from_slice(&g.to_be_bytes()); } + + // signature_algorithms + let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, rsa_pkcs1_sha256 + exts.extend_from_slice(&0x000du16.to_be_bytes()); + exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); + exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); + for a in sig_algs { exts.extend_from_slice(&a.to_be_bytes()); } + + // supported_versions (TLS1.3 + TLS1.2) + let versions: [u16; 2] = [0x0304, 0x0303]; + exts.extend_from_slice(&0x002bu16.to_be_bytes()); + exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes()); + exts.push((versions.len() * 2) as u8); + for v in versions { exts.extend_from_slice(&v.to_be_bytes()); } + + // key_share (x25519) + let key = gen_key_share(rng); + let mut keyshare = Vec::with_capacity(4 + key.len()); + keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group + keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes()); + keyshare.extend_from_slice(&key); + exts.extend_from_slice(&0x0033u16.to_be_bytes()); + exts.extend_from_slice(&((2 + keyshare.len()) as u16).to_be_bytes()); + exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes()); + exts.extend_from_slice(&keyshare); + + // ALPN (http/1.1) + let alpn_proto = b"http/1.1"; + exts.extend_from_slice(&0x0010u16.to_be_bytes()); + exts.extend_from_slice(&((2 + 1 + alpn_proto.len()) as u16).to_be_bytes()); + exts.extend_from_slice(&((1 + alpn_proto.len()) as u16).to_be_bytes()); + exts.push(alpn_proto.len() as u8); + exts.extend_from_slice(alpn_proto); + + // padding to reduce recognizability and keep length ~500 bytes + if exts.len() < 180 { + let pad_len = 180 - exts.len(); + exts.extend_from_slice(&0x0015u16.to_be_bytes()); // padding extension + exts.extend_from_slice(&(pad_len as u16 + 2).to_be_bytes()); + exts.extend_from_slice(&(pad_len as u16).to_be_bytes()); + exts.resize(exts.len() + pad_len, 0); + } + + // Extensions length prefix + body.extend_from_slice(&(exts.len() as u16).to_be_bytes()); + body.extend_from_slice(&exts); + + // === Handshake wrapper === + let mut handshake = Vec::new(); + handshake.push(0x01); // ClientHello + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + // === Record === + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); // legacy record version + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record +} + +fn gen_key_share(rng: &SecureRandom) -> [u8; 32] { + let mut key = [0u8; 32]; + key.copy_from_slice(&rng.bytes(32)); + key +} + +async fn read_tls_record(stream: &mut TcpStream) -> Result<(u8, Vec)> { + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await?; + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await?; + Ok((header[0], body)) +} + +fn parse_server_hello(body: &[u8]) -> Option { + if body.len() < 4 || body[0] != 0x02 { + return None; + } + + let msg_len = u32::from_be_bytes([0, body[1], body[2], body[3]]) as usize; + if msg_len + 4 > body.len() { + return None; + } + + let mut pos = 4; + let version = [*body.get(pos)?, *body.get(pos + 1)?]; + pos += 2; + + let mut random = [0u8; 32]; + random.copy_from_slice(body.get(pos..pos + 32)?); + pos += 32; + + let session_len = *body.get(pos)? as usize; + pos += 1; + let session_id = body.get(pos..pos + session_len)?.to_vec(); + pos += session_len; + + let cipher_suite = [*body.get(pos)?, *body.get(pos + 1)?]; + pos += 2; + + let compression = *body.get(pos)?; + pos += 1; + + let ext_len = u16::from_be_bytes([*body.get(pos)?, *body.get(pos + 1)?]) as usize; + pos += 2; + let ext_end = pos.checked_add(ext_len)?; + if ext_end > body.len() { + return None; + } + + let mut extensions = Vec::new(); + while pos + 4 <= ext_end { + let etype = u16::from_be_bytes([body[pos], body[pos + 1]]); + let elen = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize; + pos += 4; + let data = body.get(pos..pos + elen)?.to_vec(); + pos += elen; + extensions.push(TlsExtension { ext_type: etype, data }); + } + + Some(ParsedServerHello { + version, + random, + session_id, + cipher_suite, + compression, + extensions, + }) +} + +async fn fetch_via_raw_tls( + host: &str, + port: u16, + sni: &str, + connect_timeout: Duration, +) -> Result { + let addr = format!("{host}:{port}"); + let mut stream = timeout(connect_timeout, TcpStream::connect(addr)).await??; + + let rng = SecureRandom::new(); + let client_hello = build_client_hello(sni, &rng); + timeout(connect_timeout, async { + stream.write_all(&client_hello).await?; + stream.flush().await?; + Ok::<(), std::io::Error>(()) + }) + .await??; + + let mut records = Vec::new(); + // Read up to 4 records: ServerHello, CCS, and up to two ApplicationData. + for _ in 0..4 { + match timeout(connect_timeout, read_tls_record(&mut stream)).await { + Ok(Ok(rec)) => records.push(rec), + Ok(Err(e)) => return Err(e.into()), + Err(_) => break, + } + if records.len() >= 3 && records.iter().any(|(t, _)| *t == TLS_RECORD_APPLICATION) { + break; + } + } + + let mut app_sizes = Vec::new(); + let mut server_hello = None; + for (t, body) in &records { + if *t == TLS_RECORD_HANDSHAKE && server_hello.is_none() { + server_hello = parse_server_hello(body); + } else if *t == TLS_RECORD_APPLICATION { + app_sizes.push(body.len()); + } + } + + let parsed = server_hello.ok_or_else(|| anyhow!("ServerHello not received"))?; + let total_app_data_len = app_sizes.iter().sum::().max(1024); + + Ok(TlsFetchResult { + server_hello_parsed: parsed, + app_data_records_sizes: if app_sizes.is_empty() { + vec![total_app_data_len] + } else { + app_sizes + }, + total_app_data_len, + }) +} + +/// Fetch real TLS metadata for the given SNI: negotiated cipher and cert lengths. +pub async fn fetch_real_tls( + host: &str, + port: u16, + sni: &str, + connect_timeout: Duration, +) -> Result { + // Preferred path: raw TLS probe for accurate record sizing + match fetch_via_raw_tls(host, port, sni, connect_timeout).await { + Ok(res) => return Ok(res), + Err(e) => { + warn!(sni = %sni, error = %e, "Raw TLS fetch failed, falling back to rustls"); + } + } + + // Fallback: rustls handshake to at least get certificate sizes + let addr = format!("{host}:{port}"); + let stream = timeout(connect_timeout, TcpStream::connect(addr)).await??; + + let config = build_client_config(); + let connector = TlsConnector::from(config); + + let server_name = ServerName::try_from(sni.to_owned()) + .or_else(|_| ServerName::try_from(host.to_owned())) + .map_err(|_| RustlsError::General("invalid SNI".into()))?; + + let tls_stream: TlsStream = connector.connect(server_name, stream).await?; + + // Extract negotiated parameters and certificates + let (_io, session) = tls_stream.get_ref(); + let cipher_suite = session + .negotiated_cipher_suite() + .map(|s| u16::from(s.suite()).to_be_bytes()) + .unwrap_or([0x13, 0x01]); + + let certs: Vec> = session + .peer_certificates() + .map(|slice| slice.to_vec()) + .unwrap_or_default(); + + let total_cert_len: usize = certs.iter().map(|c| c.len()).sum::().max(1024); + + // Heuristic: split across two records if large to mimic real servers a bit. + let app_data_records_sizes = if total_cert_len > 3000 { + vec![total_cert_len / 2, total_cert_len - total_cert_len / 2] + } else { + vec![total_cert_len] + }; + + let parsed = ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite, + compression: 0, + extensions: Vec::new(), + }; + + debug!( + sni = %sni, + len = total_cert_len, + cipher = format!("0x{:04x}", u16::from_be_bytes(cipher_suite)), + "Fetched TLS metadata via rustls" + ); + + Ok(TlsFetchResult { + server_hello_parsed: parsed, + app_data_records_sizes: app_data_records_sizes.clone(), + total_app_data_len: app_data_records_sizes.iter().sum(), + }) +} diff --git a/src/tls_front/mod.rs b/src/tls_front/mod.rs new file mode 100644 index 0000000..89f3988 --- /dev/null +++ b/src/tls_front/mod.rs @@ -0,0 +1,7 @@ +pub mod types; +pub mod cache; +pub mod fetcher; +pub mod emulator; + +pub use cache::TlsFrontCache; +pub use types::{CachedTlsData, TlsFetchResult}; diff --git a/src/tls_front/types.rs b/src/tls_front/types.rs new file mode 100644 index 0000000..7f346db --- /dev/null +++ b/src/tls_front/types.rs @@ -0,0 +1,48 @@ +use std::time::SystemTime; + +/// Parsed representation of an unencrypted TLS ServerHello. +#[derive(Debug, Clone)] +pub struct ParsedServerHello { + pub version: [u8; 2], + pub random: [u8; 32], + pub session_id: Vec, + pub cipher_suite: [u8; 2], + pub compression: u8, + pub extensions: Vec, +} + +/// Generic TLS extension container. +#[derive(Debug, Clone)] +pub struct TlsExtension { + pub ext_type: u16, + pub data: Vec, +} + +/// Basic certificate metadata (optional, informative). +#[derive(Debug, Clone)] +pub struct ParsedCertificateInfo { + pub not_after_unix: Option, + pub not_before_unix: Option, + pub issuer_cn: Option, + pub subject_cn: Option, + pub san_names: Vec, +} + +/// Cached data per SNI used by the emulator. +#[derive(Debug, Clone)] +pub struct CachedTlsData { + pub server_hello_template: ParsedServerHello, + pub cert_info: Option, + pub app_data_records_sizes: Vec, + pub total_app_data_len: usize, + pub fetched_at: SystemTime, + pub domain: String, +} + +/// Result of attempting to fetch real TLS artifacts. +#[derive(Debug, Clone)] +pub struct TlsFetchResult { + pub server_hello_parsed: ParsedServerHello, + pub app_data_records_sizes: Vec, + pub total_app_data_len: usize, +} diff --git a/src/transport/socket.rs b/src/transport/socket.rs index a4a7034..f353c52 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -122,6 +122,38 @@ pub fn get_local_addr(stream: &TcpStream) -> Option { stream.local_addr().ok() } +/// Resolve primary IP address of a network interface by name. +/// Returns the first address matching the requested family (IPv4/IPv6). +#[cfg(unix)] +pub fn resolve_interface_ip(name: &str, want_ipv6: bool) -> Option { + use nix::ifaddrs::getifaddrs; + + if let Ok(addrs) = getifaddrs() { + for iface in addrs { + if iface.interface_name == name { + if let Some(address) = iface.address { + if let Some(v4) = address.as_sockaddr_in() { + if !want_ipv6 { + return Some(IpAddr::V4(v4.ip())); + } + } else if let Some(v6) = address.as_sockaddr_in6() { + if want_ipv6 { + return Some(IpAddr::V6(v6.ip().clone())); + } + } + } + } + } + } + None +} + +/// Stub for non-Unix platforms: interface name resolution unsupported. +#[cfg(not(unix))] +pub fn resolve_interface_ip(_name: &str, _want_ipv6: bool) -> Option { + None +} + /// Get peer address of a socket pub fn get_peer_addr(stream: &TcpStream) -> Option { stream.peer_addr().ok() diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 8fdd437..660043f 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -5,6 +5,7 @@ use std::collections::HashMap; use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; use tokio::net::TcpStream; use tokio::sync::RwLock; @@ -15,7 +16,7 @@ use tracing::{debug, warn, info, trace}; use crate::config::{UpstreamConfig, UpstreamType}; use crate::error::{Result, ProxyError}; use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT}; -use crate::transport::socket::create_outgoing_socket_bound; +use crate::transport::socket::{create_outgoing_socket_bound, resolve_interface_ip}; use crate::transport::socks::{connect_socks4, connect_socks5}; /// Number of Telegram datacenters @@ -84,6 +85,8 @@ struct UpstreamState { dc_latency: [LatencyEma; NUM_DCS], /// Per-DC IP version preference (learned from connectivity tests) dc_ip_pref: [IpPreference; NUM_DCS], + /// Round-robin counter for bind_addresses selection + bind_rr: Arc, } impl UpstreamState { @@ -95,6 +98,7 @@ impl UpstreamState { last_check: std::time::Instant::now(), dc_latency: [LatencyEma::new(0.3); NUM_DCS], dc_ip_pref: [IpPreference::Unknown; NUM_DCS], + bind_rr: Arc::new(AtomicUsize::new(0)), } } @@ -166,6 +170,46 @@ impl UpstreamManager { } } + fn resolve_bind_address( + interface: &Option, + bind_addresses: &Option>, + target: SocketAddr, + rr: Option<&AtomicUsize>, + ) -> Option { + let want_ipv6 = target.is_ipv6(); + + if let Some(addrs) = bind_addresses { + let candidates: Vec = addrs + .iter() + .filter_map(|s| s.parse::().ok()) + .filter(|ip| ip.is_ipv6() == want_ipv6) + .collect(); + + if !candidates.is_empty() { + if let Some(counter) = rr { + let idx = counter.fetch_add(1, Ordering::Relaxed) % candidates.len(); + return Some(candidates[idx]); + } + return candidates.first().copied(); + } + } + + if let Some(iface) = interface { + if let Ok(ip) = iface.parse::() { + if ip.is_ipv6() == want_ipv6 { + return Some(ip); + } + } else { + #[cfg(unix)] + if let Some(ip) = resolve_interface_ip(iface, want_ipv6) { + return Some(ip); + } + } + } + + None + } + /// Select upstream using latency-weighted random selection. async fn select_upstream(&self, dc_idx: Option, scope: Option<&str>) -> Option { let upstreams = self.upstreams.read().await; @@ -262,7 +306,12 @@ impl UpstreamManager { let start = Instant::now(); - match self.connect_via_upstream(&upstream, target).await { + let bind_rr = { + let guard = self.upstreams.read().await; + guard.get(idx).map(|u| u.bind_rr.clone()) + }; + + match self.connect_via_upstream(&upstream, target, bind_rr).await { Ok(stream) => { let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; let mut guard = self.upstreams.write().await; @@ -294,13 +343,27 @@ impl UpstreamManager { } } - async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result { + async fn connect_via_upstream( + &self, + config: &UpstreamConfig, + target: SocketAddr, + bind_rr: Option>, + ) -> Result { match &config.upstream_type { - UpstreamType::Direct { interface } => { - let bind_ip = interface.as_ref() - .and_then(|s| s.parse::().ok()); + UpstreamType::Direct { interface, bind_addresses } => { + let bind_ip = Self::resolve_bind_address( + interface, + bind_addresses, + target, + bind_rr.as_deref(), + ); let socket = create_outgoing_socket_bound(target, bind_ip)?; + if let Some(ip) = bind_ip { + debug!(bind = %ip, target = %target, "Bound outgoing socket"); + } else if interface.is_some() || bind_addresses.is_some() { + debug!(target = %target, "No matching bind address for target family"); + } socket.set_nonblocking(true)?; match socket.connect(&target.into()) { @@ -323,8 +386,12 @@ impl UpstreamManager { let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; - let bind_ip = interface.as_ref() - .and_then(|s| s.parse::().ok()); + let bind_ip = Self::resolve_bind_address( + interface, + &None, + proxy_addr, + bind_rr.as_deref(), + ); let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; @@ -354,8 +421,12 @@ impl UpstreamManager { let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; - let bind_ip = interface.as_ref() - .and_then(|s| s.parse::().ok()); + let bind_ip = Self::resolve_bind_address( + interface, + &None, + proxy_addr, + bind_rr.as_deref(), + ); let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; @@ -398,18 +469,18 @@ impl UpstreamManager { ipv4_enabled: bool, ipv6_enabled: bool, ) -> Vec { - let upstreams: Vec<(usize, UpstreamConfig)> = { + let upstreams: Vec<(usize, UpstreamConfig, Arc)> = { let guard = self.upstreams.read().await; guard.iter().enumerate() - .map(|(i, u)| (i, u.config.clone())) + .map(|(i, u)| (i, u.config.clone(), u.bind_rr.clone())) .collect() }; let mut all_results = Vec::new(); - for (upstream_idx, upstream_config) in &upstreams { + for (upstream_idx, upstream_config, bind_rr) in &upstreams { let upstream_name = match &upstream_config.upstream_type { - UpstreamType::Direct { interface } => { + UpstreamType::Direct { interface, .. } => { format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default()) } UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address), @@ -424,7 +495,7 @@ impl UpstreamManager { let result = tokio::time::timeout( Duration::from_secs(DC_PING_TIMEOUT_SECS), - self.ping_single_dc(&upstream_config, addr_v6) + self.ping_single_dc(&upstream_config, Some(bind_rr.clone()), addr_v6) ).await; let ping_result = match result { @@ -475,7 +546,7 @@ impl UpstreamManager { let result = tokio::time::timeout( Duration::from_secs(DC_PING_TIMEOUT_SECS), - self.ping_single_dc(&upstream_config, addr_v4) + self.ping_single_dc(&upstream_config, Some(bind_rr.clone()), addr_v4) ).await; let ping_result = match result { @@ -538,7 +609,7 @@ impl UpstreamManager { } let result = tokio::time::timeout( Duration::from_secs(DC_PING_TIMEOUT_SECS), - self.ping_single_dc(&upstream_config, addr) + self.ping_single_dc(&upstream_config, Some(bind_rr.clone()), addr) ).await; let ping_result = match result { @@ -607,9 +678,14 @@ impl UpstreamManager { all_results } - async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result { + async fn ping_single_dc( + &self, + config: &UpstreamConfig, + bind_rr: Option>, + target: SocketAddr, + ) -> Result { let start = Instant::now(); - let _stream = self.connect_via_upstream(config, target).await?; + let _stream = self.connect_via_upstream(config, target, bind_rr).await?; Ok(start.elapsed().as_secs_f64() * 1000.0) } @@ -649,15 +725,16 @@ impl UpstreamManager { let count = self.upstreams.read().await.len(); for i in 0..count { - let config = { + let (config, bind_rr) = { let guard = self.upstreams.read().await; - guard[i].config.clone() + let u = &guard[i]; + (u.config.clone(), u.bind_rr.clone()) }; let start = Instant::now(); let result = tokio::time::timeout( Duration::from_secs(10), - self.connect_via_upstream(&config, dc_addr) + self.connect_via_upstream(&config, dc_addr, Some(bind_rr.clone())) ).await; match result { @@ -686,7 +763,7 @@ impl UpstreamManager { let start2 = Instant::now(); let result2 = tokio::time::timeout( Duration::from_secs(10), - self.connect_via_upstream(&config, fallback_addr) + self.connect_via_upstream(&config, fallback_addr, Some(bind_rr.clone())) ).await; let mut guard = self.upstreams.write().await; diff --git a/telemt.service b/telemt.service index b08b4c8..4f522a4 100644 --- a/telemt.service +++ b/telemt.service @@ -7,6 +7,7 @@ Type=simple WorkingDirectory=/bin ExecStart=/bin/telemt /etc/telemt.toml Restart=on-failure +LimitNOFILE=65536 [Install] WantedBy=multi-user.target