From bc432f06e27dd664113cf50db4934a006fb93878 Mon Sep 17 00:00:00 2001 From: sintanial Date: Sun, 1 Mar 2026 13:53:50 +0300 Subject: [PATCH 1/2] Add per-user ad_tag with global fallback and hot-reload - Per-user ad_tag in [access.user_ad_tags], global fallback in general.ad_tag - User tag overrides global; if no user tag, general.ad_tag is used - Both general.ad_tag and user_ad_tags support hot-reload (no restart) --- Cargo.lock | 2 +- README.md | 8 +++--- config.toml | 2 ++ src/config/hot_reload.rs | 42 ++++++++++++++++-------------- src/config/load.rs | 4 +-- src/config/types.rs | 12 ++++++--- src/main.rs | 2 +- src/proxy/middle_relay.rs | 19 +++++++++++++- src/transport/middle_proxy/send.rs | 5 +++- 9 files changed, 65 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 251f0b7..e29b473 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2087,7 +2087,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.0.13" +version = "3.1.3" dependencies = [ "aes", "anyhow", diff --git a/README.md b/README.md index 093f2cd..8ea25e7 100644 --- a/README.md +++ b/README.md @@ -215,10 +215,12 @@ hello = "00000000000000000000000000000000" ``` ### Advanced -#### Adtag -To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to section `[General]` +#### Adtag (per-user) +To use channel advertising and usage statistics from Telegram, get an Adtag from [@mtproxybot](https://t.me/mtproxybot). Set it per user in `[access.user_ad_tags]` (32 hex chars): ```toml -ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot +[access.user_ad_tags] +username1 = "11111111111111111111111111111111" # Replace with your tag from @mtproxybot +username2 = "22222222222222222222222222222222" ``` #### Listening and Announce IPs To specify listening address and/or address in links, add to section `[[server.listeners]]` of config.toml: diff --git a/config.toml b/config.toml index b280234..cb33e3d 100644 --- a/config.toml +++ b/config.toml @@ -5,7 +5,9 @@ # === General Settings === [general] use_middle_proxy = false +# Global ad_tag fallback when user has no per-user tag in [access.user_ad_tags] # ad_tag = "00000000000000000000000000000000" +# Per-user ad_tag in [access.user_ad_tags] (32 hex from @MTProxybot) # === Log Level === # Log level: debug | verbose | normal | silent diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index eec6b8c..e16cff2 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -4,21 +4,22 @@ //! //! # What can be reloaded without restart //! -//! | Section | Field | Effect | -//! |-----------|-------------------------------|-----------------------------------| -//! | `general` | `log_level` | Filter updated via `log_level_tx` | -//! | `general` | `ad_tag` | Passed on next connection | -//! | `general` | `middle_proxy_pool_size` | Passed on next connection | -//! | `general` | `me_keepalive_*` | Passed on next connection | -//! | `general` | `desync_all_full` | Applied immediately | -//! | `general` | `update_every` | Applied to ME updater immediately | -//! | `general` | `hardswap` | Applied on next ME map update | -//! | `general` | `me_pool_drain_ttl_secs` | Applied on next ME map update | -//! | `general` | `me_pool_min_fresh_ratio` | Applied on next ME map update | -//! | `general` | `me_reinit_drain_timeout_secs`| Applied on next ME map update | -//! | `general` | `telemetry` / `me_*_policy` | Applied immediately | -//! | `network` | `dns_overrides` | Applied immediately | -//! | `access` | All user/quota fields | Effective immediately | +//! | Section | Field | Effect | +//! |-----------|--------------------------------|------------------------------------------------| +//! | `general` | `log_level` | Filter updated via `log_level_tx` | +//! | `access` | `user_ad_tags` | Passed on next connection | +//! | `general` | `ad_tag` | Passed on next connection (fallback per-user) | +//! | `general` | `middle_proxy_pool_size` | Passed on next connection | +//! | `general` | `me_keepalive_*` | Passed on next connection | +//! | `general` | `desync_all_full` | Applied immediately | +//! | `general` | `update_every` | Applied to ME updater immediately | +//! | `general` | `hardswap` | Applied on next ME map update | +//! | `general` | `me_pool_drain_ttl_secs` | Applied on next ME map update | +//! | `general` | `me_pool_min_fresh_ratio` | Applied on next ME map update | +//! | `general` | `me_reinit_drain_timeout_secs` | Applied on next ME map update | +//! | `general` | `telemetry` / `me_*_policy` | Applied immediately | +//! | `network` | `dns_overrides` | Applied immediately | +//! | `access` | All user/quota fields | Effective immediately | //! //! Fields that require re-binding sockets (`server.port`, `censorship.*`, //! `network.*`, `use_middle_proxy`) are **not** applied; a warning is emitted. @@ -207,14 +208,17 @@ fn log_changes( log_tx.send(new_hot.log_level.clone()).ok(); } - if old_hot.ad_tag != new_hot.ad_tag { + if old_hot.access.user_ad_tags != new_hot.access.user_ad_tags { info!( - "config reload: ad_tag: {} → {}", - old_hot.ad_tag.as_deref().unwrap_or("none"), - new_hot.ad_tag.as_deref().unwrap_or("none"), + "config reload: user_ad_tags updated ({} entries)", + new_hot.access.user_ad_tags.len(), ); } + if old_hot.ad_tag != new_hot.ad_tag { + info!("config reload: general.ad_tag updated (applied on next connection)"); + } + if old_hot.dns_overrides != new_hot.dns_overrides { info!( "config reload: network.dns_overrides updated ({} entries)", diff --git a/src/config/load.rs b/src/config/load.rs index 3aafda2..f37791d 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -532,7 +532,7 @@ impl ProxyConfig { ))); } - if let Some(tag) = &self.general.ad_tag { + for (user, tag) in &self.access.user_ad_tags { let zeros = "00000000000000000000000000000000"; if !is_valid_ad_tag(tag) { return Err(ProxyError::Config( @@ -540,7 +540,7 @@ impl ProxyConfig { )); } if tag == zeros { - warn!("ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel"); + warn!(user = %user, "user ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel"); } } diff --git a/src/config/types.rs b/src/config/types.rs index 7a3f6e9..716d78f 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -247,14 +247,15 @@ pub struct GeneralConfig { #[serde(default = "default_true")] pub use_middle_proxy: bool, - #[serde(default)] - pub ad_tag: Option, - /// Path to proxy-secret binary file (auto-downloaded if absent). /// Infrastructure secret from https://core.telegram.org/getProxySecret. #[serde(default = "default_proxy_secret_path")] pub proxy_secret_path: Option, + /// Global ad_tag (32 hex chars from @MTProxybot). Fallback when user has no per-user tag in access.user_ad_tags. + #[serde(default)] + pub ad_tag: Option, + /// Public IP override for middle-proxy NAT environments. /// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr". #[serde(default)] @@ -807,6 +808,10 @@ pub struct AccessConfig { #[serde(default = "default_access_users")] pub users: HashMap, + /// Per-user ad_tag (32 hex chars from @MTProxybot). + #[serde(default)] + pub user_ad_tags: HashMap, + #[serde(default)] pub user_max_tcp_conns: HashMap, @@ -833,6 +838,7 @@ impl Default for AccessConfig { fn default() -> Self { Self { users: default_access_users(), + user_ad_tags: HashMap::new(), user_max_tcp_conns: HashMap::new(), user_expirations: HashMap::new(), user_data_quota: HashMap::new(), diff --git a/src/main.rs b/src/main.rs index 2675509..b910c64 100644 --- a/src/main.rs +++ b/src/main.rs @@ -448,7 +448,7 @@ async fn main() -> std::result::Result<(), Box> { info!("Middle-proxy STUN probing disabled by network.stun_use=false"); } - // ad_tag (proxy_tag) for advertising + // Global ad_tag (pool default). Used when user has no per-user tag in access.user_ad_tags. let proxy_tag = config .general .ad_tag diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index a4942ba..0690906 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -238,7 +238,22 @@ where stats.increment_user_connects(&user); stats.increment_user_curr_connects(&user); - let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag()); + // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) + let user_tag: Option> = config + .access + .user_ad_tags + .get(&user) + .and_then(|s| hex::decode(s).ok()) + .filter(|v| v.len() == 16); + let global_tag: Option> = config + .general + .ad_tag + .as_ref() + .and_then(|s| hex::decode(s).ok()) + .filter(|v| v.len() == 16); + let effective_tag = user_tag.or(global_tag); + + let proto_flags = proto_flags_for_tag(proto_tag, effective_tag.is_some()); debug!( trace_id = format_args!("0x{:016x}", trace_id), user = %user, @@ -256,6 +271,7 @@ where let (c2me_tx, mut c2me_rx) = mpsc::channel::(C2ME_CHANNEL_CAPACITY); let me_pool_c2me = me_pool.clone(); + let effective_tag = effective_tag; let c2me_sender = tokio::spawn(async move { let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { @@ -268,6 +284,7 @@ where translated_local_addr, &payload, flags, + effective_tag.as_deref(), ).await?; sent_since_yield = sent_since_yield.saturating_add(1); if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index f68b1b9..65bc43a 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -18,6 +18,7 @@ use rand::seq::SliceRandom; use super::registry::ConnMeta; impl MePool { + /// Send RPC_PROXY_REQ. `tag_override`: per-user ad_tag (from access.user_ad_tags); if None, uses pool default. pub async fn send_proxy_req( self: &Arc, conn_id: u64, @@ -26,13 +27,15 @@ impl MePool { our_addr: SocketAddr, data: &[u8], proto_flags: u32, + tag_override: Option<&[u8]>, ) -> Result<()> { + let tag = tag_override.or(self.proxy_tag.as_deref()); let payload = build_proxy_req_payload( conn_id, client_addr, our_addr, data, - self.proxy_tag.as_deref(), + tag, proto_flags, ); let meta = ConnMeta { From 6f1980dfd7e117a09b64f4d81392acd8dc1fedcb Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 2 Mar 2026 00:17:58 +0300 Subject: [PATCH 2/2] ME Pool improvements Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/defaults.rs | 36 ++++ src/config/load.rs | 40 +++- src/config/types.rs | 78 ++++++++ src/main.rs | 47 +++-- src/transport/middle_proxy/config_updater.rs | 187 +++++++++++++------ src/transport/middle_proxy/handshake.rs | 14 +- src/transport/middle_proxy/mod.rs | 2 +- src/transport/middle_proxy/pool.rs | 65 ++++++- src/transport/middle_proxy/pool_config.rs | 40 +++- src/transport/middle_proxy/pool_init.rs | 2 +- src/transport/middle_proxy/pool_writer.rs | 27 ++- src/transport/middle_proxy/rotation.rs | 104 ++++++++++- src/transport/middle_proxy/send.rs | 34 +++- 13 files changed, 558 insertions(+), 118 deletions(-) diff --git a/src/config/defaults.rs b/src/config/defaults.rs index ab087fd..0ea6692 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -277,6 +277,18 @@ pub(crate) fn default_me_reinit_every_secs() -> u64 { 15 * 60 } +pub(crate) fn default_me_reinit_singleflight() -> bool { + true +} + +pub(crate) fn default_me_reinit_trigger_channel() -> usize { + 64 +} + +pub(crate) fn default_me_reinit_coalesce_window_ms() -> u64 { + 200 +} + pub(crate) fn default_me_hardswap_warmup_delay_min_ms() -> u64 { 1000 } @@ -301,6 +313,18 @@ pub(crate) fn default_me_config_apply_cooldown_secs() -> u64 { 300 } +pub(crate) fn default_me_snapshot_require_http_2xx() -> bool { + true +} + +pub(crate) fn default_me_snapshot_reject_empty_map() -> bool { + true +} + +pub(crate) fn default_me_snapshot_min_proxy_for_lines() -> u32 { + 1 +} + pub(crate) fn default_proxy_secret_stable_snapshots() -> u8 { 2 } @@ -309,6 +333,10 @@ pub(crate) fn default_proxy_secret_rotate_runtime() -> bool { true } +pub(crate) fn default_me_secret_atomic_snapshot() -> bool { + true +} + pub(crate) fn default_proxy_secret_len_max() -> usize { 256 } @@ -321,10 +349,18 @@ pub(crate) fn default_me_pool_drain_ttl_secs() -> u64 { 90 } +pub(crate) fn default_me_bind_stale_ttl_secs() -> u64 { + default_me_pool_drain_ttl_secs() +} + pub(crate) fn default_me_pool_min_fresh_ratio() -> f32 { 0.8 } +pub(crate) fn default_me_deterministic_writer_sort() -> bool { + true +} + pub(crate) fn default_hardswap() -> bool { true } diff --git a/src/config/load.rs b/src/config/load.rs index f37791d..17545b9 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -305,12 +305,24 @@ impl ProxyConfig { )); } + if config.general.me_snapshot_min_proxy_for_lines == 0 { + return Err(ProxyError::Config( + "general.me_snapshot_min_proxy_for_lines must be > 0".to_string(), + )); + } + if config.general.proxy_secret_stable_snapshots == 0 { return Err(ProxyError::Config( "general.proxy_secret_stable_snapshots must be > 0".to_string(), )); } + if config.general.me_reinit_trigger_channel == 0 { + return Err(ProxyError::Config( + "general.me_reinit_trigger_channel must be > 0".to_string(), + )); + } + if !(32..=4096).contains(&config.general.proxy_secret_len_max) { return Err(ProxyError::Config( "general.proxy_secret_len_max must be within [32, 4096]".to_string(), @@ -535,9 +547,10 @@ impl ProxyConfig { for (user, tag) in &self.access.user_ad_tags { let zeros = "00000000000000000000000000000000"; if !is_valid_ad_tag(tag) { - return Err(ProxyError::Config( - "general.ad_tag must be exactly 32 hex characters".to_string(), - )); + return Err(ProxyError::Config(format!( + "access.user_ad_tags['{}'] must be exactly 32 hex characters", + user + ))); } if tag == zeros { warn!(user = %user, "user ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel"); @@ -1100,6 +1113,27 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn invalid_user_ad_tag_reports_access_user_ad_tags_key() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + alice = "00000000000000000000000000000000" + + [access.user_ad_tags] + alice = "not_hex" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_invalid_user_ad_tag_message_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + let err = cfg.validate().unwrap_err().to_string(); + assert!(err.contains("access.user_ad_tags['alice'] must be exactly 32 hex characters")); + let _ = std::fs::remove_file(path); + } + #[test] fn invalid_dns_override_is_rejected() { let toml = r#" diff --git a/src/config/types.rs b/src/config/types.rs index 716d78f..d57c890 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -130,6 +130,34 @@ impl MeSocksKdfPolicy { } } +/// Stale ME writer bind policy during drain window. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum MeBindStaleMode { + Never, + #[default] + Ttl, + Always, +} + +impl MeBindStaleMode { + pub fn as_u8(self) -> u8 { + match self { + MeBindStaleMode::Never => 0, + MeBindStaleMode::Ttl => 1, + MeBindStaleMode::Always => 2, + } + } + + pub fn from_u8(raw: u8) -> Self { + match raw { + 0 => MeBindStaleMode::Never, + 2 => MeBindStaleMode::Always, + _ => MeBindStaleMode::Ttl, + } + } +} + /// Telemetry controls for hot-path counters and ME diagnostics. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TelemetryConfig { @@ -454,6 +482,18 @@ pub struct GeneralConfig { #[serde(default = "default_me_config_apply_cooldown_secs")] pub me_config_apply_cooldown_secs: u64, + /// Ensure getProxyConfig snapshots are applied only for 2xx HTTP responses. + #[serde(default = "default_me_snapshot_require_http_2xx")] + pub me_snapshot_require_http_2xx: bool, + + /// Reject empty getProxyConfig snapshots instead of marking them applied. + #[serde(default = "default_me_snapshot_reject_empty_map")] + pub me_snapshot_reject_empty_map: bool, + + /// Minimum parsed `proxy_for` rows required to accept a snapshot. + #[serde(default = "default_me_snapshot_min_proxy_for_lines")] + pub me_snapshot_min_proxy_for_lines: u32, + /// Number of identical getProxySecret snapshots required before runtime secret rotation. #[serde(default = "default_proxy_secret_stable_snapshots")] pub proxy_secret_stable_snapshots: u8, @@ -462,6 +502,10 @@ pub struct GeneralConfig { #[serde(default = "default_proxy_secret_rotate_runtime")] pub proxy_secret_rotate_runtime: bool, + /// Keep key-selector and secret bytes from one snapshot during ME handshake. + #[serde(default = "default_me_secret_atomic_snapshot")] + pub me_secret_atomic_snapshot: bool, + /// Maximum allowed proxy-secret length in bytes for startup and runtime refresh. #[serde(default = "default_proxy_secret_len_max")] pub proxy_secret_len_max: usize, @@ -471,6 +515,14 @@ pub struct GeneralConfig { #[serde(default = "default_me_pool_drain_ttl_secs")] pub me_pool_drain_ttl_secs: u64, + /// Policy for new binds on stale draining writers. + #[serde(default)] + pub me_bind_stale_mode: MeBindStaleMode, + + /// TTL for stale bind allowance when `me_bind_stale_mode = \"ttl\"`. + #[serde(default = "default_me_bind_stale_ttl_secs")] + pub me_bind_stale_ttl_secs: u64, + /// Minimum desired-DC coverage ratio required before draining stale writers. /// Range: 0.0..=1.0. #[serde(default = "default_me_pool_min_fresh_ratio")] @@ -491,6 +543,22 @@ pub struct GeneralConfig { #[serde(default = "default_proxy_config_reload_secs")] pub proxy_config_auto_reload_secs: u64, + /// Serialize ME reinit cycles across all trigger sources. + #[serde(default = "default_me_reinit_singleflight")] + pub me_reinit_singleflight: bool, + + /// Trigger queue capacity for reinit scheduler. + #[serde(default = "default_me_reinit_trigger_channel")] + pub me_reinit_trigger_channel: usize, + + /// Trigger coalescing window before starting a reinit cycle. + #[serde(default = "default_me_reinit_coalesce_window_ms")] + pub me_reinit_coalesce_window_ms: u64, + + /// Deterministic candidate sort for ME writer binding path. + #[serde(default = "default_me_deterministic_writer_sort")] + pub me_deterministic_writer_sort: bool, + /// Enable NTP drift check at startup. #[serde(default = "default_ntp_check")] pub ntp_check: bool, @@ -565,14 +633,24 @@ impl Default for GeneralConfig { me_hardswap_warmup_pass_backoff_base_ms: default_me_hardswap_warmup_pass_backoff_base_ms(), me_config_stable_snapshots: default_me_config_stable_snapshots(), me_config_apply_cooldown_secs: default_me_config_apply_cooldown_secs(), + me_snapshot_require_http_2xx: default_me_snapshot_require_http_2xx(), + me_snapshot_reject_empty_map: default_me_snapshot_reject_empty_map(), + me_snapshot_min_proxy_for_lines: default_me_snapshot_min_proxy_for_lines(), proxy_secret_stable_snapshots: default_proxy_secret_stable_snapshots(), proxy_secret_rotate_runtime: default_proxy_secret_rotate_runtime(), + me_secret_atomic_snapshot: default_me_secret_atomic_snapshot(), proxy_secret_len_max: default_proxy_secret_len_max(), me_pool_drain_ttl_secs: default_me_pool_drain_ttl_secs(), + me_bind_stale_mode: MeBindStaleMode::default(), + me_bind_stale_ttl_secs: default_me_bind_stale_ttl_secs(), me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(), me_reinit_drain_timeout_secs: default_me_reinit_drain_timeout_secs(), proxy_secret_auto_reload_secs: default_proxy_secret_reload_secs(), proxy_config_auto_reload_secs: default_proxy_config_reload_secs(), + me_reinit_singleflight: default_me_reinit_singleflight(), + me_reinit_trigger_channel: default_me_reinit_trigger_channel(), + me_reinit_coalesce_window_ms: default_me_reinit_coalesce_window_ms(), + me_deterministic_writer_sort: default_me_deterministic_writer_sort(), ntp_check: default_ntp_check(), ntp_servers: default_ntp_servers(), auto_degradation_enabled: default_true(), diff --git a/src/main.rs b/src/main.rs index b910c64..03998cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use std::time::Duration; use rand::Rng; use tokio::net::TcpListener; use tokio::signal; -use tokio::sync::Semaphore; +use tokio::sync::{Semaphore, mpsc}; use tracing::{debug, error, info, warn}; use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload}; #[cfg(unix)] @@ -40,7 +40,7 @@ use crate::stats::telemetry::TelemetryPolicy; use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; use crate::transport::middle_proxy::{ - MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line, + MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, MeReinitTrigger, format_sample_line, format_me_route, }; use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes}; @@ -546,6 +546,10 @@ async fn main() -> std::result::Result<(), Box> { config.general.me_hardswap_warmup_delay_max_ms, config.general.me_hardswap_warmup_extra_passes, config.general.me_hardswap_warmup_pass_backoff_base_ms, + config.general.me_bind_stale_mode, + config.general.me_bind_stale_ttl_secs, + config.general.me_secret_atomic_snapshot, + config.general.me_deterministic_writer_sort, config.general.me_socks_kdf_policy, config.general.me_route_backpressure_base_timeout_ms, config.general.me_route_backpressure_high_timeout_ms, @@ -849,26 +853,43 @@ async fn main() -> std::result::Result<(), Box> { }); if let Some(ref pool) = me_pool { - let pool_clone = pool.clone(); - let rng_clone = rng.clone(); - let config_rx_clone = config_rx.clone(); + let reinit_trigger_capacity = config + .general + .me_reinit_trigger_channel + .max(1); + let (reinit_tx, reinit_rx) = mpsc::channel::(reinit_trigger_capacity); + + let pool_clone_sched = pool.clone(); + let rng_clone_sched = rng.clone(); + let config_rx_clone_sched = config_rx.clone(); tokio::spawn(async move { - crate::transport::middle_proxy::me_config_updater( - pool_clone, - rng_clone, - config_rx_clone, + crate::transport::middle_proxy::me_reinit_scheduler( + pool_clone_sched, + rng_clone_sched, + config_rx_clone_sched, + reinit_rx, + ) + .await; + }); + + let pool_clone = pool.clone(); + let config_rx_clone = config_rx.clone(); + let reinit_tx_updater = reinit_tx.clone(); + tokio::spawn(async move { + crate::transport::middle_proxy::me_config_updater( + pool_clone, + config_rx_clone, + reinit_tx_updater, ) .await; }); - let pool_clone_rot = pool.clone(); - let rng_clone_rot = rng.clone(); let config_rx_clone_rot = config_rx.clone(); + let reinit_tx_rotation = reinit_tx.clone(); tokio::spawn(async move { crate::transport::middle_proxy::me_rotation_task( - pool_clone_rot, - rng_clone_rot, config_rx_clone_rot, + reinit_tx_rotation, ) .await; }); diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 4e8e63f..2772b27 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -5,15 +5,15 @@ use std::sync::Arc; use std::time::Duration; use httpdate; -use tokio::sync::watch; +use tokio::sync::{mpsc, watch}; use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::error::Result; use super::MePool; +use super::rotation::{MeReinitTrigger, enqueue_reinit_trigger}; use super::secret::download_proxy_secret_with_max_len; -use crate::crypto::SecureRandom; use std::time::SystemTime; async fn retry_fetch(url: &str) -> Option { @@ -38,6 +38,8 @@ async fn retry_fetch(url: &str) -> Option { pub struct ProxyConfigData { pub map: HashMap>, pub default_dc: Option, + pub http_status: u16, + pub proxy_for_lines: u32, } #[derive(Debug, Default)] @@ -172,6 +174,7 @@ pub async fn fetch_proxy_config(url: &str) -> Result { .await .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")))? ; + let http_status = resp.status().as_u16(); if let Some(date) = resp.headers().get(reqwest::header::DATE) && let Ok(date_str) = date.to_str() @@ -194,9 +197,11 @@ pub async fn fetch_proxy_config(url: &str) -> Result { .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?; let mut map: HashMap> = HashMap::new(); + let mut proxy_for_lines: u32 = 0; for line in text.lines() { if let Some((dc, ip, port)) = parse_proxy_line(line) { map.entry(dc).or_default().push((ip, port)); + proxy_for_lines = proxy_for_lines.saturating_add(1); } } @@ -214,14 +219,49 @@ pub async fn fetch_proxy_config(url: &str) -> Result { None }); - Ok(ProxyConfigData { map, default_dc }) + Ok(ProxyConfigData { + map, + default_dc, + http_status, + proxy_for_lines, + }) +} + +fn snapshot_passes_guards( + cfg: &ProxyConfig, + snapshot: &ProxyConfigData, + snapshot_name: &'static str, +) -> bool { + if cfg.general.me_snapshot_require_http_2xx + && !(200..=299).contains(&snapshot.http_status) + { + warn!( + snapshot = snapshot_name, + http_status = snapshot.http_status, + "ME snapshot rejected by non-2xx HTTP status" + ); + return false; + } + + let min_proxy_for = cfg.general.me_snapshot_min_proxy_for_lines; + if snapshot.proxy_for_lines < min_proxy_for { + warn!( + snapshot = snapshot_name, + parsed_proxy_for_lines = snapshot.proxy_for_lines, + min_proxy_for_lines = min_proxy_for, + "ME snapshot rejected by proxy_for line floor" + ); + return false; + } + + true } async fn run_update_cycle( pool: &Arc, - rng: &Arc, cfg: &ProxyConfig, state: &mut UpdaterState, + reinit_tx: &mpsc::Sender, ) { pool.update_runtime_reinit_policy( cfg.general.hardswap, @@ -232,6 +272,10 @@ async fn run_update_cycle( cfg.general.me_hardswap_warmup_delay_max_ms, cfg.general.me_hardswap_warmup_extra_passes, cfg.general.me_hardswap_warmup_pass_backoff_base_ms, + cfg.general.me_bind_stale_mode, + cfg.general.me_bind_stale_ttl_secs, + cfg.general.me_secret_atomic_snapshot, + cfg.general.me_deterministic_writer_sort, ); let required_cfg_snapshots = cfg.general.me_config_stable_snapshots.max(1); @@ -242,44 +286,48 @@ async fn run_update_cycle( let mut ready_v4: Option<(ProxyConfigData, u64)> = None; let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await; if let Some(cfg_v4) = cfg_v4 { - let cfg_v4_hash = hash_proxy_config(&cfg_v4); - let stable_hits = state.config_v4.observe(cfg_v4_hash); - if stable_hits < required_cfg_snapshots { - debug!( - stable_hits, - required_cfg_snapshots, - snapshot = format_args!("0x{cfg_v4_hash:016x}"), - "ME config v4 candidate observed" - ); - } else if state.config_v4.is_applied(cfg_v4_hash) { - debug!( - snapshot = format_args!("0x{cfg_v4_hash:016x}"), - "ME config v4 stable snapshot already applied" - ); - } else { - ready_v4 = Some((cfg_v4, cfg_v4_hash)); + if snapshot_passes_guards(cfg, &cfg_v4, "getProxyConfig") { + let cfg_v4_hash = hash_proxy_config(&cfg_v4); + let stable_hits = state.config_v4.observe(cfg_v4_hash); + if stable_hits < required_cfg_snapshots { + debug!( + stable_hits, + required_cfg_snapshots, + snapshot = format_args!("0x{cfg_v4_hash:016x}"), + "ME config v4 candidate observed" + ); + } else if state.config_v4.is_applied(cfg_v4_hash) { + debug!( + snapshot = format_args!("0x{cfg_v4_hash:016x}"), + "ME config v4 stable snapshot already applied" + ); + } else { + ready_v4 = Some((cfg_v4, cfg_v4_hash)); + } } } let mut ready_v6: Option<(ProxyConfigData, u64)> = None; let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await; if let Some(cfg_v6) = cfg_v6 { - let cfg_v6_hash = hash_proxy_config(&cfg_v6); - let stable_hits = state.config_v6.observe(cfg_v6_hash); - if stable_hits < required_cfg_snapshots { - debug!( - stable_hits, - required_cfg_snapshots, - snapshot = format_args!("0x{cfg_v6_hash:016x}"), - "ME config v6 candidate observed" - ); - } else if state.config_v6.is_applied(cfg_v6_hash) { - debug!( - snapshot = format_args!("0x{cfg_v6_hash:016x}"), - "ME config v6 stable snapshot already applied" - ); - } else { - ready_v6 = Some((cfg_v6, cfg_v6_hash)); + if snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") { + let cfg_v6_hash = hash_proxy_config(&cfg_v6); + let stable_hits = state.config_v6.observe(cfg_v6_hash); + if stable_hits < required_cfg_snapshots { + debug!( + stable_hits, + required_cfg_snapshots, + snapshot = format_args!("0x{cfg_v6_hash:016x}"), + "ME config v6 candidate observed" + ); + } else if state.config_v6.is_applied(cfg_v6_hash) { + debug!( + snapshot = format_args!("0x{cfg_v6_hash:016x}"), + "ME config v6 stable snapshot already applied" + ); + } else { + ready_v6 = Some((cfg_v6, cfg_v6_hash)); + } } } @@ -292,28 +340,40 @@ async fn run_update_cycle( let update_v6 = ready_v6 .as_ref() .map(|(snapshot, _)| snapshot.map.clone()); - - let changed = pool.update_proxy_maps(update_v4, update_v6).await; - - if let Some((snapshot, hash)) = ready_v4 { - if let Some(dc) = snapshot.default_dc { - pool.default_dc - .store(dc, std::sync::atomic::Ordering::Relaxed); - } - state.config_v4.mark_applied(hash); - } - - if let Some((_snapshot, hash)) = ready_v6 { - state.config_v6.mark_applied(hash); - } - - state.last_map_apply_at = Some(tokio::time::Instant::now()); - - if changed { - maps_changed = true; - info!("ME config update applied after stable-gate"); + let update_is_empty = + update_v4.is_empty() && update_v6.as_ref().is_none_or(|v| v.is_empty()); + let apply_outcome = if update_is_empty && !cfg.general.me_snapshot_reject_empty_map { + super::pool_config::SnapshotApplyOutcome::AppliedNoDelta } else { - debug!("ME config stable-gate applied with no map delta"); + pool.update_proxy_maps(update_v4, update_v6).await + }; + + if matches!( + apply_outcome, + super::pool_config::SnapshotApplyOutcome::RejectedEmpty + ) { + warn!("ME config stable snapshot rejected (empty endpoint map)"); + } else { + if let Some((snapshot, hash)) = ready_v4 { + if let Some(dc) = snapshot.default_dc { + pool.default_dc + .store(dc, std::sync::atomic::Ordering::Relaxed); + } + state.config_v4.mark_applied(hash); + } + + if let Some((_snapshot, hash)) = ready_v6 { + state.config_v6.mark_applied(hash); + } + + state.last_map_apply_at = Some(tokio::time::Instant::now()); + + if apply_outcome.changed() { + maps_changed = true; + info!("ME config update applied after stable-gate"); + } else { + debug!("ME config stable-gate applied with no map delta"); + } } } else if let Some(last) = state.last_map_apply_at { let wait_secs = map_apply_cooldown_remaining_secs(last, apply_cooldown); @@ -325,8 +385,7 @@ async fn run_update_cycle( } if maps_changed { - pool.zero_downtime_reinit_after_map_change(rng.as_ref()) - .await; + enqueue_reinit_trigger(reinit_tx, MeReinitTrigger::MapChanged); } pool.reset_stun_state(); @@ -367,8 +426,8 @@ async fn run_update_cycle( pub async fn me_config_updater( pool: Arc, - rng: Arc, mut config_rx: watch::Receiver>, + reinit_tx: mpsc::Sender, ) { let mut state = UpdaterState::default(); let mut update_every_secs = config_rx @@ -387,7 +446,7 @@ pub async fn me_config_updater( tokio::select! { _ = &mut sleep => { let cfg = config_rx.borrow().clone(); - run_update_cycle(&pool, &rng, cfg.as_ref(), &mut state).await; + run_update_cycle(&pool, cfg.as_ref(), &mut state, &reinit_tx).await; let refreshed_secs = cfg.general.effective_update_every_secs().max(1); if refreshed_secs != update_every_secs { info!( @@ -415,6 +474,10 @@ pub async fn me_config_updater( cfg.general.me_hardswap_warmup_delay_max_ms, cfg.general.me_hardswap_warmup_extra_passes, cfg.general.me_hardswap_warmup_pass_backoff_base_ms, + cfg.general.me_bind_stale_mode, + cfg.general.me_bind_stale_ttl_secs, + cfg.general.me_secret_atomic_snapshot, + cfg.general.me_deterministic_writer_sort, ); let new_secs = cfg.general.effective_update_every_secs().max(1); if new_secs == update_every_secs { @@ -429,7 +492,7 @@ pub async fn me_config_updater( ); update_every_secs = new_secs; update_every = Duration::from_secs(update_every_secs); - run_update_cycle(&pool, &rng, cfg.as_ref(), &mut state).await; + run_update_cycle(&pool, cfg.as_ref(), &mut state, &reinit_tx).await; next_tick = tokio::time::Instant::now() + update_every; } else { info!( diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 384ecc9..5daa460 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -1,4 +1,5 @@ use std::net::{IpAddr, SocketAddr}; +use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; use socket2::{SockRef, TcpKeepalive}; #[cfg(target_os = "linux")] @@ -267,7 +268,16 @@ impl MePool { .unwrap_or_default() .as_secs() as u32; - let ks = self.key_selector().await; + let secret_atomic_snapshot = self.secret_atomic_snapshot.load(Ordering::Relaxed); + let (ks, secret) = if secret_atomic_snapshot { + let snapshot = self.secret_snapshot().await; + (snapshot.key_selector, snapshot.secret) + } else { + // Backward-compatible mode: key selector and secret may come from different updates. + let key_selector = self.key_selector().await; + let secret = self.secret_snapshot().await.secret; + (key_selector, secret) + }; let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); let nonce_frame = build_rpc_frame(-2, &nonce_payload, RpcChecksumMode::Crc32); let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); @@ -357,8 +367,6 @@ impl MePool { let diag_level: u8 = std::env::var("ME_DIAG").ok().and_then(|v| v.parse().ok()).unwrap_or(0); - let secret: Vec = self.proxy_secret.read().await.clone(); - let prekey_client = build_middleproxy_prekey( &srv_nonce, &my_nonce, diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 1072ec8..26c58a6 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -30,7 +30,7 @@ pub use pool_nat::{stun_probe, detect_public_ip}; pub use registry::ConnRegistry; pub use secret::fetch_proxy_secret; pub use config_updater::{fetch_proxy_config, me_config_updater}; -pub use rotation::me_rotation_task; +pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task}; pub use wire::proto_flags_for_tag; #[derive(Debug)] diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index f67b2a8..d87430a 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -7,7 +7,7 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use tokio::sync::{Mutex, Notify, RwLock, mpsc}; use tokio_util::sync::CancellationToken; -use crate::config::MeSocksKdfPolicy; +use crate::config::{MeBindStaleMode, MeSocksKdfPolicy}; use crate::crypto::SecureRandom; use crate::network::IpFamily; use crate::network::probe::NetworkDecision; @@ -29,6 +29,13 @@ pub struct MeWriter { pub allow_drain_fallback: Arc, } +#[derive(Debug, Clone)] +pub struct SecretSnapshot { + pub epoch: u64, + pub key_selector: u32, + pub secret: Vec, +} + #[allow(dead_code)] pub struct MePool { pub(super) registry: Arc, @@ -38,7 +45,7 @@ pub struct MePool { pub(super) upstream: Option>, pub(super) rng: Arc, pub(super) proxy_tag: Option>, - pub(super) proxy_secret: Arc>>, + pub(super) proxy_secret: Arc>, pub(super) nat_ip_cfg: Option, pub(super) nat_ip_detected: Arc>>, pub(super) nat_probe: bool, @@ -83,6 +90,10 @@ pub struct MePool { pub(super) me_hardswap_warmup_delay_max_ms: AtomicU64, pub(super) me_hardswap_warmup_extra_passes: AtomicU32, pub(super) me_hardswap_warmup_pass_backoff_base_ms: AtomicU64, + pub(super) me_bind_stale_mode: AtomicU8, + pub(super) me_bind_stale_ttl_secs: AtomicU64, + pub(super) secret_atomic_snapshot: AtomicBool, + pub(super) me_deterministic_writer_sort: AtomicBool, pub(super) me_socks_kdf_policy: AtomicU8, pool_size: usize, } @@ -147,6 +158,10 @@ impl MePool { me_hardswap_warmup_delay_max_ms: u64, me_hardswap_warmup_extra_passes: u8, me_hardswap_warmup_pass_backoff_base_ms: u64, + me_bind_stale_mode: MeBindStaleMode, + me_bind_stale_ttl_secs: u64, + me_secret_atomic_snapshot: bool, + me_deterministic_writer_sort: bool, me_socks_kdf_policy: MeSocksKdfPolicy, me_route_backpressure_base_timeout_ms: u64, me_route_backpressure_high_timeout_ms: u64, @@ -166,7 +181,20 @@ impl MePool { upstream, rng, proxy_tag, - proxy_secret: Arc::new(RwLock::new(proxy_secret)), + proxy_secret: Arc::new(RwLock::new(SecretSnapshot { + epoch: 1, + key_selector: if proxy_secret.len() >= 4 { + u32::from_le_bytes([ + proxy_secret[0], + proxy_secret[1], + proxy_secret[2], + proxy_secret[3], + ]) + } else { + 0 + }, + secret: proxy_secret, + })), nat_ip_cfg: nat_ip, nat_ip_detected: Arc::new(RwLock::new(None)), nat_probe, @@ -216,6 +244,10 @@ impl MePool { me_hardswap_warmup_pass_backoff_base_ms: AtomicU64::new( me_hardswap_warmup_pass_backoff_base_ms, ), + me_bind_stale_mode: AtomicU8::new(me_bind_stale_mode.as_u8()), + me_bind_stale_ttl_secs: AtomicU64::new(me_bind_stale_ttl_secs), + secret_atomic_snapshot: AtomicBool::new(me_secret_atomic_snapshot), + me_deterministic_writer_sort: AtomicBool::new(me_deterministic_writer_sort), me_socks_kdf_policy: AtomicU8::new(me_socks_kdf_policy.as_u8()), }) } @@ -238,6 +270,10 @@ impl MePool { hardswap_warmup_delay_max_ms: u64, hardswap_warmup_extra_passes: u8, hardswap_warmup_pass_backoff_base_ms: u64, + bind_stale_mode: MeBindStaleMode, + bind_stale_ttl_secs: u64, + secret_atomic_snapshot: bool, + deterministic_writer_sort: bool, ) { self.hardswap.store(hardswap, Ordering::Relaxed); self.me_pool_drain_ttl_secs @@ -254,6 +290,14 @@ impl MePool { .store(hardswap_warmup_extra_passes as u32, Ordering::Relaxed); self.me_hardswap_warmup_pass_backoff_base_ms .store(hardswap_warmup_pass_backoff_base_ms, Ordering::Relaxed); + self.me_bind_stale_mode + .store(bind_stale_mode.as_u8(), Ordering::Relaxed); + self.me_bind_stale_ttl_secs + .store(bind_stale_ttl_secs, Ordering::Relaxed); + self.secret_atomic_snapshot + .store(secret_atomic_snapshot, Ordering::Relaxed); + self.me_deterministic_writer_sort + .store(deterministic_writer_sort, Ordering::Relaxed); } pub fn reset_stun_state(&self) { @@ -307,12 +351,15 @@ impl MePool { } pub(super) async fn key_selector(&self) -> u32 { - let secret = self.proxy_secret.read().await; - if secret.len() >= 4 { - u32::from_le_bytes([secret[0], secret[1], secret[2], secret[3]]) - } else { - 0 - } + self.proxy_secret.read().await.key_selector + } + + pub(super) async fn secret_snapshot(&self) -> SecretSnapshot { + self.proxy_secret.read().await.clone() + } + + pub(super) fn bind_stale_mode(&self) -> MeBindStaleMode { + MeBindStaleMode::from_u8(self.me_bind_stale_mode.load(Ordering::Relaxed)) } pub(super) fn family_order(&self) -> Vec { diff --git a/src/transport/middle_proxy/pool_config.rs b/src/transport/middle_proxy/pool_config.rs index fe2aad8..04e3bb5 100644 --- a/src/transport/middle_proxy/pool_config.rs +++ b/src/transport/middle_proxy/pool_config.rs @@ -7,12 +7,29 @@ use tracing::warn; use super::pool::MePool; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SnapshotApplyOutcome { + AppliedChanged, + AppliedNoDelta, + RejectedEmpty, +} + +impl SnapshotApplyOutcome { + pub fn changed(self) -> bool { + matches!(self, SnapshotApplyOutcome::AppliedChanged) + } +} + impl MePool { pub async fn update_proxy_maps( &self, new_v4: HashMap>, new_v6: Option>>, - ) -> bool { + ) -> SnapshotApplyOutcome { + if new_v4.is_empty() && new_v6.as_ref().is_none_or(|v| v.is_empty()) { + return SnapshotApplyOutcome::RejectedEmpty; + } + let mut changed = false; { let mut guard = self.proxy_map_v4.write().await; @@ -51,7 +68,11 @@ impl MePool { } } } - changed + if changed { + SnapshotApplyOutcome::AppliedChanged + } else { + SnapshotApplyOutcome::AppliedNoDelta + } } pub async fn update_secret(self: &Arc, new_secret: Vec) -> bool { @@ -60,8 +81,19 @@ impl MePool { return false; } let mut guard = self.proxy_secret.write().await; - if *guard != new_secret { - *guard = new_secret; + if guard.secret != new_secret { + guard.secret = new_secret; + guard.key_selector = if guard.secret.len() >= 4 { + u32::from_le_bytes([ + guard.secret[0], + guard.secret[1], + guard.secret[2], + guard.secret[3], + ]) + } else { + 0 + }; + guard.epoch = guard.epoch.saturating_add(1); drop(guard); self.reconnect_all().await; return true; diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs index 623be7f..fef1553 100644 --- a/src/transport/middle_proxy/pool_init.rs +++ b/src/transport/middle_proxy/pool_init.rs @@ -19,7 +19,7 @@ impl MePool { me_servers = self.proxy_map_v4.read().await.len(), pool_size, key_selector = format_args!("0x{ks:08x}"), - secret_len = self.proxy_secret.read().await.len(), + secret_len = self.proxy_secret.read().await.secret.len(), "Initializing ME pool" ); diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 28f5538..a8cc5a5 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -9,6 +9,7 @@ use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; +use crate::config::MeBindStaleMode; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::protocol::constants::RPC_PING_U32; @@ -42,7 +43,7 @@ impl MePool { } pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { - let secret_len = self.proxy_secret.read().await.len(); + let secret_len = self.proxy_secret.read().await.secret.len(); if secret_len < 32 { return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); } @@ -351,16 +352,22 @@ impl MePool { return false; } - let ttl_secs = self.me_pool_drain_ttl_secs.load(Ordering::Relaxed); - if ttl_secs == 0 { - return true; - } + match self.bind_stale_mode() { + MeBindStaleMode::Never => false, + MeBindStaleMode::Always => true, + MeBindStaleMode::Ttl => { + let ttl_secs = self.me_bind_stale_ttl_secs.load(Ordering::Relaxed); + if ttl_secs == 0 { + return true; + } - let started = writer.draining_started_at_epoch_secs.load(Ordering::Relaxed); - if started == 0 { - return false; - } + let started = writer.draining_started_at_epoch_secs.load(Ordering::Relaxed); + if started == 0 { + return false; + } - Self::now_epoch_secs().saturating_sub(started) <= ttl_secs + Self::now_epoch_secs().saturating_sub(started) <= ttl_secs + } + } } } diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs index cf5f70d..16232c9 100644 --- a/src/transport/middle_proxy/rotation.rs +++ b/src/transport/middle_proxy/rotation.rs @@ -1,19 +1,111 @@ use std::sync::Arc; use std::time::Duration; -use tokio::sync::watch; -use tracing::{info, warn}; +use tokio::sync::{mpsc, watch}; +use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use super::MePool; -/// Periodically reinitialize ME generations and swap them after full warmup. -pub async fn me_rotation_task( +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MeReinitTrigger { + Periodic, + MapChanged, +} + +impl MeReinitTrigger { + fn as_str(self) -> &'static str { + match self { + MeReinitTrigger::Periodic => "periodic", + MeReinitTrigger::MapChanged => "map-change", + } + } +} + +pub fn enqueue_reinit_trigger( + tx: &mpsc::Sender, + trigger: MeReinitTrigger, +) { + match tx.try_send(trigger) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + debug!(trigger = trigger.as_str(), "ME reinit trigger dropped (queue full)"); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + warn!(trigger = trigger.as_str(), "ME reinit trigger dropped (scheduler closed)"); + } + } +} + +pub async fn me_reinit_scheduler( pool: Arc, rng: Arc, + config_rx: watch::Receiver>, + mut trigger_rx: mpsc::Receiver, +) { + info!("ME reinit scheduler started"); + loop { + let Some(first_trigger) = trigger_rx.recv().await else { + warn!("ME reinit scheduler stopped: trigger channel closed"); + break; + }; + + let mut map_change_seen = matches!(first_trigger, MeReinitTrigger::MapChanged); + let mut periodic_seen = matches!(first_trigger, MeReinitTrigger::Periodic); + let cfg = config_rx.borrow().clone(); + let coalesce_window = Duration::from_millis(cfg.general.me_reinit_coalesce_window_ms); + if !coalesce_window.is_zero() { + let deadline = tokio::time::Instant::now() + coalesce_window; + loop { + let now = tokio::time::Instant::now(); + if now >= deadline { + break; + } + match tokio::time::timeout(deadline - now, trigger_rx.recv()).await { + Ok(Some(next)) => { + if next == MeReinitTrigger::MapChanged { + map_change_seen = true; + } else { + periodic_seen = true; + } + } + Ok(None) => break, + Err(_) => break, + } + } + } + + let reason = if map_change_seen && periodic_seen { + "map-change+periodic" + } else if map_change_seen { + "map-change" + } else { + "periodic" + }; + + if cfg.general.me_reinit_singleflight { + debug!(reason, "ME reinit scheduled (single-flight)"); + pool.zero_downtime_reinit_periodic(rng.as_ref()).await; + } else { + debug!(reason, "ME reinit scheduled (concurrent mode)"); + let pool_clone = pool.clone(); + let rng_clone = rng.clone(); + tokio::spawn(async move { + pool_clone + .zero_downtime_reinit_periodic(rng_clone.as_ref()) + .await; + }); + } + + } +} + +/// Periodically enqueue reinitialization triggers for ME generations. +pub async fn me_rotation_task( mut config_rx: watch::Receiver>, + reinit_tx: mpsc::Sender, ) { let mut interval_secs = config_rx .borrow() @@ -31,7 +123,7 @@ pub async fn me_rotation_task( tokio::select! { _ = &mut sleep => { - pool.zero_downtime_reinit_periodic(rng.as_ref()).await; + enqueue_reinit_trigger(&reinit_tx, MeReinitTrigger::Periodic); let refreshed_secs = config_rx .borrow() .general @@ -70,7 +162,7 @@ pub async fn me_rotation_task( ); interval_secs = new_secs; interval = Duration::from_secs(interval_secs); - pool.zero_downtime_reinit_periodic(rng.as_ref()).await; + enqueue_reinit_trigger(&reinit_tx, MeReinitTrigger::Periodic); next_tick = tokio::time::Instant::now() + interval; } else { info!( diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 65bc43a..25b8852 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -138,12 +138,34 @@ impl MePool { } } - candidate_indices.sort_by_key(|idx| { - let w = &writers_snapshot[*idx]; - let degraded = w.degraded.load(Ordering::Relaxed); - let stale = (w.generation < self.current_generation()) as usize; - (stale, degraded as usize, Reverse(w.tx.capacity())) - }); + if self.me_deterministic_writer_sort.load(Ordering::Relaxed) { + candidate_indices.sort_by(|lhs, rhs| { + let left = &writers_snapshot[*lhs]; + let right = &writers_snapshot[*rhs]; + let left_key = ( + (left.generation < self.current_generation()) as usize, + left.degraded.load(Ordering::Relaxed) as usize, + Reverse(left.tx.capacity()), + left.addr, + left.id, + ); + let right_key = ( + (right.generation < self.current_generation()) as usize, + right.degraded.load(Ordering::Relaxed) as usize, + Reverse(right.tx.capacity()), + right.addr, + right.id, + ); + left_key.cmp(&right_key) + }); + } else { + candidate_indices.sort_by_key(|idx| { + let w = &writers_snapshot[*idx]; + let degraded = w.degraded.load(Ordering::Relaxed); + let stale = (w.generation < self.current_generation()) as usize; + (stale, degraded as usize, Reverse(w.tx.capacity())) + }); + } let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); let mut fallback_blocking_idx: Option = None;