Merge pull request #162 from telemt/flow

ME Pool V2
This commit is contained in:
Alexey 2026-02-19 13:42:03 +03:00 committed by GitHub
commit 0768fee06a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 343 additions and 137 deletions

View File

@ -74,6 +74,10 @@ pub(crate) fn default_unknown_dc_log_path() -> Option<String> {
Some("unknown-dc.txt".to_string()) Some("unknown-dc.txt".to_string())
} }
pub(crate) fn default_pool_size() -> usize {
2
}
// Custom deserializer helpers // Custom deserializer helpers
#[derive(Deserialize)] #[derive(Deserialize)]

View File

@ -11,6 +11,32 @@ use crate::error::{ProxyError, Result};
use super::defaults::*; use super::defaults::*;
use super::types::*; use super::types::*;
fn preprocess_includes(content: &str, base_dir: &Path, depth: u8) -> Result<String> {
if depth > 10 {
return Err(ProxyError::Config("Include depth > 10".into()));
}
let mut output = String::with_capacity(content.len());
for line in content.lines() {
let trimmed = line.trim();
if let Some(rest) = trimmed.strip_prefix("include") {
let rest = rest.trim();
if let Some(rest) = rest.strip_prefix('=') {
let path_str = rest.trim().trim_matches('"');
let resolved = base_dir.join(path_str);
let included = std::fs::read_to_string(&resolved)
.map_err(|e| ProxyError::Config(e.to_string()))?;
let included_dir = resolved.parent().unwrap_or(base_dir);
output.push_str(&preprocess_includes(&included, included_dir, depth + 1)?);
output.push('\n');
continue;
}
}
output.push_str(line);
output.push('\n');
}
Ok(output)
}
fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> { fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> {
if !net.ipv4 && matches!(net.ipv6, Some(false)) { if !net.ipv4 && matches!(net.ipv6, Some(false)) {
return Err(ProxyError::Config( return Err(ProxyError::Config(
@ -84,10 +110,12 @@ pub struct ProxyConfig {
impl ProxyConfig { impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = let content =
std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; std::fs::read_to_string(&path).map_err(|e| ProxyError::Config(e.to_string()))?;
let base_dir = path.as_ref().parent().unwrap_or(Path::new("."));
let processed = preprocess_includes(&content, base_dir, 0)?;
let mut config: ProxyConfig = let mut config: ProxyConfig =
toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; toml::from_str(&processed).map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets. // Validate secrets.
for (user, secret) in &config.access.users { for (user, secret) in &config.access.users {
@ -151,8 +179,10 @@ impl ProxyConfig {
validate_network_cfg(&mut config.network)?; validate_network_cfg(&mut config.network)?;
// Random fake_cert_len. // Random fake_cert_len only when default is in use.
if config.censorship.fake_cert_len == default_fake_cert_len() {
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
}
// Resolve listen_tcp: explicit value wins, otherwise auto-detect. // Resolve listen_tcp: explicit value wins, otherwise auto-detect.
// If unix socket is set → TCP only when listen_addr_ipv4 or listeners are explicitly provided. // If unix socket is set → TCP only when listen_addr_ipv4 or listeners are explicitly provided.

View File

@ -143,6 +143,18 @@ pub struct GeneralConfig {
#[serde(default)] #[serde(default)]
pub middle_proxy_nat_stun: Option<String>, pub middle_proxy_nat_stun: Option<String>,
/// Optional list of STUN servers for NAT probing fallback.
#[serde(default)]
pub middle_proxy_nat_stun_servers: Vec<String>,
/// Desired size of active Middle-Proxy writer pool.
#[serde(default = "default_pool_size")]
pub middle_proxy_pool_size: usize,
/// Number of warm standby ME connections kept pre-initialized.
#[serde(default)]
pub middle_proxy_warm_standby: usize,
/// Ignore STUN/interface IP mismatch (keep using Middle Proxy even if NAT detected). /// Ignore STUN/interface IP mismatch (keep using Middle Proxy even if NAT detected).
#[serde(default)] #[serde(default)]
pub stun_iface_mismatch_ignore: bool, pub stun_iface_mismatch_ignore: bool,
@ -175,6 +187,9 @@ impl Default for GeneralConfig {
middle_proxy_nat_ip: None, middle_proxy_nat_ip: None,
middle_proxy_nat_probe: false, middle_proxy_nat_probe: false,
middle_proxy_nat_stun: None, middle_proxy_nat_stun: None,
middle_proxy_nat_stun_servers: Vec::new(),
middle_proxy_pool_size: default_pool_size(),
middle_proxy_warm_standby: 0,
stun_iface_mismatch_ignore: false, stun_iface_mismatch_ignore: false,
unknown_dc_log_path: default_unknown_dc_log_path(), unknown_dc_log_path: default_unknown_dc_log_path(),
log_level: LogLevel::Normal, log_level: LogLevel::Normal,

View File

@ -11,6 +11,9 @@ pub struct SecureRandom {
inner: Mutex<SecureRandomInner>, inner: Mutex<SecureRandomInner>,
} }
unsafe impl Send for SecureRandom {}
unsafe impl Sync for SecureRandom {}
struct SecureRandomInner { struct SecureRandomInner {
rng: StdRng, rng: StdRng,
cipher: AesCtr, cipher: AesCtr,

View File

@ -74,7 +74,6 @@ fn parse_cli() -> (String, bool, Option<String>) {
eprintln!("Options:"); eprintln!("Options:");
eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent"); eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --version, -V Print version information");
eprintln!(" --help, -h Show this help"); eprintln!(" --help, -h Show this help");
eprintln!(); eprintln!();
eprintln!("Setup (fire-and-forget):"); eprintln!("Setup (fire-and-forget):");
@ -111,18 +110,20 @@ fn parse_cli() -> (String, bool, Option<String>) {
} }
fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
info!("--- Proxy Links ({}) ---", host); info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
for user_name in config.general.links.show.resolve_users(&config.access.users) { for user_name in config.general.links.show.resolve_users(&config.access.users) {
if let Some(secret) = config.access.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name); info!(target: "telemt::links", "User: {}", user_name);
if config.general.modes.classic { if config.general.modes.classic {
info!( info!(
target: "telemt::links",
" Classic: tg://proxy?server={}&port={}&secret={}", " Classic: tg://proxy?server={}&port={}&secret={}",
host, port, secret host, port, secret
); );
} }
if config.general.modes.secure { if config.general.modes.secure {
info!( info!(
target: "telemt::links",
" DD: tg://proxy?server={}&port={}&secret=dd{}", " DD: tg://proxy?server={}&port={}&secret=dd{}",
host, port, secret host, port, secret
); );
@ -130,15 +131,16 @@ fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
if config.general.modes.tls { if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain); let domain_hex = hex::encode(&config.censorship.tls_domain);
info!( info!(
target: "telemt::links",
" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
host, port, secret, domain_hex host, port, secret, domain_hex
); );
} }
} else { } else {
warn!("User '{}' in show_link not found", user_name); warn!(target: "telemt::links", "User '{}' in show_link not found", user_name);
} }
} }
info!("------------------------"); info!(target: "telemt::links", "------------------------");
} }
#[tokio::main] #[tokio::main]
@ -322,6 +324,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
config.general.middle_proxy_nat_ip, config.general.middle_proxy_nat_ip,
config.general.middle_proxy_nat_probe, config.general.middle_proxy_nat_probe,
config.general.middle_proxy_nat_stun.clone(), config.general.middle_proxy_nat_stun.clone(),
config.general.middle_proxy_nat_stun_servers.clone(),
probe.detected_ipv6, probe.detected_ipv6,
config.timeouts.me_one_retry, config.timeouts.me_one_retry,
config.timeouts.me_one_timeout_ms, config.timeouts.me_one_timeout_ms,
@ -332,16 +335,18 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
rng.clone(), rng.clone(),
); );
match pool.init(2, &rng).await { let pool_size = config.general.middle_proxy_pool_size.max(1);
match pool.init(pool_size, &rng).await {
Ok(()) => { Ok(()) => {
info!("Middle-End pool initialized successfully"); info!("Middle-End pool initialized successfully");
// Phase 4: Start health monitor // Phase 4: Start health monitor
let pool_clone = pool.clone(); let pool_clone = pool.clone();
let rng_clone = rng.clone(); let rng_clone = rng.clone();
let min_conns = pool_size;
tokio::spawn(async move { tokio::spawn(async move {
crate::transport::middle_proxy::me_health_monitor( crate::transport::middle_proxy::me_health_monitor(
pool_clone, rng_clone, 2, pool_clone, rng_clone, min_conns,
) )
.await; .await;
}); });
@ -745,6 +750,8 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
// Switch to user-configured log level after startup // Switch to user-configured log level after startup
let runtime_filter = if has_rust_log { let runtime_filter = if has_rust_log {
EnvFilter::from_default_env() EnvFilter::from_default_env()
} else if matches!(effective_log_level, LogLevel::Silent) {
EnvFilter::new("warn,telemt::links=info")
} else { } else {
EnvFilter::new(effective_log_level.to_filter_str()) EnvFilter::new(effective_log_level.to_filter_str())
}; };

View File

@ -4,6 +4,13 @@ use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
/// Commands sent to dedicated writer tasks to avoid mutex contention on TCP writes.
pub(crate) enum WriterCommand {
Data(Vec<u8>),
DataAndFlush(Vec<u8>),
Close,
}
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> { pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32; let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut frame = Vec::with_capacity(total_len as usize); let mut frame = Vec::with_capacity(total_len as usize);

View File

@ -13,6 +13,24 @@ use super::secret::download_proxy_secret;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use std::time::SystemTime; use std::time::SystemTime;
async fn retry_fetch(url: &str) -> Option<ProxyConfigData> {
let delays = [1u64, 5, 15];
for (i, d) in delays.iter().enumerate() {
match fetch_proxy_config(url).await {
Ok(cfg) => return Some(cfg),
Err(e) => {
if i == delays.len() - 1 {
warn!(error = %e, url, "fetch_proxy_config failed");
} else {
debug!(error = %e, url, "fetch_proxy_config retrying");
tokio::time::sleep(Duration::from_secs(*d)).await;
}
}
}
}
None
}
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct ProxyConfigData { pub struct ProxyConfigData {
pub map: HashMap<i32, Vec<(IpAddr, u16)>>, pub map: HashMap<i32, Vec<(IpAddr, u16)>>,
@ -118,7 +136,8 @@ pub async fn me_config_updater(pool: Arc<MePool>, rng: Arc<SecureRandom>, interv
tick.tick().await; tick.tick().await;
// Update proxy config v4 // Update proxy config v4
if let Ok(cfg) = fetch_proxy_config("https://core.telegram.org/getProxyConfig").await { let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await;
if let Some(cfg) = cfg_v4 {
let changed = pool.update_proxy_maps(cfg.map.clone(), None).await; let changed = pool.update_proxy_maps(cfg.map.clone(), None).await;
if let Some(dc) = cfg.default_dc { if let Some(dc) = cfg.default_dc {
pool.default_dc.store(dc, std::sync::atomic::Ordering::Relaxed); pool.default_dc.store(dc, std::sync::atomic::Ordering::Relaxed);
@ -129,14 +148,20 @@ pub async fn me_config_updater(pool: Arc<MePool>, rng: Arc<SecureRandom>, interv
} else { } else {
debug!("ME config v4 unchanged"); debug!("ME config v4 unchanged");
} }
} else {
warn!("getProxyConfig update failed");
} }
// Update proxy config v6 (optional) // Update proxy config v6 (optional)
if let Ok(cfg_v6) = fetch_proxy_config("https://core.telegram.org/getProxyConfigV6").await { let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await;
let _ = pool.update_proxy_maps(HashMap::new(), Some(cfg_v6.map)).await; if let Some(cfg_v6) = cfg_v6 {
let changed = pool.update_proxy_maps(HashMap::new(), Some(cfg_v6.map)).await;
if changed {
info!("ME config updated (v6), reconciling connections");
pool.reconcile_connections(&rng).await;
} else {
debug!("ME config v6 unchanged");
} }
}
pool.reset_stun_state();
// Update proxy-secret // Update proxy-secret
match download_proxy_secret().await { match download_proxy_secret().await {

View File

@ -1,14 +1,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, AtomicUsize, Ordering};
use bytes::BytesMut; use bytes::BytesMut;
use rand::Rng; use rand::Rng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock, mpsc, Notify};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use std::time::Duration; use std::time::{Duration, Instant};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
@ -18,7 +18,7 @@ use crate::protocol::constants::*;
use super::ConnRegistry; use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta}; use super::registry::{BoundConn, ConnMeta};
use super::codec::RpcWriter; use super::codec::{RpcWriter, WriterCommand};
use super::reader::reader_loop; use super::reader::reader_loop;
use super::MeResponse; use super::MeResponse;
@ -29,7 +29,7 @@ const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
pub struct MeWriter { pub struct MeWriter {
pub id: u64, pub id: u64,
pub addr: SocketAddr, pub addr: SocketAddr,
pub writer: Arc<Mutex<RpcWriter>>, pub tx: mpsc::Sender<WriterCommand>,
pub cancel: CancellationToken, pub cancel: CancellationToken,
pub degraded: Arc<AtomicBool>, pub degraded: Arc<AtomicBool>,
pub draining: Arc<AtomicBool>, pub draining: Arc<AtomicBool>,
@ -47,9 +47,11 @@ pub struct MePool {
pub(super) nat_ip_detected: Arc<RwLock<Option<IpAddr>>>, pub(super) nat_ip_detected: Arc<RwLock<Option<IpAddr>>>,
pub(super) nat_probe: bool, pub(super) nat_probe: bool,
pub(super) nat_stun: Option<String>, pub(super) nat_stun: Option<String>,
pub(super) nat_stun_servers: Vec<String>,
pub(super) detected_ipv6: Option<Ipv6Addr>, pub(super) detected_ipv6: Option<Ipv6Addr>,
pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8, pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8,
pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool, pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool,
pub(super) stun_backoff_until: Arc<RwLock<Option<Instant>>>,
pub(super) me_one_retry: u8, pub(super) me_one_retry: u8,
pub(super) me_one_timeout: Duration, pub(super) me_one_timeout: Duration,
pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>, pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
@ -59,6 +61,8 @@ pub struct MePool {
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>, pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>, pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>, pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pub(super) writer_available: Arc<Notify>,
pub(super) conn_count: AtomicUsize,
pool_size: usize, pool_size: usize,
} }
@ -75,6 +79,7 @@ impl MePool {
nat_ip: Option<IpAddr>, nat_ip: Option<IpAddr>,
nat_probe: bool, nat_probe: bool,
nat_stun: Option<String>, nat_stun: Option<String>,
nat_stun_servers: Vec<String>,
detected_ipv6: Option<Ipv6Addr>, detected_ipv6: Option<Ipv6Addr>,
me_one_retry: u8, me_one_retry: u8,
me_one_timeout_ms: u64, me_one_timeout_ms: u64,
@ -96,9 +101,11 @@ impl MePool {
nat_ip_detected: Arc::new(RwLock::new(None)), nat_ip_detected: Arc::new(RwLock::new(None)),
nat_probe, nat_probe,
nat_stun, nat_stun,
nat_stun_servers,
detected_ipv6, detected_ipv6,
nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), nat_probe_attempts: std::sync::atomic::AtomicU8::new(0),
nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), nat_probe_disabled: std::sync::atomic::AtomicBool::new(false),
stun_backoff_until: Arc::new(RwLock::new(None)),
me_one_retry, me_one_retry,
me_one_timeout: Duration::from_millis(me_one_timeout_ms), me_one_timeout: Duration::from_millis(me_one_timeout_ms),
pool_size: 2, pool_size: 2,
@ -109,6 +116,8 @@ impl MePool {
ping_tracker: Arc::new(Mutex::new(HashMap::new())), ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
writer_available: Arc::new(Notify::new()),
conn_count: AtomicUsize::new(0),
}) })
} }
@ -116,6 +125,11 @@ impl MePool {
self.proxy_tag.is_some() self.proxy_tag.is_some()
} }
pub fn reset_stun_state(&self) {
self.nat_probe_attempts.store(0, Ordering::Relaxed);
self.nat_probe_disabled.store(false, Ordering::Relaxed);
}
pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr {
let ip = self.translate_ip_for_nat(addr.ip()); let ip = self.translate_ip_for_nat(addr.ip());
SocketAddr::new(ip, addr.port()) SocketAddr::new(ip, addr.port())
@ -132,7 +146,11 @@ impl MePool {
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) { pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet; use std::collections::HashSet;
let writers = self.writers.read().await; let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect(); let current: HashSet<SocketAddr> = writers
.iter()
.filter(|w| !w.draining.load(Ordering::Relaxed))
.map(|w| w.addr)
.collect();
drop(writers); drop(writers);
for family in self.family_order() { for family in self.family_order() {
@ -175,12 +193,36 @@ impl MePool {
let mut guard = self.proxy_map_v6.write().await; let mut guard = self.proxy_map_v6.write().await;
if !v6.is_empty() && *guard != v6 { if !v6.is_empty() && *guard != v6 {
*guard = v6; *guard = v6;
changed = true;
}
}
// Ensure negative DC entries mirror positives when absent (Telegram convention).
{
let mut guard = self.proxy_map_v4.write().await;
let keys: Vec<i32> = guard.keys().cloned().collect();
for k in keys.iter().cloned().filter(|k| *k > 0) {
if !guard.contains_key(&-k) {
if let Some(addrs) = guard.get(&k).cloned() {
guard.insert(-k, addrs);
}
}
}
}
{
let mut guard = self.proxy_map_v6.write().await;
let keys: Vec<i32> = guard.keys().cloned().collect();
for k in keys.iter().cloned().filter(|k| *k > 0) {
if !guard.contains_key(&-k) {
if let Some(addrs) = guard.get(&k).cloned() {
guard.insert(-k, addrs);
}
}
} }
} }
changed changed
} }
pub async fn update_secret(&self, new_secret: Vec<u8>) -> bool { pub async fn update_secret(self: &Arc<Self>, new_secret: Vec<u8>) -> bool {
if new_secret.len() < 32 { if new_secret.len() < 32 {
warn!(len = new_secret.len(), "proxy-secret update ignored (too short)"); warn!(len = new_secret.len(), "proxy-secret update ignored (too short)");
return false; return false;
@ -195,10 +237,14 @@ impl MePool {
false false
} }
pub async fn reconnect_all(&self) { pub async fn reconnect_all(self: &Arc<Self>) {
// Graceful: do not drop all at once. New connections will use updated secret. let ws = self.writers.read().await.clone();
// Existing writers remain until health monitor replaces them. for w in ws {
// No-op here to avoid total outage. if let Ok(()) = self.connect_one(w.addr, self.rng.as_ref()).await {
self.mark_writer_draining(w.id).await;
tokio::time::sleep(Duration::from_secs(2)).await;
}
}
} }
pub(super) async fn key_selector(&self) -> u32 { pub(super) async fn key_selector(&self) -> u32 {
@ -317,21 +363,43 @@ impl MePool {
let cancel = CancellationToken::new(); let cancel = CancellationToken::new();
let degraded = Arc::new(AtomicBool::new(false)); let degraded = Arc::new(AtomicBool::new(false));
let draining = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false));
let rpc_w = Arc::new(Mutex::new(RpcWriter { let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let mut rpc_writer = RpcWriter {
writer: hs.wr, writer: hs.wr,
key: hs.write_key, key: hs.write_key,
iv: hs.write_iv, iv: hs.write_iv,
seq_no: 0, seq_no: 0,
})); };
let cancel_wr = cancel.clone();
tokio::spawn(async move {
loop {
tokio::select! {
cmd = rx.recv() => {
match cmd {
Some(WriterCommand::Data(payload)) => {
if rpc_writer.send(&payload).await.is_err() { break; }
}
Some(WriterCommand::DataAndFlush(payload)) => {
if rpc_writer.send_and_flush(&payload).await.is_err() { break; }
}
Some(WriterCommand::Close) | None => break,
}
}
_ = cancel_wr.cancelled() => break,
}
}
});
let writer = MeWriter { let writer = MeWriter {
id: writer_id, id: writer_id,
addr, addr,
writer: rpc_w.clone(), tx: tx.clone(),
cancel: cancel.clone(), cancel: cancel.clone(),
degraded: degraded.clone(), degraded: degraded.clone(),
draining: draining.clone(), draining: draining.clone(),
}; };
self.writers.write().await.push(writer.clone()); self.writers.write().await.push(writer.clone());
self.conn_count.fetch_add(1, Ordering::Relaxed);
self.writer_available.notify_waiters();
let reg = self.registry.clone(); let reg = self.registry.clone();
let writers_arc = self.writers_arc(); let writers_arc = self.writers_arc();
@ -339,8 +407,11 @@ impl MePool {
let rtt_stats = self.rtt_stats.clone(); let rtt_stats = self.rtt_stats.clone();
let pool = Arc::downgrade(self); let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone(); let cancel_ping = cancel.clone();
let rpc_w_ping = rpc_w.clone(); let tx_ping = tx.clone();
let ping_tracker_ping = ping_tracker.clone(); let ping_tracker_ping = ping_tracker.clone();
let cleanup_done = Arc::new(AtomicBool::new(false));
let cleanup_for_reader = cleanup_done.clone();
let cleanup_for_ping = cleanup_done.clone();
tokio::spawn(async move { tokio::spawn(async move {
let cancel_reader = cancel.clone(); let cancel_reader = cancel.clone();
@ -351,7 +422,7 @@ impl MePool {
reg.clone(), reg.clone(),
BytesMut::new(), BytesMut::new(),
BytesMut::new(), BytesMut::new(),
rpc_w.clone(), tx.clone(),
ping_tracker.clone(), ping_tracker.clone(),
rtt_stats.clone(), rtt_stats.clone(),
writer_id, writer_id,
@ -360,8 +431,13 @@ impl MePool {
) )
.await; .await;
if let Some(pool) = pool.upgrade() { if let Some(pool) = pool.upgrade() {
if cleanup_for_reader
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
pool.remove_writer_and_close_clients(writer_id).await; pool.remove_writer_and_close_clients(writer_id).await;
} }
}
if let Err(e) = res { if let Err(e) = res {
warn!(error = %e, "ME reader ended"); warn!(error = %e, "ME reader ended");
} }
@ -389,15 +465,21 @@ impl MePool {
p.extend_from_slice(&sent_id.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes());
{ {
let mut tracker = ping_tracker_ping.lock().await; let mut tracker = ping_tracker_ping.lock().await;
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
} }
ping_id = ping_id.wrapping_add(1); ping_id = ping_id.wrapping_add(1);
if let Err(e) = rpc_w_ping.lock().await.send_and_flush(&p).await { if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() {
debug!(error = %e, "Active ME ping failed, removing dead writer"); debug!("Active ME ping failed, removing dead writer");
cancel_ping.cancel(); cancel_ping.cancel();
if let Some(pool) = pool_ping.upgrade() { if let Some(pool) = pool_ping.upgrade() {
if cleanup_for_ping
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
pool.remove_writer_and_close_clients(writer_id).await; pool.remove_writer_and_close_clients(writer_id).await;
} }
}
break; break;
} }
} }
@ -430,7 +512,7 @@ impl MePool {
false false
} }
pub(crate) async fn remove_writer_and_close_clients(&self, writer_id: u64) { pub(crate) async fn remove_writer_and_close_clients(self: &Arc<Self>, writer_id: u64) {
let conns = self.remove_writer_only(writer_id).await; let conns = self.remove_writer_only(writer_id).await;
for bound in conns { for bound in conns {
let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await;
@ -444,8 +526,11 @@ impl MePool {
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
let w = ws.remove(pos); let w = ws.remove(pos);
w.cancel.cancel(); w.cancel.cancel();
let _ = w.tx.send(WriterCommand::Close).await;
self.conn_count.fetch_sub(1, Ordering::Relaxed);
} }
} }
self.rtt_stats.lock().await.remove(&writer_id);
self.registry.writer_lost(writer_id).await self.registry.writer_lost(writer_id).await
} }
@ -459,8 +544,14 @@ impl MePool {
let pool = Arc::downgrade(self); let pool = Arc::downgrade(self);
tokio::spawn(async move { tokio::spawn(async move {
let deadline = Instant::now() + Duration::from_secs(300);
loop { loop {
if let Some(p) = pool.upgrade() { if let Some(p) = pool.upgrade() {
if Instant::now() >= deadline {
warn!(writer_id, "Drain timeout, force-closing");
let _ = p.remove_writer_and_close_clients(writer_id).await;
break;
}
if p.registry.is_writer_empty(writer_id).await { if p.registry.is_writer_empty(writer_id).await {
let _ = p.remove_writer_only(writer_id).await; let _ = p.remove_writer_only(writer_id).await;
break; break;

View File

@ -98,8 +98,9 @@ impl MePool {
family: IpFamily, family: IpFamily,
) -> Option<std::net::SocketAddr> { ) -> Option<std::net::SocketAddr> {
const STUN_CACHE_TTL: Duration = Duration::from_secs(600); const STUN_CACHE_TTL: Duration = Duration::from_secs(600);
// If STUN probing was disabled after attempts, reuse cached (even stale) or skip. // Backoff window
if self.nat_probe_disabled.load(std::sync::atomic::Ordering::Relaxed) { if let Some(until) = *self.stun_backoff_until.read().await {
if Instant::now() < until {
if let Ok(cache) = self.nat_reflection_cache.try_lock() { if let Ok(cache) = self.nat_reflection_cache.try_lock() {
let slot = match family { let slot = match family {
IpFamily::V4 => cache.v4, IpFamily::V4 => cache.v4,
@ -109,6 +110,7 @@ impl MePool {
} }
return None; return None;
} }
}
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
let slot = match family { let slot = match family {
@ -123,15 +125,15 @@ impl MePool {
} }
let attempt = self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let attempt = self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if attempt >= 2 { let servers = if !self.nat_stun_servers.is_empty() {
self.nat_probe_disabled.store(true, std::sync::atomic::Ordering::Relaxed); self.nat_stun_servers.clone()
return None; } else if let Some(s) = &self.nat_stun {
} vec![s.clone()]
} else {
vec!["stun.l.google.com:19302".to_string()]
};
let stun_addr = self for stun_addr in servers {
.nat_stun
.clone()
.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
match stun_probe_dual(&stun_addr).await { match stun_probe_dual(&stun_addr).await {
Ok(res) => { Ok(res) => {
let picked: Option<StunProbeResult> = match family { let picked: Option<StunProbeResult> = match family {
@ -139,7 +141,8 @@ impl MePool {
IpFamily::V6 => res.v6, IpFamily::V6 => res.v6,
}; };
if let Some(result) = picked { if let Some(result) = picked {
info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, "NAT probe: reflected address"); info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, stun = %stun_addr, "NAT probe: reflected address");
self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed);
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
let slot = match family { let slot = match family {
IpFamily::V4 => &mut cache.v4, IpFamily::V4 => &mut cache.v4,
@ -147,25 +150,18 @@ impl MePool {
}; };
*slot = Some((Instant::now(), result.reflected_addr)); *slot = Some((Instant::now(), result.reflected_addr));
} }
Some(result.reflected_addr) return Some(result.reflected_addr);
} else {
None
} }
} }
Err(e) => { Err(e) => {
let attempts = attempt + 1; warn!(error = %e, stun = %stun_addr, attempt = attempt + 1, "NAT probe failed, trying next server");
if attempts <= 2 {
warn!(error = %e, attempt = attempts, "NAT probe failed");
} else {
debug!(error = %e, attempt = attempts, "NAT probe suppressed after max attempts");
} }
if attempts >= 2 {
self.nat_probe_disabled.store(true, std::sync::atomic::Ordering::Relaxed);
} }
}
let backoff = Duration::from_secs(60 * 2u64.pow((attempt as u32).min(6)));
*self.stun_backoff_until.write().await = Some(Instant::now() + backoff);
None None
} }
}
}
} }
async fn fetch_public_ipv4_with_retry() -> Result<Option<Ipv4Addr>> { async fn fetch_public_ipv4_with_retry() -> Result<Option<Ipv4Addr>> {

View File

@ -6,7 +6,7 @@ use std::time::Instant;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::Mutex; use tokio::sync::{Mutex, mpsc};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
@ -14,7 +14,7 @@ use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use super::codec::RpcWriter; use super::codec::WriterCommand;
use super::{ConnRegistry, MeResponse}; use super::{ConnRegistry, MeResponse};
pub(crate) async fn reader_loop( pub(crate) async fn reader_loop(
@ -24,7 +24,7 @@ pub(crate) async fn reader_loop(
reg: Arc<ConnRegistry>, reg: Arc<ConnRegistry>,
enc_leftover: BytesMut, enc_leftover: BytesMut,
mut dec: BytesMut, mut dec: BytesMut,
writer: Arc<Mutex<RpcWriter>>, tx: mpsc::Sender<WriterCommand>,
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>, ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>, rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
_writer_id: u64, _writer_id: u64,
@ -33,6 +33,8 @@ pub(crate) async fn reader_loop(
) -> Result<()> { ) -> Result<()> {
let mut raw = enc_leftover; let mut raw = enc_leftover;
let mut expected_seq: i32 = 0; let mut expected_seq: i32 = 0;
let mut crc_errors = 0u32;
let mut seq_mismatch = 0u32;
loop { loop {
let mut tmp = [0u8; 16_384]; let mut tmp = [0u8; 16_384];
@ -80,12 +82,20 @@ pub(crate) async fn reader_loop(
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
if crc32(&frame[..pe]) != ec { if crc32(&frame[..pe]) != ec {
warn!("CRC mismatch in data frame"); warn!("CRC mismatch in data frame");
crc_errors += 1;
if crc_errors > 3 {
return Err(ProxyError::Proxy("Too many CRC mismatches".into()));
}
continue; continue;
} }
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
if seq_no != expected_seq { if seq_no != expected_seq {
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
seq_mismatch += 1;
if seq_mismatch > 10 {
return Err(ProxyError::Proxy("Too many seq mismatches".into()));
}
expected_seq = seq_no.wrapping_add(1); expected_seq = seq_no.wrapping_add(1);
} else { } else {
expected_seq = expected_seq.wrapping_add(1); expected_seq = expected_seq.wrapping_add(1);
@ -108,7 +118,7 @@ pub(crate) async fn reader_loop(
let routed = reg.route(cid, MeResponse::Data { flags, data }).await; let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
if !routed { if !routed {
reg.unregister(cid).await; reg.unregister(cid).await;
send_close_conn(&writer, cid).await; send_close_conn(&tx, cid).await;
} }
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
@ -118,7 +128,7 @@ pub(crate) async fn reader_loop(
let routed = reg.route(cid, MeResponse::Ack(cfm)).await; let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
if !routed { if !routed {
reg.unregister(cid).await; reg.unregister(cid).await;
send_close_conn(&writer, cid).await; send_close_conn(&tx, cid).await;
} }
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
@ -136,8 +146,8 @@ pub(crate) async fn reader_loop(
let mut pong = Vec::with_capacity(12); let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes()); pong.extend_from_slice(&ping_id.to_le_bytes());
if let Err(e) = writer.lock().await.send_and_flush(&pong).await { if tx.send(WriterCommand::DataAndFlush(pong)).await.is_err() {
warn!(error = %e, "PONG send failed"); warn!("PONG send failed");
break; break;
} }
} else if pt == RPC_PONG_U32 && body.len() >= 8 { } else if pt == RPC_PONG_U32 && body.len() >= 8 {
@ -171,12 +181,10 @@ pub(crate) async fn reader_loop(
} }
} }
async fn send_close_conn(writer: &Arc<Mutex<RpcWriter>>, conn_id: u64) { async fn send_close_conn(tx: &mpsc::Sender<WriterCommand>, conn_id: u64) {
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = writer.lock().await.send_and_flush(&p).await { let _ = tx.send(WriterCommand::DataAndFlush(p)).await;
debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN");
}
} }

View File

@ -5,7 +5,7 @@ use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::sync::{mpsc, Mutex, RwLock};
use super::codec::RpcWriter; use super::codec::WriterCommand;
use super::MeResponse; use super::MeResponse;
#[derive(Clone)] #[derive(Clone)]
@ -25,12 +25,12 @@ pub struct BoundConn {
#[derive(Clone)] #[derive(Clone)]
pub struct ConnWriter { pub struct ConnWriter {
pub writer_id: u64, pub writer_id: u64,
pub writer: Arc<Mutex<RpcWriter>>, pub tx: mpsc::Sender<WriterCommand>,
} }
struct RegistryInner { struct RegistryInner {
map: HashMap<u64, mpsc::Sender<MeResponse>>, map: HashMap<u64, mpsc::Sender<MeResponse>>,
writers: HashMap<u64, Arc<Mutex<RpcWriter>>>, writers: HashMap<u64, mpsc::Sender<WriterCommand>>,
writer_for_conn: HashMap<u64, u64>, writer_for_conn: HashMap<u64, u64>,
conns_for_writer: HashMap<u64, HashSet<u64>>, conns_for_writer: HashMap<u64, HashSet<u64>>,
meta: HashMap<u64, ConnMeta>, meta: HashMap<u64, ConnMeta>,
@ -96,13 +96,13 @@ impl ConnRegistry {
&self, &self,
conn_id: u64, conn_id: u64,
writer_id: u64, writer_id: u64,
writer: Arc<Mutex<RpcWriter>>, tx: mpsc::Sender<WriterCommand>,
meta: ConnMeta, meta: ConnMeta,
) { ) {
let mut inner = self.inner.write().await; let mut inner = self.inner.write().await;
inner.meta.entry(conn_id).or_insert(meta); inner.meta.entry(conn_id).or_insert(meta);
inner.writer_for_conn.insert(conn_id, writer_id); inner.writer_for_conn.insert(conn_id, writer_id);
inner.writers.entry(writer_id).or_insert_with(|| writer.clone()); inner.writers.entry(writer_id).or_insert_with(|| tx.clone());
inner inner
.conns_for_writer .conns_for_writer
.entry(writer_id) .entry(writer_id)
@ -114,7 +114,7 @@ impl ConnRegistry {
let inner = self.inner.read().await; let inner = self.inner.read().await;
let writer_id = inner.writer_for_conn.get(&conn_id).cloned()?; let writer_id = inner.writer_for_conn.get(&conn_id).cloned()?;
let writer = inner.writers.get(&writer_id).cloned()?; let writer = inner.writers.get(&writer_id).cloned()?;
Some(ConnWriter { writer_id, writer }) Some(ConnWriter { writer_id, tx: writer })
} }
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> { pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {

View File

@ -31,8 +31,17 @@ pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interva
info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection"); info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection");
match pool.connect_one(w.addr, rng.as_ref()).await { match pool.connect_one(w.addr, rng.as_ref()).await {
Ok(()) => { Ok(()) => {
// Mark old writer for graceful drain; removal happens when sessions finish. tokio::time::sleep(Duration::from_secs(2)).await;
let ws = pool.writers.read().await;
let new_alive = ws.iter().any(|nw|
nw.id != w.id && nw.addr == w.addr && !nw.degraded.load(Ordering::Relaxed) && !nw.draining.load(Ordering::Relaxed)
);
drop(ws);
if new_alive {
pool.mark_writer_draining(w.id).await; pool.mark_writer_draining(w.id).await;
} else {
warn!(addr = %w.addr, writer_id = w.id, "New writer died, keeping old");
}
} }
Err(e) => { Err(e) => {
warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed"); warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed");

View File

@ -10,6 +10,7 @@ use crate::network::IpFamily;
use crate::protocol::constants::RPC_CLOSE_EXT_U32; use crate::protocol::constants::RPC_CLOSE_EXT_U32;
use super::MePool; use super::MePool;
use super::codec::WriterCommand;
use super::wire::build_proxy_req_payload; use super::wire::build_proxy_req_payload;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use super::registry::ConnMeta; use super::registry::ConnMeta;
@ -43,18 +44,15 @@ impl MePool {
loop { loop {
if let Some(current) = self.registry.get_writer(conn_id).await { if let Some(current) = self.registry.get_writer(conn_id).await {
let send_res = { let send_res = {
if let Ok(mut guard) = current.writer.try_lock() { current
let r = guard.send(&payload).await; .tx
drop(guard); .send(WriterCommand::Data(payload.clone()))
r .await
} else {
current.writer.lock().await.send(&payload).await
}
}; };
match send_res { match send_res {
Ok(()) => return Ok(()), Ok(()) => return Ok(()),
Err(e) => { Err(_) => {
warn!(error = %e, writer_id = current.writer_id, "ME write failed"); warn!(writer_id = current.writer_id, "ME writer channel closed");
self.remove_writer_and_close_clients(current.writer_id).await; self.remove_writer_and_close_clients(current.writer_id).await;
continue; continue;
} }
@ -64,7 +62,26 @@ impl MePool {
let mut writers_snapshot = { let mut writers_snapshot = {
let ws = self.writers.read().await; let ws = self.writers.read().await;
if ws.is_empty() { if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into())); drop(ws);
for family in self.family_order() {
let map = match family {
IpFamily::V4 => self.proxy_map_v4.read().await.clone(),
IpFamily::V6 => self.proxy_map_v6.read().await.clone(),
};
for (_dc, addrs) in map.iter() {
for (ip, port) in addrs {
let addr = SocketAddr::new(*ip, *port);
if self.connect_one(addr, self.rng.as_ref()).await.is_ok() {
self.writer_available.notify_waiters();
break;
}
}
}
}
if tokio::time::timeout(Duration::from_secs(3), self.writer_available.notified()).await.is_err() {
return Err(ProxyError::Proxy("All ME connections dead (waited 3s)".into()));
}
continue;
} }
ws.clone() ws.clone()
}; };
@ -96,9 +113,10 @@ impl MePool {
writers_snapshot = ws2.clone(); writers_snapshot = ws2.clone();
drop(ws2); drop(ws2);
candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
if !candidate_indices.is_empty() {
break; break;
} }
drop(map_guard); }
} }
if candidate_indices.is_empty() { if candidate_indices.is_empty() {
return Err(ProxyError::Proxy("No ME writers available for target DC".into())); return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
@ -120,38 +138,31 @@ impl MePool {
if w.draining.load(Ordering::Relaxed) { if w.draining.load(Ordering::Relaxed) {
continue; continue;
} }
if let Ok(mut guard) = w.writer.try_lock() { if w.tx.send(WriterCommand::Data(payload.clone())).await.is_ok() {
let send_res = guard.send(&payload).await;
drop(guard);
match send_res {
Ok(()) => {
self.registry self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone()) .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone())
.await; .await;
return Ok(()); return Ok(());
} } else {
Err(e) => { warn!(writer_id = w.id, "ME writer channel closed");
warn!(error = %e, writer_id = w.id, "ME write failed");
self.remove_writer_and_close_clients(w.id).await; self.remove_writer_and_close_clients(w.id).await;
continue; continue;
} }
} }
}
}
let w = writers_snapshot[candidate_indices[start]].clone(); let w = writers_snapshot[candidate_indices[start]].clone();
if w.draining.load(Ordering::Relaxed) { if w.draining.load(Ordering::Relaxed) {
continue; continue;
} }
match w.writer.lock().await.send(&payload).await { match w.tx.send(WriterCommand::Data(payload.clone())).await {
Ok(()) => { Ok(()) => {
self.registry self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone()) .bind_writer(conn_id, w.id, w.tx.clone(), meta.clone())
.await; .await;
return Ok(()); return Ok(());
} }
Err(e) => { Err(_) => {
warn!(error = %e, writer_id = w.id, "ME write failed (blocking)"); warn!(writer_id = w.id, "ME writer channel closed (blocking)");
self.remove_writer_and_close_clients(w.id).await; self.remove_writer_and_close_clients(w.id).await;
} }
} }
@ -163,8 +174,8 @@ impl MePool {
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = w.writer.lock().await.send_and_flush(&p).await { if w.tx.send(WriterCommand::DataAndFlush(p)).await.is_err() {
debug!(error = %e, "ME close write failed"); debug!("ME close write failed");
self.remove_writer_and_close_clients(w.writer_id).await; self.remove_writer_and_close_clients(w.writer_id).await;
} }
} else { } else {
@ -176,7 +187,7 @@ impl MePool {
} }
pub fn connection_count(&self) -> usize { pub fn connection_count(&self) -> usize {
self.writers.try_read().map(|w| w.len()).unwrap_or(0) self.conn_count.load(Ordering::Relaxed)
} }
pub(super) async fn candidate_indices_for_dc( pub(super) async fn candidate_indices_for_dc(