From 35ae455e2b6fd56f4312d7d968df38c9a37663a5 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Thu, 19 Feb 2026 13:35:56 +0300 Subject: [PATCH] ME Pool V2 Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/defaults.rs | 4 + src/config/load.rs | 38 +++++- src/config/types.rs | 15 +++ src/crypto/random.rs | 5 +- src/main.rs | 21 ++- src/transport/middle_proxy/codec.rs | 7 + src/transport/middle_proxy/config_updater.rs | 35 ++++- src/transport/middle_proxy/pool.rs | 133 ++++++++++++++++--- src/transport/middle_proxy/pool_nat.rs | 86 ++++++------ src/transport/middle_proxy/reader.rs | 30 +++-- src/transport/middle_proxy/registry.rs | 12 +- src/transport/middle_proxy/rotation.rs | 13 +- src/transport/middle_proxy/send.rs | 81 ++++++----- 13 files changed, 343 insertions(+), 137 deletions(-) diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 3f8254c..c11e738 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -74,6 +74,10 @@ pub(crate) fn default_unknown_dc_log_path() -> Option { Some("unknown-dc.txt".to_string()) } +pub(crate) fn default_pool_size() -> usize { + 2 +} + // Custom deserializer helpers #[derive(Deserialize)] diff --git a/src/config/load.rs b/src/config/load.rs index 4f00f77..a2fc19b 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -11,6 +11,32 @@ use crate::error::{ProxyError, Result}; use super::defaults::*; use super::types::*; +fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result { + if depth > 10 { + return Err(ProxyError::Config("Include depth > 10".into())); + } + let mut output = String::with_capacity(content.len()); + for line in content.lines() { + let trimmed = line.trim(); + if let Some(rest) = trimmed.strip_prefix("include") { + let rest = rest.trim(); + if let Some(rest) = rest.strip_prefix('=') { + let path_str = rest.trim().trim_matches('"'); + let resolved = base_dir.join(path_str); + let included = std::fs::read_to_string(&resolved) + .map_err(|e| ProxyError::Config(e.to_string()))?; + let included_dir = resolved.parent().unwrap_or(base_dir); + output.push_str(&preprocess_includes(&included, included_dir, depth + 1)?); + output.push('\n'); + continue; + } + } + output.push_str(line); + output.push('\n'); + } + Ok(output) +} + fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> { if !net.ipv4 && matches!(net.ipv6, Some(false)) { return Err(ProxyError::Config( @@ -84,10 +110,12 @@ pub struct ProxyConfig { impl ProxyConfig { pub fn load>(path: P) -> Result { let content = - std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; + std::fs::read_to_string(&path).map_err(|e| ProxyError::Config(e.to_string()))?; + let base_dir = path.as_ref().parent().unwrap_or(Path::new(".")); + let processed = preprocess_includes(&content, base_dir, 0)?; let mut config: ProxyConfig = - toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; + toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?; // Validate secrets. for (user, secret) in &config.access.users { @@ -151,8 +179,10 @@ impl ProxyConfig { validate_network_cfg(&mut config.network)?; - // Random fake_cert_len. - config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); + // Random fake_cert_len only when default is in use. + if config.censorship.fake_cert_len == default_fake_cert_len() { + config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); + } // Resolve listen_tcp: explicit value wins, otherwise auto-detect. // If unix socket is set → TCP only when listen_addr_ipv4 or listeners are explicitly provided. diff --git a/src/config/types.rs b/src/config/types.rs index 766851b..1cdb1bf 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -143,6 +143,18 @@ pub struct GeneralConfig { #[serde(default)] pub middle_proxy_nat_stun: Option, + /// Optional list of STUN servers for NAT probing fallback. + #[serde(default)] + pub middle_proxy_nat_stun_servers: Vec, + + /// Desired size of active Middle-Proxy writer pool. + #[serde(default = "default_pool_size")] + pub middle_proxy_pool_size: usize, + + /// Number of warm standby ME connections kept pre-initialized. + #[serde(default)] + pub middle_proxy_warm_standby: usize, + /// Ignore STUN/interface IP mismatch (keep using Middle Proxy even if NAT detected). #[serde(default)] pub stun_iface_mismatch_ignore: bool, @@ -175,6 +187,9 @@ impl Default for GeneralConfig { middle_proxy_nat_ip: None, middle_proxy_nat_probe: false, middle_proxy_nat_stun: None, + middle_proxy_nat_stun_servers: Vec::new(), + middle_proxy_pool_size: default_pool_size(), + middle_proxy_warm_standby: 0, stun_iface_mismatch_ignore: false, unknown_dc_log_path: default_unknown_dc_log_path(), log_level: LogLevel::Normal, diff --git a/src/crypto/random.rs b/src/crypto/random.rs index 18862ab..99aa5f3 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -11,6 +11,9 @@ pub struct SecureRandom { inner: Mutex, } +unsafe impl Send for SecureRandom {} +unsafe impl Sync for SecureRandom {} + struct SecureRandomInner { rng: StdRng, cipher: AesCtr, @@ -211,4 +214,4 @@ mod tests { assert_ne!(shuffled, original); } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 7663fab..33aefcb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -74,7 +74,6 @@ fn parse_cli() -> (String, bool, Option) { eprintln!("Options:"); eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --log-level debug|verbose|normal|silent"); - eprintln!(" --version, -V Print version information"); eprintln!(" --help, -h Show this help"); eprintln!(); eprintln!("Setup (fire-and-forget):"); @@ -111,18 +110,20 @@ fn parse_cli() -> (String, bool, Option) { } fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { - info!("--- Proxy Links ({}) ---", host); + info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); for user_name in config.general.links.show.resolve_users(&config.access.users) { if let Some(secret) = config.access.users.get(user_name) { - info!("User: {}", user_name); + info!(target: "telemt::links", "User: {}", user_name); if config.general.modes.classic { info!( + target: "telemt::links", " Classic: tg://proxy?server={}&port={}&secret={}", host, port, secret ); } if config.general.modes.secure { info!( + target: "telemt::links", " DD: tg://proxy?server={}&port={}&secret=dd{}", host, port, secret ); @@ -130,15 +131,16 @@ fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { if config.general.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain); info!( + target: "telemt::links", " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", host, port, secret, domain_hex ); } } else { - warn!("User '{}' in show_link not found", user_name); + warn!(target: "telemt::links", "User '{}' in show_link not found", user_name); } } - info!("------------------------"); + info!(target: "telemt::links", "------------------------"); } #[tokio::main] @@ -322,6 +324,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai config.general.middle_proxy_nat_ip, config.general.middle_proxy_nat_probe, config.general.middle_proxy_nat_stun.clone(), + config.general.middle_proxy_nat_stun_servers.clone(), probe.detected_ipv6, config.timeouts.me_one_retry, config.timeouts.me_one_timeout_ms, @@ -332,16 +335,18 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai rng.clone(), ); - match pool.init(2, &rng).await { + let pool_size = config.general.middle_proxy_pool_size.max(1); + match pool.init(pool_size, &rng).await { Ok(()) => { info!("Middle-End pool initialized successfully"); // Phase 4: Start health monitor let pool_clone = pool.clone(); let rng_clone = rng.clone(); + let min_conns = pool_size; tokio::spawn(async move { crate::transport::middle_proxy::me_health_monitor( - pool_clone, rng_clone, 2, + pool_clone, rng_clone, min_conns, ) .await; }); @@ -745,6 +750,8 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai // Switch to user-configured log level after startup let runtime_filter = if has_rust_log { EnvFilter::from_default_env() + } else if matches!(effective_log_level, LogLevel::Silent) { + EnvFilter::new("warn,telemt::links=info") } else { EnvFilter::new(effective_log_level.to_filter_str()) }; diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index cc33ae8..12efc45 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -4,6 +4,13 @@ use crate::crypto::{AesCbc, crc32}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; +/// Commands sent to dedicated writer tasks to avoid mutex contention on TCP writes. +pub(crate) enum WriterCommand { + Data(Vec), + DataAndFlush(Vec), + Close, +} + pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec { let total_len = (4 + 4 + payload.len() + 4) as u32; let mut frame = Vec::with_capacity(total_len as usize); diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 3c36820..d2bb550 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -13,6 +13,24 @@ use super::secret::download_proxy_secret; use crate::crypto::SecureRandom; use std::time::SystemTime; +async fn retry_fetch(url: &str) -> Option { + let delays = [1u64, 5, 15]; + for (i, d) in delays.iter().enumerate() { + match fetch_proxy_config(url).await { + Ok(cfg) => return Some(cfg), + Err(e) => { + if i == delays.len() - 1 { + warn!(error = %e, url, "fetch_proxy_config failed"); + } else { + debug!(error = %e, url, "fetch_proxy_config retrying"); + tokio::time::sleep(Duration::from_secs(*d)).await; + } + } + } + } + None +} + #[derive(Debug, Clone, Default)] pub struct ProxyConfigData { pub map: HashMap>, @@ -118,7 +136,8 @@ pub async fn me_config_updater(pool: Arc, rng: Arc, interv tick.tick().await; // Update proxy config v4 - if let Ok(cfg) = fetch_proxy_config("https://core.telegram.org/getProxyConfig").await { + let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await; + if let Some(cfg) = cfg_v4 { let changed = pool.update_proxy_maps(cfg.map.clone(), None).await; if let Some(dc) = cfg.default_dc { pool.default_dc.store(dc, std::sync::atomic::Ordering::Relaxed); @@ -129,14 +148,20 @@ pub async fn me_config_updater(pool: Arc, rng: Arc, interv } else { debug!("ME config v4 unchanged"); } - } else { - warn!("getProxyConfig update failed"); } // Update proxy config v6 (optional) - if let Ok(cfg_v6) = fetch_proxy_config("https://core.telegram.org/getProxyConfigV6").await { - let _ = pool.update_proxy_maps(HashMap::new(), Some(cfg_v6.map)).await; + let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await; + if let Some(cfg_v6) = cfg_v6 { + let changed = pool.update_proxy_maps(HashMap::new(), Some(cfg_v6.map)).await; + if changed { + info!("ME config updated (v6), reconciling connections"); + pool.reconcile_connections(&rng).await; + } else { + debug!("ME config v6 unchanged"); + } } + pool.reset_stun_state(); // Update proxy-secret match download_proxy_secret().await { diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 00041fd..e860dc8 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1,14 +1,14 @@ use std::collections::HashMap; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, AtomicUsize, Ordering}; use bytes::BytesMut; use rand::Rng; use rand::seq::SliceRandom; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock, mpsc, Notify}; use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; -use std::time::Duration; +use std::time::{Duration, Instant}; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; @@ -18,7 +18,7 @@ use crate::protocol::constants::*; use super::ConnRegistry; use super::registry::{BoundConn, ConnMeta}; -use super::codec::RpcWriter; +use super::codec::{RpcWriter, WriterCommand}; use super::reader::reader_loop; use super::MeResponse; @@ -29,7 +29,7 @@ const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; pub struct MeWriter { pub id: u64, pub addr: SocketAddr, - pub writer: Arc>, + pub tx: mpsc::Sender, pub cancel: CancellationToken, pub degraded: Arc, pub draining: Arc, @@ -47,9 +47,11 @@ pub struct MePool { pub(super) nat_ip_detected: Arc>>, pub(super) nat_probe: bool, pub(super) nat_stun: Option, + pub(super) nat_stun_servers: Vec, pub(super) detected_ipv6: Option, pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8, pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool, + pub(super) stun_backoff_until: Arc>>, pub(super) me_one_retry: u8, pub(super) me_one_timeout: Duration, pub(super) proxy_map_v4: Arc>>>, @@ -59,6 +61,8 @@ pub struct MePool { pub(super) ping_tracker: Arc>>, pub(super) rtt_stats: Arc>>, pub(super) nat_reflection_cache: Arc>, + pub(super) writer_available: Arc, + pub(super) conn_count: AtomicUsize, pool_size: usize, } @@ -75,6 +79,7 @@ impl MePool { nat_ip: Option, nat_probe: bool, nat_stun: Option, + nat_stun_servers: Vec, detected_ipv6: Option, me_one_retry: u8, me_one_timeout_ms: u64, @@ -96,9 +101,11 @@ impl MePool { nat_ip_detected: Arc::new(RwLock::new(None)), nat_probe, nat_stun, + nat_stun_servers, detected_ipv6, nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), + stun_backoff_until: Arc::new(RwLock::new(None)), me_one_retry, me_one_timeout: Duration::from_millis(me_one_timeout_ms), pool_size: 2, @@ -109,6 +116,8 @@ impl MePool { ping_tracker: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), + writer_available: Arc::new(Notify::new()), + conn_count: AtomicUsize::new(0), }) } @@ -116,6 +125,11 @@ impl MePool { self.proxy_tag.is_some() } + pub fn reset_stun_state(&self) { + self.nat_probe_attempts.store(0, Ordering::Relaxed); + self.nat_probe_disabled.store(false, Ordering::Relaxed); + } + pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { let ip = self.translate_ip_for_nat(addr.ip()); SocketAddr::new(ip, addr.port()) @@ -132,7 +146,11 @@ impl MePool { pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { use std::collections::HashSet; let writers = self.writers.read().await; - let current: HashSet = writers.iter().map(|w| w.addr).collect(); + let current: HashSet = writers + .iter() + .filter(|w| !w.draining.load(Ordering::Relaxed)) + .map(|w| w.addr) + .collect(); drop(writers); for family in self.family_order() { @@ -175,12 +193,36 @@ impl MePool { let mut guard = self.proxy_map_v6.write().await; if !v6.is_empty() && *guard != v6 { *guard = v6; + changed = true; + } + } + // Ensure negative DC entries mirror positives when absent (Telegram convention). + { + let mut guard = self.proxy_map_v4.write().await; + let keys: Vec = guard.keys().cloned().collect(); + for k in keys.iter().cloned().filter(|k| *k > 0) { + if !guard.contains_key(&-k) { + if let Some(addrs) = guard.get(&k).cloned() { + guard.insert(-k, addrs); + } + } + } + } + { + let mut guard = self.proxy_map_v6.write().await; + let keys: Vec = guard.keys().cloned().collect(); + for k in keys.iter().cloned().filter(|k| *k > 0) { + if !guard.contains_key(&-k) { + if let Some(addrs) = guard.get(&k).cloned() { + guard.insert(-k, addrs); + } + } } } changed } - pub async fn update_secret(&self, new_secret: Vec) -> bool { + pub async fn update_secret(self: &Arc, new_secret: Vec) -> bool { if new_secret.len() < 32 { warn!(len = new_secret.len(), "proxy-secret update ignored (too short)"); return false; @@ -195,10 +237,14 @@ impl MePool { false } - pub async fn reconnect_all(&self) { - // Graceful: do not drop all at once. New connections will use updated secret. - // Existing writers remain until health monitor replaces them. - // No-op here to avoid total outage. + pub async fn reconnect_all(self: &Arc) { + let ws = self.writers.read().await.clone(); + for w in ws { + if let Ok(()) = self.connect_one(w.addr, self.rng.as_ref()).await { + self.mark_writer_draining(w.id).await; + tokio::time::sleep(Duration::from_secs(2)).await; + } + } } pub(super) async fn key_selector(&self) -> u32 { @@ -317,21 +363,43 @@ impl MePool { let cancel = CancellationToken::new(); let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); - let rpc_w = Arc::new(Mutex::new(RpcWriter { + let (tx, mut rx) = mpsc::channel::(4096); + let mut rpc_writer = RpcWriter { writer: hs.wr, key: hs.write_key, iv: hs.write_iv, seq_no: 0, - })); + }; + let cancel_wr = cancel.clone(); + tokio::spawn(async move { + loop { + tokio::select! { + cmd = rx.recv() => { + match cmd { + Some(WriterCommand::Data(payload)) => { + if rpc_writer.send(&payload).await.is_err() { break; } + } + Some(WriterCommand::DataAndFlush(payload)) => { + if rpc_writer.send_and_flush(&payload).await.is_err() { break; } + } + Some(WriterCommand::Close) | None => break, + } + } + _ = cancel_wr.cancelled() => break, + } + } + }); let writer = MeWriter { id: writer_id, addr, - writer: rpc_w.clone(), + tx: tx.clone(), cancel: cancel.clone(), degraded: degraded.clone(), draining: draining.clone(), }; self.writers.write().await.push(writer.clone()); + self.conn_count.fetch_add(1, Ordering::Relaxed); + self.writer_available.notify_waiters(); let reg = self.registry.clone(); let writers_arc = self.writers_arc(); @@ -339,8 +407,11 @@ impl MePool { let rtt_stats = self.rtt_stats.clone(); let pool = Arc::downgrade(self); let cancel_ping = cancel.clone(); - let rpc_w_ping = rpc_w.clone(); + let tx_ping = tx.clone(); let ping_tracker_ping = ping_tracker.clone(); + let cleanup_done = Arc::new(AtomicBool::new(false)); + let cleanup_for_reader = cleanup_done.clone(); + let cleanup_for_ping = cleanup_done.clone(); tokio::spawn(async move { let cancel_reader = cancel.clone(); @@ -351,7 +422,7 @@ impl MePool { reg.clone(), BytesMut::new(), BytesMut::new(), - rpc_w.clone(), + tx.clone(), ping_tracker.clone(), rtt_stats.clone(), writer_id, @@ -360,7 +431,12 @@ impl MePool { ) .await; if let Some(pool) = pool.upgrade() { - pool.remove_writer_and_close_clients(writer_id).await; + if cleanup_for_reader + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + pool.remove_writer_and_close_clients(writer_id).await; + } } if let Err(e) = res { warn!(error = %e, "ME reader ended"); @@ -389,14 +465,20 @@ impl MePool { p.extend_from_slice(&sent_id.to_le_bytes()); { let mut tracker = ping_tracker_ping.lock().await; + tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); - if let Err(e) = rpc_w_ping.lock().await.send_and_flush(&p).await { - debug!(error = %e, "Active ME ping failed, removing dead writer"); + if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { + debug!("Active ME ping failed, removing dead writer"); cancel_ping.cancel(); if let Some(pool) = pool_ping.upgrade() { - pool.remove_writer_and_close_clients(writer_id).await; + if cleanup_for_ping + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + pool.remove_writer_and_close_clients(writer_id).await; + } } break; } @@ -430,7 +512,7 @@ impl MePool { false } - pub(crate) async fn remove_writer_and_close_clients(&self, writer_id: u64) { + pub(crate) async fn remove_writer_and_close_clients(self: &Arc, writer_id: u64) { let conns = self.remove_writer_only(writer_id).await; for bound in conns { let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; @@ -444,8 +526,11 @@ impl MePool { if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); w.cancel.cancel(); + let _ = w.tx.send(WriterCommand::Close).await; + self.conn_count.fetch_sub(1, Ordering::Relaxed); } } + self.rtt_stats.lock().await.remove(&writer_id); self.registry.writer_lost(writer_id).await } @@ -459,8 +544,14 @@ impl MePool { let pool = Arc::downgrade(self); tokio::spawn(async move { + let deadline = Instant::now() + Duration::from_secs(300); loop { if let Some(p) = pool.upgrade() { + if Instant::now() >= deadline { + warn!(writer_id, "Drain timeout, force-closing"); + let _ = p.remove_writer_and_close_clients(writer_id).await; + break; + } if p.registry.is_writer_empty(writer_id).await { let _ = p.remove_writer_only(writer_id).await; break; diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 1b4f7c7..d3dec16 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -98,16 +98,18 @@ impl MePool { family: IpFamily, ) -> Option { const STUN_CACHE_TTL: Duration = Duration::from_secs(600); - // If STUN probing was disabled after attempts, reuse cached (even stale) or skip. - if self.nat_probe_disabled.load(std::sync::atomic::Ordering::Relaxed) { - if let Ok(cache) = self.nat_reflection_cache.try_lock() { - let slot = match family { - IpFamily::V4 => cache.v4, - IpFamily::V6 => cache.v6, - }; - return slot.map(|(_, addr)| addr); + // Backoff window + if let Some(until) = *self.stun_backoff_until.read().await { + if Instant::now() < until { + if let Ok(cache) = self.nat_reflection_cache.try_lock() { + let slot = match family { + IpFamily::V4 => cache.v4, + IpFamily::V6 => cache.v6, + }; + return slot.map(|(_, addr)| addr); + } + return None; } - return None; } if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { @@ -123,48 +125,42 @@ impl MePool { } let attempt = self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if attempt >= 2 { - self.nat_probe_disabled.store(true, std::sync::atomic::Ordering::Relaxed); - return None; - } + let servers = if !self.nat_stun_servers.is_empty() { + self.nat_stun_servers.clone() + } else if let Some(s) = &self.nat_stun { + vec![s.clone()] + } else { + vec!["stun.l.google.com:19302".to_string()] + }; - let stun_addr = self - .nat_stun - .clone() - .unwrap_or_else(|| "stun.l.google.com:19302".to_string()); - match stun_probe_dual(&stun_addr).await { - Ok(res) => { - let picked: Option = match family { - IpFamily::V4 => res.v4, - IpFamily::V6 => res.v6, - }; - if let Some(result) = picked { - info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, "NAT probe: reflected address"); - if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { - let slot = match family { - IpFamily::V4 => &mut cache.v4, - IpFamily::V6 => &mut cache.v6, - }; - *slot = Some((Instant::now(), result.reflected_addr)); + for stun_addr in servers { + match stun_probe_dual(&stun_addr).await { + Ok(res) => { + let picked: Option = match family { + IpFamily::V4 => res.v4, + IpFamily::V6 => res.v6, + }; + if let Some(result) = picked { + info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, stun = %stun_addr, "NAT probe: reflected address"); + self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed); + if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + let slot = match family { + IpFamily::V4 => &mut cache.v4, + IpFamily::V6 => &mut cache.v6, + }; + *slot = Some((Instant::now(), result.reflected_addr)); + } + return Some(result.reflected_addr); } - Some(result.reflected_addr) - } else { - None } - } - Err(e) => { - let attempts = attempt + 1; - if attempts <= 2 { - warn!(error = %e, attempt = attempts, "NAT probe failed"); - } else { - debug!(error = %e, attempt = attempts, "NAT probe suppressed after max attempts"); + Err(e) => { + warn!(error = %e, stun = %stun_addr, attempt = attempt + 1, "NAT probe failed, trying next server"); } - if attempts >= 2 { - self.nat_probe_disabled.store(true, std::sync::atomic::Ordering::Relaxed); - } - None } } + let backoff = Duration::from_secs(60 * 2u64.pow((attempt as u32).min(6))); + *self.stun_backoff_until.write().await = Some(Instant::now() + backoff); + None } } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index fb40fdb..c22ed68 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -6,7 +6,7 @@ use std::time::Instant; use bytes::{Bytes, BytesMut}; use tokio::io::AsyncReadExt; use tokio::net::TcpStream; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, mpsc}; use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; @@ -14,7 +14,7 @@ use crate::crypto::{AesCbc, crc32}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; -use super::codec::RpcWriter; +use super::codec::WriterCommand; use super::{ConnRegistry, MeResponse}; pub(crate) async fn reader_loop( @@ -24,7 +24,7 @@ pub(crate) async fn reader_loop( reg: Arc, enc_leftover: BytesMut, mut dec: BytesMut, - writer: Arc>, + tx: mpsc::Sender, ping_tracker: Arc>>, rtt_stats: Arc>>, _writer_id: u64, @@ -33,6 +33,8 @@ pub(crate) async fn reader_loop( ) -> Result<()> { let mut raw = enc_leftover; let mut expected_seq: i32 = 0; + let mut crc_errors = 0u32; + let mut seq_mismatch = 0u32; loop { let mut tmp = [0u8; 16_384]; @@ -80,12 +82,20 @@ pub(crate) async fn reader_loop( let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); if crc32(&frame[..pe]) != ec { warn!("CRC mismatch in data frame"); + crc_errors += 1; + if crc_errors > 3 { + return Err(ProxyError::Proxy("Too many CRC mismatches".into())); + } continue; } let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); if seq_no != expected_seq { warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); + seq_mismatch += 1; + if seq_mismatch > 10 { + return Err(ProxyError::Proxy("Too many seq mismatches".into())); + } expected_seq = seq_no.wrapping_add(1); } else { expected_seq = expected_seq.wrapping_add(1); @@ -108,7 +118,7 @@ pub(crate) async fn reader_loop( let routed = reg.route(cid, MeResponse::Data { flags, data }).await; if !routed { reg.unregister(cid).await; - send_close_conn(&writer, cid).await; + send_close_conn(&tx, cid).await; } } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -118,7 +128,7 @@ pub(crate) async fn reader_loop( let routed = reg.route(cid, MeResponse::Ack(cfm)).await; if !routed { reg.unregister(cid).await; - send_close_conn(&writer, cid).await; + send_close_conn(&tx, cid).await; } } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -136,8 +146,8 @@ pub(crate) async fn reader_loop( let mut pong = Vec::with_capacity(12); pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); pong.extend_from_slice(&ping_id.to_le_bytes()); - if let Err(e) = writer.lock().await.send_and_flush(&pong).await { - warn!(error = %e, "PONG send failed"); + if tx.send(WriterCommand::DataAndFlush(pong)).await.is_err() { + warn!("PONG send failed"); break; } } else if pt == RPC_PONG_U32 && body.len() >= 8 { @@ -171,12 +181,10 @@ pub(crate) async fn reader_loop( } } -async fn send_close_conn(writer: &Arc>, conn_id: u64) { +async fn send_close_conn(tx: &mpsc::Sender, conn_id: u64) { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - if let Err(e) = writer.lock().await.send_and_flush(&p).await { - debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN"); - } + let _ = tx.send(WriterCommand::DataAndFlush(p)).await; } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 04f8baa..ab4f280 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use tokio::sync::{mpsc, Mutex, RwLock}; -use super::codec::RpcWriter; +use super::codec::WriterCommand; use super::MeResponse; #[derive(Clone)] @@ -25,12 +25,12 @@ pub struct BoundConn { #[derive(Clone)] pub struct ConnWriter { pub writer_id: u64, - pub writer: Arc>, + pub tx: mpsc::Sender, } struct RegistryInner { map: HashMap>, - writers: HashMap>>, + writers: HashMap>, writer_for_conn: HashMap, conns_for_writer: HashMap>, meta: HashMap, @@ -96,13 +96,13 @@ impl ConnRegistry { &self, conn_id: u64, writer_id: u64, - writer: Arc>, + tx: mpsc::Sender, meta: ConnMeta, ) { let mut inner = self.inner.write().await; inner.meta.entry(conn_id).or_insert(meta); inner.writer_for_conn.insert(conn_id, writer_id); - inner.writers.entry(writer_id).or_insert_with(|| writer.clone()); + inner.writers.entry(writer_id).or_insert_with(|| tx.clone()); inner .conns_for_writer .entry(writer_id) @@ -114,7 +114,7 @@ impl ConnRegistry { let inner = self.inner.read().await; let writer_id = inner.writer_for_conn.get(&conn_id).cloned()?; let writer = inner.writers.get(&writer_id).cloned()?; - Some(ConnWriter { writer_id, writer }) + Some(ConnWriter { writer_id, tx: writer }) } pub async fn writer_lost(&self, writer_id: u64) -> Vec { diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs index 6d94f3e..e141fc4 100644 --- a/src/transport/middle_proxy/rotation.rs +++ b/src/transport/middle_proxy/rotation.rs @@ -31,8 +31,17 @@ pub async fn me_rotation_task(pool: Arc, rng: Arc, interva info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection"); match pool.connect_one(w.addr, rng.as_ref()).await { Ok(()) => { - // Mark old writer for graceful drain; removal happens when sessions finish. - pool.mark_writer_draining(w.id).await; + tokio::time::sleep(Duration::from_secs(2)).await; + let ws = pool.writers.read().await; + let new_alive = ws.iter().any(|nw| + nw.id != w.id && nw.addr == w.addr && !nw.degraded.load(Ordering::Relaxed) && !nw.draining.load(Ordering::Relaxed) + ); + drop(ws); + if new_alive { + pool.mark_writer_draining(w.id).await; + } else { + warn!(addr = %w.addr, writer_id = w.id, "New writer died, keeping old"); + } } Err(e) => { warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed"); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 2b0c42e..627906d 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -10,6 +10,7 @@ use crate::network::IpFamily; use crate::protocol::constants::RPC_CLOSE_EXT_U32; use super::MePool; +use super::codec::WriterCommand; use super::wire::build_proxy_req_payload; use rand::seq::SliceRandom; use super::registry::ConnMeta; @@ -43,18 +44,15 @@ impl MePool { loop { if let Some(current) = self.registry.get_writer(conn_id).await { let send_res = { - if let Ok(mut guard) = current.writer.try_lock() { - let r = guard.send(&payload).await; - drop(guard); - r - } else { - current.writer.lock().await.send(&payload).await - } + current + .tx + .send(WriterCommand::Data(payload.clone())) + .await }; match send_res { Ok(()) => return Ok(()), - Err(e) => { - warn!(error = %e, writer_id = current.writer_id, "ME write failed"); + Err(_) => { + warn!(writer_id = current.writer_id, "ME writer channel closed"); self.remove_writer_and_close_clients(current.writer_id).await; continue; } @@ -64,7 +62,26 @@ impl MePool { let mut writers_snapshot = { let ws = self.writers.read().await; if ws.is_empty() { - return Err(ProxyError::Proxy("All ME connections dead".into())); + drop(ws); + for family in self.family_order() { + let map = match family { + IpFamily::V4 => self.proxy_map_v4.read().await.clone(), + IpFamily::V6 => self.proxy_map_v6.read().await.clone(), + }; + for (_dc, addrs) in map.iter() { + for (ip, port) in addrs { + let addr = SocketAddr::new(*ip, *port); + if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { + self.writer_available.notify_waiters(); + break; + } + } + } + } + if tokio::time::timeout(Duration::from_secs(3), self.writer_available.notified()).await.is_err() { + return Err(ProxyError::Proxy("All ME connections dead (waited 3s)".into())); + } + continue; } ws.clone() }; @@ -96,9 +113,10 @@ impl MePool { writers_snapshot = ws2.clone(); drop(ws2); candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; - break; + if !candidate_indices.is_empty() { + break; + } } - drop(map_guard); } if candidate_indices.is_empty() { return Err(ProxyError::Proxy("No ME writers available for target DC".into())); @@ -120,22 +138,15 @@ impl MePool { if w.draining.load(Ordering::Relaxed) { continue; } - if let Ok(mut guard) = w.writer.try_lock() { - let send_res = guard.send(&payload).await; - drop(guard); - match send_res { - Ok(()) => { - self.registry - .bind_writer(conn_id, w.id, w.writer.clone(), meta.clone()) - .await; - return Ok(()); - } - Err(e) => { - warn!(error = %e, writer_id = w.id, "ME write failed"); - self.remove_writer_and_close_clients(w.id).await; - continue; - } - } + if w.tx.send(WriterCommand::Data(payload.clone())).await.is_ok() { + self.registry + .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone()) + .await; + return Ok(()); + } else { + warn!(writer_id = w.id, "ME writer channel closed"); + self.remove_writer_and_close_clients(w.id).await; + continue; } } @@ -143,15 +154,15 @@ impl MePool { if w.draining.load(Ordering::Relaxed) { continue; } - match w.writer.lock().await.send(&payload).await { + match w.tx.send(WriterCommand::Data(payload.clone())).await { Ok(()) => { self.registry - .bind_writer(conn_id, w.id, w.writer.clone(), meta.clone()) + .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone()) .await; return Ok(()); } - Err(e) => { - warn!(error = %e, writer_id = w.id, "ME write failed (blocking)"); + Err(_) => { + warn!(writer_id = w.id, "ME writer channel closed (blocking)"); self.remove_writer_and_close_clients(w.id).await; } } @@ -163,8 +174,8 @@ impl MePool { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - if let Err(e) = w.writer.lock().await.send_and_flush(&p).await { - debug!(error = %e, "ME close write failed"); + if w.tx.send(WriterCommand::DataAndFlush(p)).await.is_err() { + debug!("ME close write failed"); self.remove_writer_and_close_clients(w.writer_id).await; } } else { @@ -176,7 +187,7 @@ impl MePool { } pub fn connection_count(&self) -> usize { - self.writers.try_read().map(|w| w.len()).unwrap_or(0) + self.conn_count.load(Ordering::Relaxed) } pub(super) async fn candidate_indices_for_dc(