ME Pool V2

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-02-19 13:35:56 +03:00
parent 433e6c9a20
commit 35ae455e2b
13 changed files with 343 additions and 137 deletions

View File

@@ -1,14 +1,14 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, AtomicUsize, Ordering};
use bytes::BytesMut;
use rand::Rng;
use rand::seq::SliceRandom;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::{Mutex, RwLock, mpsc, Notify};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use std::time::Duration;
use std::time::{Duration, Instant};
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result};
@@ -18,7 +18,7 @@ use crate::protocol::constants::*;
use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta};
use super::codec::RpcWriter;
use super::codec::{RpcWriter, WriterCommand};
use super::reader::reader_loop;
use super::MeResponse;
@@ -29,7 +29,7 @@ const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
pub struct MeWriter {
pub id: u64,
pub addr: SocketAddr,
pub writer: Arc<Mutex<RpcWriter>>,
pub tx: mpsc::Sender<WriterCommand>,
pub cancel: CancellationToken,
pub degraded: Arc<AtomicBool>,
pub draining: Arc<AtomicBool>,
@@ -47,9 +47,11 @@ pub struct MePool {
pub(super) nat_ip_detected: Arc<RwLock<Option<IpAddr>>>,
pub(super) nat_probe: bool,
pub(super) nat_stun: Option<String>,
pub(super) nat_stun_servers: Vec<String>,
pub(super) detected_ipv6: Option<Ipv6Addr>,
pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8,
pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool,
pub(super) stun_backoff_until: Arc<RwLock<Option<Instant>>>,
pub(super) me_one_retry: u8,
pub(super) me_one_timeout: Duration,
pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
@@ -59,6 +61,8 @@ pub struct MePool {
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pub(super) writer_available: Arc<Notify>,
pub(super) conn_count: AtomicUsize,
pool_size: usize,
}
@@ -75,6 +79,7 @@ impl MePool {
nat_ip: Option<IpAddr>,
nat_probe: bool,
nat_stun: Option<String>,
nat_stun_servers: Vec<String>,
detected_ipv6: Option<Ipv6Addr>,
me_one_retry: u8,
me_one_timeout_ms: u64,
@@ -96,9 +101,11 @@ impl MePool {
nat_ip_detected: Arc::new(RwLock::new(None)),
nat_probe,
nat_stun,
nat_stun_servers,
detected_ipv6,
nat_probe_attempts: std::sync::atomic::AtomicU8::new(0),
nat_probe_disabled: std::sync::atomic::AtomicBool::new(false),
stun_backoff_until: Arc::new(RwLock::new(None)),
me_one_retry,
me_one_timeout: Duration::from_millis(me_one_timeout_ms),
pool_size: 2,
@@ -109,6 +116,8 @@ impl MePool {
ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
writer_available: Arc::new(Notify::new()),
conn_count: AtomicUsize::new(0),
})
}
@@ -116,6 +125,11 @@ impl MePool {
self.proxy_tag.is_some()
}
pub fn reset_stun_state(&self) {
self.nat_probe_attempts.store(0, Ordering::Relaxed);
self.nat_probe_disabled.store(false, Ordering::Relaxed);
}
pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr {
let ip = self.translate_ip_for_nat(addr.ip());
SocketAddr::new(ip, addr.port())
@@ -132,7 +146,11 @@ impl MePool {
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet;
let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
let current: HashSet<SocketAddr> = writers
.iter()
.filter(|w| !w.draining.load(Ordering::Relaxed))
.map(|w| w.addr)
.collect();
drop(writers);
for family in self.family_order() {
@@ -175,12 +193,36 @@ impl MePool {
let mut guard = self.proxy_map_v6.write().await;
if !v6.is_empty() && *guard != v6 {
*guard = v6;
changed = true;
}
}
// Ensure negative DC entries mirror positives when absent (Telegram convention).
{
let mut guard = self.proxy_map_v4.write().await;
let keys: Vec<i32> = guard.keys().cloned().collect();
for k in keys.iter().cloned().filter(|k| *k > 0) {
if !guard.contains_key(&-k) {
if let Some(addrs) = guard.get(&k).cloned() {
guard.insert(-k, addrs);
}
}
}
}
{
let mut guard = self.proxy_map_v6.write().await;
let keys: Vec<i32> = guard.keys().cloned().collect();
for k in keys.iter().cloned().filter(|k| *k > 0) {
if !guard.contains_key(&-k) {
if let Some(addrs) = guard.get(&k).cloned() {
guard.insert(-k, addrs);
}
}
}
}
changed
}
pub async fn update_secret(&self, new_secret: Vec<u8>) -> bool {
pub async fn update_secret(self: &Arc<Self>, new_secret: Vec<u8>) -> bool {
if new_secret.len() < 32 {
warn!(len = new_secret.len(), "proxy-secret update ignored (too short)");
return false;
@@ -195,10 +237,14 @@ impl MePool {
false
}
pub async fn reconnect_all(&self) {
// Graceful: do not drop all at once. New connections will use updated secret.
// Existing writers remain until health monitor replaces them.
// No-op here to avoid total outage.
pub async fn reconnect_all(self: &Arc<Self>) {
let ws = self.writers.read().await.clone();
for w in ws {
if let Ok(()) = self.connect_one(w.addr, self.rng.as_ref()).await {
self.mark_writer_draining(w.id).await;
tokio::time::sleep(Duration::from_secs(2)).await;
}
}
}
pub(super) async fn key_selector(&self) -> u32 {
@@ -317,21 +363,43 @@ impl MePool {
let cancel = CancellationToken::new();
let degraded = Arc::new(AtomicBool::new(false));
let draining = Arc::new(AtomicBool::new(false));
let rpc_w = Arc::new(Mutex::new(RpcWriter {
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let mut rpc_writer = RpcWriter {
writer: hs.wr,
key: hs.write_key,
iv: hs.write_iv,
seq_no: 0,
}));
};
let cancel_wr = cancel.clone();
tokio::spawn(async move {
loop {
tokio::select! {
cmd = rx.recv() => {
match cmd {
Some(WriterCommand::Data(payload)) => {
if rpc_writer.send(&payload).await.is_err() { break; }
}
Some(WriterCommand::DataAndFlush(payload)) => {
if rpc_writer.send_and_flush(&payload).await.is_err() { break; }
}
Some(WriterCommand::Close) | None => break,
}
}
_ = cancel_wr.cancelled() => break,
}
}
});
let writer = MeWriter {
id: writer_id,
addr,
writer: rpc_w.clone(),
tx: tx.clone(),
cancel: cancel.clone(),
degraded: degraded.clone(),
draining: draining.clone(),
};
self.writers.write().await.push(writer.clone());
self.conn_count.fetch_add(1, Ordering::Relaxed);
self.writer_available.notify_waiters();
let reg = self.registry.clone();
let writers_arc = self.writers_arc();
@@ -339,8 +407,11 @@ impl MePool {
let rtt_stats = self.rtt_stats.clone();
let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone();
let rpc_w_ping = rpc_w.clone();
let tx_ping = tx.clone();
let ping_tracker_ping = ping_tracker.clone();
let cleanup_done = Arc::new(AtomicBool::new(false));
let cleanup_for_reader = cleanup_done.clone();
let cleanup_for_ping = cleanup_done.clone();
tokio::spawn(async move {
let cancel_reader = cancel.clone();
@@ -351,7 +422,7 @@ impl MePool {
reg.clone(),
BytesMut::new(),
BytesMut::new(),
rpc_w.clone(),
tx.clone(),
ping_tracker.clone(),
rtt_stats.clone(),
writer_id,
@@ -360,7 +431,12 @@ impl MePool {
)
.await;
if let Some(pool) = pool.upgrade() {
pool.remove_writer_and_close_clients(writer_id).await;
if cleanup_for_reader
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
pool.remove_writer_and_close_clients(writer_id).await;
}
}
if let Err(e) = res {
warn!(error = %e, "ME reader ended");
@@ -389,14 +465,20 @@ impl MePool {
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_ping.lock().await;
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
ping_id = ping_id.wrapping_add(1);
if let Err(e) = rpc_w_ping.lock().await.send_and_flush(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer");
if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() {
debug!("Active ME ping failed, removing dead writer");
cancel_ping.cancel();
if let Some(pool) = pool_ping.upgrade() {
pool.remove_writer_and_close_clients(writer_id).await;
if cleanup_for_ping
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
pool.remove_writer_and_close_clients(writer_id).await;
}
}
break;
}
@@ -430,7 +512,7 @@ impl MePool {
false
}
pub(crate) async fn remove_writer_and_close_clients(&self, writer_id: u64) {
pub(crate) async fn remove_writer_and_close_clients(self: &Arc<Self>, writer_id: u64) {
let conns = self.remove_writer_only(writer_id).await;
for bound in conns {
let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await;
@@ -444,8 +526,11 @@ impl MePool {
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
let w = ws.remove(pos);
w.cancel.cancel();
let _ = w.tx.send(WriterCommand::Close).await;
self.conn_count.fetch_sub(1, Ordering::Relaxed);
}
}
self.rtt_stats.lock().await.remove(&writer_id);
self.registry.writer_lost(writer_id).await
}
@@ -459,8 +544,14 @@ impl MePool {
let pool = Arc::downgrade(self);
tokio::spawn(async move {
let deadline = Instant::now() + Duration::from_secs(300);
loop {
if let Some(p) = pool.upgrade() {
if Instant::now() >= deadline {
warn!(writer_id, "Drain timeout, force-closing");
let _ = p.remove_writer_and_close_clients(writer_id).await;
break;
}
if p.registry.is_writer_empty(writer_id).await {
let _ = p.remove_writer_only(writer_id).await;
break;