diff --git a/src/main.rs b/src/main.rs index 1da8123..fe001c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,7 @@ use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use rand::Rng; use tokio::net::TcpListener; use tokio::signal; @@ -173,6 +173,74 @@ async fn write_beobachten_snapshot(path: &str, payload: &str) -> std::io::Result tokio::fs::write(path, payload).await } +fn unit_label(value: u64, singular: &'static str, plural: &'static str) -> &'static str { + if value == 1 { singular } else { plural } +} + +fn format_uptime(total_secs: u64) -> String { + const SECS_PER_MINUTE: u64 = 60; + const SECS_PER_HOUR: u64 = 60 * SECS_PER_MINUTE; + const SECS_PER_DAY: u64 = 24 * SECS_PER_HOUR; + const SECS_PER_MONTH: u64 = 30 * SECS_PER_DAY; + const SECS_PER_YEAR: u64 = 365 * SECS_PER_DAY; + + let mut remaining = total_secs; + let years = remaining / SECS_PER_YEAR; + remaining %= SECS_PER_YEAR; + let months = remaining / SECS_PER_MONTH; + remaining %= SECS_PER_MONTH; + let days = remaining / SECS_PER_DAY; + remaining %= SECS_PER_DAY; + let hours = remaining / SECS_PER_HOUR; + remaining %= SECS_PER_HOUR; + let minutes = remaining / SECS_PER_MINUTE; + let seconds = remaining % SECS_PER_MINUTE; + + let mut parts = Vec::new(); + if years > 0 { + parts.push(format!( + "{} {}", + years, + unit_label(years, "year", "years") + )); + } + if total_secs >= SECS_PER_YEAR { + parts.push(format!( + "{} {}", + months, + unit_label(months, "month", "months") + )); + } + if total_secs >= SECS_PER_MONTH { + parts.push(format!( + "{} {}", + days, + unit_label(days, "day", "days") + )); + } + if total_secs >= SECS_PER_DAY { + parts.push(format!( + "{} {}", + hours, + unit_label(hours, "hour", "hours") + )); + } + if total_secs >= SECS_PER_HOUR { + parts.push(format!( + "{} {}", + minutes, + unit_label(minutes, "minute", "minutes") + )); + } + parts.push(format!( + "{} {}", + seconds, + unit_label(seconds, "second", "seconds") + )); + + format!("{} / {} seconds", parts.join(", "), total_secs) +} + async fn load_startup_proxy_config_snapshot( url: &str, cache_path: Option<&str>, @@ -289,6 +357,7 @@ async fn load_startup_proxy_config_snapshot( #[tokio::main] async fn main() -> std::result::Result<(), Box> { + let process_started_at = Instant::now(); let (config_path, cli_silent, cli_log_level) = parse_cli(); let mut config = match ProxyConfig::load(&config_path) { @@ -961,6 +1030,15 @@ async fn main() -> std::result::Result<(), Box> { } } + let initialized_secs = process_started_at.elapsed().as_secs(); + let second_suffix = if initialized_secs == 1 { "" } else { "s" }; + info!("================= Telegram Startup ================="); + info!( + " DC/ME Initialized in {} second{}", + initialized_secs, second_suffix + ); + info!("============================================================"); + // Background tasks let um_clone = upstream_manager.clone(); let decision_clone = decision.clone(); @@ -1514,7 +1592,29 @@ async fn main() -> std::result::Result<(), Box> { } match signal::ctrl_c().await { - Ok(()) => info!("Shutting down..."), + Ok(()) => { + let uptime_secs = process_started_at.elapsed().as_secs(); + info!("Uptime: {}", format_uptime(uptime_secs)); + info!("Shutting down..."); + if let Some(pool) = &me_pool { + match tokio::time::timeout( + Duration::from_secs(2), + pool.shutdown_send_close_conn_all(), + ) + .await + { + Ok(total) => { + info!( + close_conn_sent = total, + "ME shutdown: RPC_CLOSE_CONN broadcast completed" + ); + } + Err(_) => { + warn!("ME shutdown: RPC_CLOSE_CONN broadcast timed out"); + } + } + } + } Err(e) => error!("Signal error: {}", e), } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 869030a..e4d0031 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -278,6 +278,11 @@ impl ConnRegistry { Some(ConnWriter { writer_id, tx: writer }) } + pub async fn active_conn_ids(&self) -> Vec { + let inner = self.inner.read().await; + inner.writer_for_conn.keys().copied().collect() + } + pub async fn writer_lost(&self, writer_id: u64) -> Vec { let mut inner = self.inner.write().await; inner.writers.remove(&writer_id); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 8bd21ee..c6db028 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -11,7 +11,7 @@ use tracing::{debug, warn}; use crate::config::MeRouteNoWriterMode; use crate::error::{ProxyError, Result}; use crate::network::IpFamily; -use crate::protocol::constants::RPC_CLOSE_EXT_U32; +use crate::protocol::constants::{RPC_CLOSE_CONN_U32, RPC_CLOSE_EXT_U32}; use super::MePool; use super::codec::WriterCommand; @@ -476,6 +476,37 @@ impl MePool { Ok(()) } + pub async fn send_close_conn(self: &Arc, conn_id: u64) -> Result<()> { + if let Some(w) = self.registry.get_writer(conn_id).await { + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); + p.extend_from_slice(&conn_id.to_le_bytes()); + match w.tx.try_send(WriterCommand::DataAndFlush(p)) { + Ok(()) => {} + Err(TrySendError::Full(cmd)) => { + let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; + } + Err(TrySendError::Closed(_)) => { + debug!(conn_id, "ME close_conn skipped: writer channel closed"); + } + } + } else { + debug!(conn_id, "ME close_conn skipped (writer missing)"); + } + + self.registry.unregister(conn_id).await; + Ok(()) + } + + pub async fn shutdown_send_close_conn_all(self: &Arc) -> usize { + let conn_ids = self.registry.active_conn_ids().await; + let total = conn_ids.len(); + for conn_id in conn_ids { + let _ = self.send_close_conn(conn_id).await; + } + total + } + pub fn connection_count(&self) -> usize { self.conn_count.load(Ordering::Relaxed) }