diff --git a/README.md b/README.md index 1df8bb0..b771a55 100644 --- a/README.md +++ b/README.md @@ -143,10 +143,6 @@ then Ctrl+X -> Y -> Enter to save ## Configuration ### Minimal Configuration for First Start ```toml -# === UI === -# Users to show in the startup log (tg:// links) -show_link = ["hello"] - # === General Settings === [general] prefer_ipv6 = false @@ -164,9 +160,17 @@ tls = true port = 443 listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv6 = "::" +# listen_unix_sock = "/var/run/telemt.sock" # Unix socket +# listen_unix_sock_perm = "0666" # Socket file permissions # metrics_port = 9090 # metrics_whitelist = ["127.0.0.1", "::1"] +# Users to show in the startup log (tg:// links) +[general.links] +show = ["hello"] +# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links +# public_port = 443 # Port for tg:// links (default: server.port) + # Listen on multiple interfaces/IPs (overrides listen_addr_*) [[server.listeners]] ip = "0.0.0.0" diff --git a/config.toml b/config.toml index 7344dc5..a0fd7b6 100644 --- a/config.toml +++ b/config.toml @@ -1,7 +1,3 @@ -# === UI === -# Users to show in the startup log (tg:// links) -show_link = ["hello"] - # === General Settings === [general] prefer_ipv6 = false @@ -24,9 +20,17 @@ tls = true port = 443 listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv6 = "::" +# listen_unix_sock = "/var/run/telemt.sock" # Unix socket +# listen_unix_sock_perm = "0666" # Socket file permissions # metrics_port = 9090 # metrics_whitelist = ["127.0.0.1", "::1"] +# Users to show in the startup log (tg:// links) +[general.links] +show = ["hello"] +# public_host = "proxy.example.com" # Host (IP or domain) for tg:// links +# public_port = 443 # Port for tg:// links (default: server.port) + # Listen on multiple interfaces/IPs (overrides listen_addr_*) [[server.listeners]] ip = "0.0.0.0" diff --git a/src/cli.rs b/src/cli.rs index 1440a63..7c44d6d 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -186,8 +186,6 @@ fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> Str r#"# Telemt MTProxy — auto-generated config # Re-run `telemt --init` to regenerate -show_link = ["{username}"] - [general] prefer_ipv6 = false fast_mode = true @@ -199,10 +197,17 @@ classic = false secure = false tls = true +[general.links] +show = ["{username}"] +# public_host = "proxy.example.com" +# public_port = 443 + [server] port = {port} listen_addr_ipv4 = "0.0.0.0" listen_addr_ipv6 = "::" +# listen_unix_sock = "/var/run/telemt.sock" +# listen_unix_sock_perm = "0666" [[server.listeners]] ip = "0.0.0.0" @@ -220,6 +225,8 @@ client_ack = 300 tls_domain = "{domain}" mask = true mask_port = 443 +# mask_host = "{domain}" +# mask_unix_sock = "/var/run/nginx.sock" fake_cert_len = 2048 [access] diff --git a/src/config/mod.rs b/src/config/mod.rs index dbf8afa..d4e1c9a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -4,7 +4,7 @@ use crate::error::{ProxyError, Result}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::path::Path; // ============= Helper Defaults ============= @@ -39,9 +39,6 @@ fn default_keepalive() -> u64 { fn default_ack_timeout() -> u64 { 300 } -fn default_listen_addr() -> String { - "0.0.0.0".to_string() -} fn default_fake_cert_len() -> usize { 2048 } @@ -156,6 +153,26 @@ pub struct GeneralConfig { #[serde(default)] pub log_level: LogLevel, + + #[serde(default)] + pub links: LinksConfig, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct LinksConfig { + /// Users whose tg:// links to show at startup. + #[serde(default)] + pub show: Vec, + + /// Public host (IP or domain) for tg:// link generation. + /// Overrides announce_ip / detected IP in links. + #[serde(default)] + pub public_host: Option, + + /// Public port for tg:// link generation. + /// Overrides server.port in links. + #[serde(default)] + pub public_port: Option, } impl Default for GeneralConfig { @@ -169,6 +186,7 @@ impl Default for GeneralConfig { proxy_secret_path: None, middle_proxy_nat_ip: None, log_level: LogLevel::Normal, + links: LinksConfig::default(), } } } @@ -178,8 +196,8 @@ pub struct ServerConfig { #[serde(default = "default_port")] pub port: u16, - #[serde(default = "default_listen_addr")] - pub listen_addr_ipv4: String, + #[serde(default)] + pub listen_addr_ipv4: Option, #[serde(default)] pub listen_addr_ipv6: Option, @@ -187,6 +205,11 @@ pub struct ServerConfig { #[serde(default)] pub listen_unix_sock: Option, + /// Unix socket file permissions (octal string, e.g. "0666"). + /// Applied after bind. If not set, inherits from process umask. + #[serde(default)] + pub listen_unix_sock_perm: Option, + #[serde(default)] pub metrics_port: Option, @@ -201,9 +224,10 @@ impl Default for ServerConfig { fn default() -> Self { Self { port: default_port(), - listen_addr_ipv4: default_listen_addr(), + listen_addr_ipv4: None, listen_addr_ipv6: Some("::".to_string()), listen_unix_sock: None, + listen_unix_sock_perm: None, metrics_port: None, metrics_whitelist: default_metrics_whitelist(), listeners: Vec::new(), @@ -380,7 +404,7 @@ pub struct ProxyConfig { #[serde(default)] pub upstreams: Vec, - #[serde(default)] + #[serde(default, skip_serializing_if = "Vec::is_empty")] pub show_link: Vec, /// DC address overrides for non-standard DCs (CDN, media, test, etc.) @@ -397,15 +421,26 @@ pub struct ProxyConfig { /// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf). #[serde(default)] pub default_dc: Option, + + /// Non-fatal warnings collected during config loading. + #[serde(skip)] + pub warnings: Vec, } impl ProxyConfig { pub fn load>(path: P) -> Result { - let content = - std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; + let content = std::fs::read_to_string(path) + .map_err(|e| ProxyError::Config(e.to_string()))?; - let mut config: ProxyConfig = - toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; + // Pre-parse raw TOML to detect defaulted fields + let raw: toml::Value = toml::from_str(&content) + .map_err(|e| ProxyError::Config(e.to_string()))?; + let port_explicit = raw.get("server") + .and_then(|s| s.get("port")) + .is_some(); + + let mut config: ProxyConfig = toml::from_str(&content) + .map_err(|e| ProxyError::Config(e.to_string()))?; // Validate secrets for (user, secret) in &config.access.users { @@ -457,15 +492,51 @@ impl ProxyConfig { use rand::Rng; config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); - // Migration: Populate listeners if empty - if config.server.listeners.is_empty() { - if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::() { - config.server.listeners.push(ListenerConfig { - ip: ipv4, - announce_ip: None, - }); + // Validate listen_unix_sock + if let Some(ref sock_path) = config.server.listen_unix_sock { + if sock_path.is_empty() { + return Err(ProxyError::Config( + "listen_unix_sock cannot be empty".to_string() + )); } - if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { + #[cfg(unix)] + if sock_path.len() > 107 { + return Err(ProxyError::Config( + format!("listen_unix_sock path too long: {} bytes (max 107)", sock_path.len()) + )); + } + #[cfg(not(unix))] + return Err(ProxyError::Config( + "listen_unix_sock is only supported on Unix platforms".to_string() + )); + } + + // Validate listen_unix_sock_perm + if let Some(ref perm_str) = config.server.listen_unix_sock_perm { + if config.server.listen_unix_sock.is_none() { + return Err(ProxyError::Config( + "listen_unix_sock_perm requires listen_unix_sock to be set".to_string() + )); + } + u32::from_str_radix(perm_str, 8).map_err(|_| { + ProxyError::Config(format!( + "listen_unix_sock_perm must be an octal string (e.g. \"0666\"), got \"{}\"", + perm_str + )) + })?; + } + + // Migration: Populate listeners from legacy listen_addr_* fields. + if config.server.listeners.is_empty() { + if let Some(ref ipv4_str) = config.server.listen_addr_ipv4 { + if let Ok(ipv4) = ipv4_str.parse::() { + config.server.listeners.push(ListenerConfig { + ip: ipv4, + announce_ip: None, + }); + } + } + if let Some(ref ipv6_str) = config.server.listen_addr_ipv6 { if let Ok(ipv6) = ipv6_str.parse::() { config.server.listeners.push(ListenerConfig { ip: ipv6, @@ -475,6 +546,18 @@ impl ProxyConfig { } } + // Validate: at least one listen endpoint must be configured. + if config.server.listeners.is_empty() && config.server.listen_unix_sock.is_none() { + return Err(ProxyError::Config( + "No listen address configured. Set [[server.listeners]], listen_addr_ipv4, or listen_unix_sock".to_string() + )); + } + + // Migration: show_link → general.links.show + if !config.show_link.is_empty() && config.general.links.show.is_empty() { + config.general.links.show = std::mem::take(&mut config.show_link); + } + // Migration: Populate upstreams if empty (Default Direct) if config.upstreams.is_empty() { config.upstreams.push(UpstreamConfig { @@ -484,6 +567,20 @@ impl ProxyConfig { }); } + // Warnings for defaulted fields + if !config.server.listeners.is_empty() && !port_explicit { + config.warnings.push(format!( + "[server] port is not set; defaulting to {}", + config.server.port + )); + } + if config.server.listen_unix_sock.is_some() && config.general.links.public_port.is_none() { + config.warnings.push(format!( + "[general.links] public_port is not set; using [server] port {} for tg:// links", + config.server.port + )); + } + Ok(config) } diff --git a/src/main.rs b/src/main.rs index e03f600..98c1171 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,8 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; use tokio::signal; use tokio::sync::Semaphore; use tracing::{debug, error, info, warn}; @@ -20,9 +22,11 @@ mod stream; mod transport; mod util; -use crate::config::{LogLevel, ProxyConfig}; +use crate::config::{ProxyConfig, LogLevel}; +use crate::proxy::{ClientHandler, handle_client_stream}; +#[cfg(unix)] +use crate::transport::{create_unix_listener, cleanup_unix_socket}; use crate::crypto::SecureRandom; -use crate::proxy::ClientHandler; use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; use crate::transport::middle_proxy::MePool; @@ -97,6 +101,31 @@ fn parse_cli() -> (String, bool, Option) { (config_path, silent, log_level) } +fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { + info!("--- Proxy Links ({}) ---", host); + for user_name in &config.general.links.show { + if let Some(secret) = config.access.users.get(user_name) { + info!("User: {}", user_name); + if config.general.modes.classic { + info!(" Classic: tg://proxy?server={}&port={}&secret={}", + host, port, secret); + } + if config.general.modes.secure { + info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", + host, port, secret); + } + if config.general.modes.tls { + let domain_hex = hex::encode(&config.censorship.tls_domain); + info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + host, port, secret, domain_hex); + } + } else { + warn!("User '{}' listed in [general.links] show not found in [access.users]", user_name); + } + } + info!("------------------------"); +} + #[tokio::main] async fn main() -> std::result::Result<(), Box> { let (config_path, cli_silent, cli_log_level) = parse_cli(); @@ -168,6 +197,10 @@ async fn main() -> std::result::Result<(), Box> { warn!("Using default tls_domain. Consider setting a custom domain."); } + for w in &config.warnings { + warn!("{}", w); + } + let prefer_ipv6 = config.general.prefer_ipv6; let use_middle_proxy = config.general.use_middle_proxy; let config = Arc::new(config); @@ -390,35 +423,12 @@ async fn main() -> std::result::Result<(), Box> { listener_conf.ip }; - if !config.show_link.is_empty() { - info!("--- Proxy Links ({}) ---", public_ip); - for user_name in &config.show_link { - if let Some(secret) = config.access.users.get(user_name) { - info!("User: {}", user_name); - if config.general.modes.classic { - info!( - " Classic: tg://proxy?server={}&port={}&secret={}", - public_ip, config.server.port, secret - ); - } - if config.general.modes.secure { - info!( - " DD: tg://proxy?server={}&port={}&secret=dd{}", - public_ip, config.server.port, secret - ); - } - if config.general.modes.tls { - let domain_hex = hex::encode(&config.censorship.tls_domain); - info!( - " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", - public_ip, config.server.port, secret, domain_hex - ); - } - } else { - warn!("User '{}' in show_link not found", user_name); - } - } - info!("------------------------"); + // Per-listener links (only when public_host is NOT set) + let links = &config.general.links; + if links.public_host.is_none() && !links.show.is_empty() { + let link_host = public_ip.to_string(); + let link_port = links.public_port.unwrap_or(config.server.port); + print_proxy_links(&link_host, link_port, &config); } listeners.push(listener); @@ -429,9 +439,109 @@ async fn main() -> std::result::Result<(), Box> { } } + // Unix socket listener + #[cfg(unix)] + let unix_sock_path = if let Some(ref unix_path) = config.server.listen_unix_sock { + match create_unix_listener(unix_path) { + Ok(std_listener) => { + // Set socket file permissions if configured + if let Some(ref perm_str) = config.server.listen_unix_sock_perm { + if let Ok(mode) = u32::from_str_radix(perm_str, 8) { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions( + unix_path, + std::fs::Permissions::from_mode(mode), + )?; + } + } + + let unix_listener = UnixListener::from_std(std_listener)?; + info!("Listening on unix:{}", unix_path); + + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let me_pool = me_pool.clone(); + + let unix_conn_counter = std::sync::Arc::new( + std::sync::atomic::AtomicU64::new(1) + ); + + tokio::spawn(async move { + loop { + match unix_listener.accept().await { + Ok((stream, _unix_addr)) => { + let conn_id = unix_conn_counter.fetch_add( + 1, std::sync::atomic::Ordering::Relaxed + ); + let fake_peer = SocketAddr::from(([127, 0, 0, 1], conn_id as u16)); + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let me_pool = me_pool.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_client_stream( + stream, fake_peer, config, stats, + upstream_manager, replay_checker, buffer_pool, rng, + me_pool, + ).await { + debug!(error = %e, "Unix socket connection error"); + } + }); + } + Err(e) => { + error!("Unix socket accept error: {}", e); + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + } + }); + + Some(unix_path.clone()) + } + Err(e) => { + error!("Failed to bind to unix:{}: {}", unix_path, e); + None + } + } + } else { + None + }; + + // Links with explicit public_host (independent of TCP listeners) + let links = &config.general.links; + if let Some(ref public_host) = links.public_host { + if !links.show.is_empty() { + let link_port = links.public_port.unwrap_or(config.server.port); + print_proxy_links(public_host, link_port, &config); + } + } + + // Warn if links were configured but couldn't be shown + // (no TCP listeners succeeded and no public_host set) + let links = &config.general.links; + if listeners.is_empty() && links.public_host.is_none() && !links.show.is_empty() { + warn!("Proxy links not shown: no TCP listeners bound. Set [general.links] public_host or fix listener errors above."); + } + if listeners.is_empty() { - error!("No listeners. Exiting."); - std::process::exit(1); + #[cfg(unix)] + if unix_sock_path.is_none() { + error!("No listeners. Exiting."); + std::process::exit(1); + } + #[cfg(not(unix))] + { + error!("No listeners. Exiting."); + std::process::exit(1); + } } // Switch to user-configured log level after startup @@ -494,7 +604,13 @@ async fn main() -> std::result::Result<(), Box> { } match signal::ctrl_c().await { - Ok(()) => info!("Shutting down..."), + Ok(()) => { + info!("Shutting down..."); + #[cfg(unix)] + if let Some(ref path) = unix_sock_path { + cleanup_unix_socket(path); + } + } Err(e) => error!("Signal error: {}", e), } diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 726d238..4ea86d3 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -23,6 +23,149 @@ use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle use crate::proxy::masking::handle_bad_client; use crate::proxy::middle_relay::handle_via_middle_proxy; +/// Handle a client connection from any stream type (TCP, Unix socket) +/// +/// This is the generic entry point for client handling. Unlike `ClientHandler::new().run()`, +/// it skips TCP-specific socket configuration (TCP_NODELAY, keepalive, TCP_USER_TIMEOUT) +/// which is appropriate for non-TCP streams like Unix sockets. +pub async fn handle_client_stream( + mut stream: S, + peer: SocketAddr, + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + me_pool: Option>, +) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + stats.increment_connects_all(); + debug!(peer = %peer, "New connection (generic stream)"); + + let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); + let stats_for_timeout = stats.clone(); + + // For non-TCP streams, use a synthetic local address + let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) + .parse() + .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + + let result = timeout(handshake_timeout, async { + let mut first_bytes = [0u8; 5]; + stream.read_exact(&mut first_bytes).await?; + + let is_tls = tls::is_tls_handshake(&first_bytes[..3]); + debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + + if is_tls { + let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; + + if tls_len < 512 { + debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + stats.increment_connects_bad(); + let (reader, writer) = tokio::io::split(stream); + handle_bad_client(reader, writer, &first_bytes, &config).await; + return Ok(()); + } + + let mut handshake = vec![0u8; 5 + tls_len]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + let (read_half, write_half) = tokio::io::split(stream); + + let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( + &handshake, read_half, write_half, peer, + &config, &replay_checker, &rng, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(()); + } + HandshakeResult::Error(e) => return Err(e), + }; + + debug!(peer = %peer, "Reading MTProto handshake through TLS"); + let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; + let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() + .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &mtproto_handshake, tls_reader, tls_writer, peer, + &config, &replay_checker, true, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader: _, writer: _ } => { + stats.increment_connects_bad(); + debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + return Ok(()); + } + HandshakeResult::Error(e) => return Err(e), + }; + + RunningClientHandler::handle_authenticated_static( + crypto_reader, crypto_writer, success, + upstream_manager, stats, config, buffer_pool, rng, me_pool, + local_addr, + ).await + } else { + if !config.general.modes.classic && !config.general.modes.secure { + debug!(peer = %peer, "Non-TLS modes disabled"); + stats.increment_connects_bad(); + let (reader, writer) = tokio::io::split(stream); + handle_bad_client(reader, writer, &first_bytes, &config).await; + return Ok(()); + } + + let mut handshake = [0u8; HANDSHAKE_LEN]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + let (read_half, write_half) = tokio::io::split(stream); + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &handshake, read_half, write_half, peer, + &config, &replay_checker, false, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(()); + } + HandshakeResult::Error(e) => return Err(e), + }; + + RunningClientHandler::handle_authenticated_static( + crypto_reader, crypto_writer, success, + upstream_manager, stats, config, buffer_pool, rng, me_pool, + local_addr, + ).await + } + }).await; + + match result { + Ok(Ok(())) => { + debug!(peer = %peer, "Connection handled successfully"); + Ok(()) + } + Ok(Err(e)) => { + debug!(peer = %peer, error = %e, "Handshake failed"); + Err(e) + } + Err(_) => { + stats_for_timeout.increment_handshake_timeouts(); + debug!(peer = %peer, "Handshake timeout"); + Err(ProxyError::TgHandshakeTimeout) + } + } +} + pub struct ClientHandler; pub struct RunningClientHandler { @@ -267,9 +410,9 @@ impl RunningClientHandler { /// Main dispatch after successful handshake. /// Two modes: - /// - Direct: TCP relay to TG DC (existing behavior) + /// - Direct: TCP relay to TG DC (existing behavior) /// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs) - async fn handle_authenticated_static( + pub(crate) async fn handle_authenticated_static( client_reader: CryptoReader, client_writer: CryptoWriter, success: HandshakeSuccess, diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 92dd373..7468448 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -2,10 +2,12 @@ pub mod handshake; pub mod client; +pub(crate) mod direct_relay; +pub(crate) mod middle_relay; pub mod relay; pub mod masking; pub use handshake::*; -pub use client::ClientHandler; +pub use client::{ClientHandler, handle_client_stream}; pub use relay::*; -pub use masking::*; \ No newline at end of file +pub use masking::*; diff --git a/src/transport/middle_proxy.rs b/src/transport/middle_proxy.rs deleted file mode 100644 index 3b08c3c..0000000 --- a/src/transport/middle_proxy.rs +++ /dev/null @@ -1,925 +0,0 @@ -//! Middle Proxy RPC Transport -//! -//! Implements Telegram Middle-End RPC protocol for routing to ALL DCs (including CDN). -//! -//! ## Phase 3 fixes: -//! - ROOT CAUSE: Use Telegram proxy-secret (binary file) not user secret -//! - Streaming handshake response (no fixed-size read deadlock) -//! - Health monitoring + reconnection -//! - Hex diagnostics for debugging - -use std::collections::HashMap; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::Duration; -use bytes::{Bytes, BytesMut}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; -use tokio::sync::{mpsc, Mutex, RwLock}; -use tokio::time::{timeout, Instant}; -use tracing::{debug, info, trace, warn, error}; - -use crate::crypto::{crc32, derive_middleproxy_keys, AesCbc, SecureRandom}; -use crate::error::{ProxyError, Result}; -use crate::protocol::constants::*; - -// ========== Proxy Secret Fetching ========== - -/// Fetch the Telegram proxy-secret binary file. -/// -/// This is NOT the user secret (-S flag, 16 bytes hex for clients). -/// This is the infrastructure secret (--aes-pwd in C MTProxy), -/// a binary file of 32-512 bytes used for ME RPC key derivation. -/// -/// Strategy: try local cache, then download from Telegram. -pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result> { - let cache = cache_path.unwrap_or("proxy-secret"); - - // 1. Try local cache (< 24h old) - if let Ok(metadata) = tokio::fs::metadata(cache).await { - if let Ok(modified) = metadata.modified() { - let age = std::time::SystemTime::now() - .duration_since(modified) - .unwrap_or(Duration::from_secs(u64::MAX)); - if age < Duration::from_secs(86400) { - if let Ok(data) = tokio::fs::read(cache).await { - if data.len() >= 32 { - info!( - path = cache, - len = data.len(), - age_hours = age.as_secs() / 3600, - "Loaded proxy-secret from cache" - ); - return Ok(data); - } - warn!(path = cache, len = data.len(), "Cached proxy-secret too short"); - } - } - } - } - - // 2. Download from Telegram - info!("Downloading proxy-secret from core.telegram.org..."); - let data = download_proxy_secret().await?; - - // 3. Cache locally (best-effort) - if let Err(e) = tokio::fs::write(cache, &data).await { - warn!(error = %e, "Failed to cache proxy-secret (non-fatal)"); - } else { - debug!(path = cache, len = data.len(), "Cached proxy-secret"); - } - - Ok(data) -} - -async fn download_proxy_secret() -> Result> { - let url = "https://core.telegram.org/getProxySecret"; - let resp = reqwest::get(url) - .await - .map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {}", e)))?; - - if !resp.status().is_success() { - return Err(ProxyError::Proxy(format!( - "proxy-secret download HTTP {}", resp.status() - ))); - } - - let data = resp.bytes().await - .map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {}", e)))? - .to_vec(); - - if data.len() < 32 { - return Err(ProxyError::Proxy(format!( - "proxy-secret too short: {} bytes (need >= 32)", data.len() - ))); - } - - info!(len = data.len(), "Downloaded proxy-secret OK"); - Ok(data) -} - -// ========== RPC Frame helpers ========== - -/// Build an RPC frame: [len(4) | seq_no(4) | payload | crc32(4)] -fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec { - let total_len = (4 + 4 + payload.len() + 4) as u32; - let mut f = Vec::with_capacity(total_len as usize); - f.extend_from_slice(&total_len.to_le_bytes()); - f.extend_from_slice(&seq_no.to_le_bytes()); - f.extend_from_slice(payload); - let c = crc32(&f); - f.extend_from_slice(&c.to_le_bytes()); - f -} - -/// Read one plaintext RPC frame. Returns (seq_no, payload). -async fn read_rpc_frame_plaintext( - rd: &mut (impl AsyncReadExt + Unpin), -) -> Result<(i32, Vec)> { - let mut len_buf = [0u8; 4]; - rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?; - let total_len = u32::from_le_bytes(len_buf) as usize; - - if total_len < 12 || total_len > (1 << 24) { - return Err(ProxyError::InvalidHandshake( - format!("Bad RPC frame length: {}", total_len), - )); - } - - let mut rest = vec![0u8; total_len - 4]; - rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?; - - let mut full = Vec::with_capacity(total_len); - full.extend_from_slice(&len_buf); - full.extend_from_slice(&rest); - - let crc_offset = total_len - 4; - let expected_crc = u32::from_le_bytes([ - full[crc_offset], full[crc_offset + 1], - full[crc_offset + 2], full[crc_offset + 3], - ]); - let actual_crc = crc32(&full[..crc_offset]); - if expected_crc != actual_crc { - return Err(ProxyError::InvalidHandshake( - format!("CRC mismatch: 0x{:08x} vs 0x{:08x}", expected_crc, actual_crc), - )); - } - - let seq_no = i32::from_le_bytes([full[4], full[5], full[6], full[7]]); - let payload = full[8..crc_offset].to_vec(); - Ok((seq_no, payload)) -} - -// ========== RPC Nonce (32 bytes payload) ========== - -fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] { - let mut p = [0u8; 32]; - p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes()); - p[4..8].copy_from_slice(&key_selector.to_le_bytes()); - p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes()); - p[12..16].copy_from_slice(&crypto_ts.to_le_bytes()); - p[16..32].copy_from_slice(nonce); - p -} - -fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, [u8; 16])> { - if d.len() < 32 { - return Err(ProxyError::InvalidHandshake( - format!("Nonce payload too short: {} bytes", d.len()), - )); - } - let t = u32::from_le_bytes([d[0], d[1], d[2], d[3]]); - if t != RPC_NONCE_U32 { - return Err(ProxyError::InvalidHandshake( - format!("Expected RPC_NONCE 0x{:08x}, got 0x{:08x}", RPC_NONCE_U32, t), - )); - } - let schema = u32::from_le_bytes([d[8], d[9], d[10], d[11]]); - let ts = u32::from_le_bytes([d[12], d[13], d[14], d[15]]); - let mut nonce = [0u8; 16]; - nonce.copy_from_slice(&d[16..32]); - Ok((schema, ts, nonce)) -} - -// ========== RPC Handshake (32 bytes payload) ========== - -fn build_handshake_payload(our_ip: u32, our_port: u16, peer_ip: u32, peer_port: u16) -> [u8; 32] { - let mut p = [0u8; 32]; - p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes()); - // flags = 0 at offset 4..8 - - // sender_pid: {ip(4), port(2), pid(2), utime(4)} at offset 8..20 - p[8..12].copy_from_slice(&our_ip.to_le_bytes()); - p[12..14].copy_from_slice(&our_port.to_le_bytes()); - let pid = (std::process::id() & 0xFFFF) as u16; - p[14..16].copy_from_slice(&pid.to_le_bytes()); - let utime = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as u32; - p[16..20].copy_from_slice(&utime.to_le_bytes()); - - // peer_pid: {ip(4), port(2), pid(2), utime(4)} at offset 20..32 - p[20..24].copy_from_slice(&peer_ip.to_le_bytes()); - p[24..26].copy_from_slice(&peer_port.to_le_bytes()); - p -} - -// ========== CBC helpers ========== - -fn cbc_encrypt_padded(key: &[u8; 32], iv: &[u8; 16], plaintext: &[u8]) -> Result<(Vec, [u8; 16])> { - let pad = (16 - (plaintext.len() % 16)) % 16; - let mut buf = plaintext.to_vec(); - let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; - for i in 0..pad { - buf.push(pad_pattern[i % 4]); - } - let cipher = AesCbc::new(*key, *iv); - cipher.encrypt_in_place(&mut buf) - .map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {}", e)))?; - let mut new_iv = [0u8; 16]; - if buf.len() >= 16 { - new_iv.copy_from_slice(&buf[buf.len() - 16..]); - } - Ok((buf, new_iv)) -} - -fn cbc_decrypt_inplace(key: &[u8; 32], iv: &[u8; 16], data: &mut [u8]) -> Result<[u8; 16]> { - let mut new_iv = [0u8; 16]; - if data.len() >= 16 { - new_iv.copy_from_slice(&data[data.len() - 16..]); - } - AesCbc::new(*key, *iv) - .decrypt_in_place(data) - .map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {}", e)))?; - Ok(new_iv) -} - -// ========== IPv4 helpers ========== - -fn ipv4_to_mapped_v6(ip: Ipv4Addr) -> [u8; 16] { - let mut buf = [0u8; 16]; - buf[10] = 0xFF; - buf[11] = 0xFF; - let o = ip.octets(); - buf[12] = o[0]; buf[13] = o[1]; buf[14] = o[2]; buf[15] = o[3]; - buf -} - -fn addr_to_ip_u32(addr: &SocketAddr) -> u32 { - match addr.ip() { - IpAddr::V4(v4) => u32::from_be_bytes(v4.octets()), - IpAddr::V6(v6) => { - if let Some(v4) = v6.to_ipv4_mapped() { - u32::from_be_bytes(v4.octets()) - } else { 0 } - } - } -} - -// ========== ME Response ========== - -#[derive(Debug)] -pub enum MeResponse { - Data(Bytes), - Ack(u32), - Close, -} - -// ========== Connection Registry ========== - -pub struct ConnRegistry { - map: RwLock>>, - next_id: AtomicU64, -} - -impl ConnRegistry { - pub fn new() -> Self { - Self { - map: RwLock::new(HashMap::new()), - next_id: AtomicU64::new(1), - } - } - pub async fn register(&self) -> (u64, mpsc::Receiver) { - let id = self.next_id.fetch_add(1, Ordering::Relaxed); - let (tx, rx) = mpsc::channel(256); - self.map.write().await.insert(id, tx); - (id, rx) - } - pub async fn unregister(&self, id: u64) { - self.map.write().await.remove(&id); - } - pub async fn route(&self, id: u64, resp: MeResponse) -> bool { - let m = self.map.read().await; - if let Some(tx) = m.get(&id) { - tx.send(resp).await.is_ok() - } else { false } - } -} - -// ========== RPC Writer (streaming CBC) ========== - -struct RpcWriter { - writer: tokio::io::WriteHalf, - key: [u8; 32], - iv: [u8; 16], - seq_no: i32, -} - -impl RpcWriter { - async fn send(&mut self, payload: &[u8]) -> Result<()> { - let frame = build_rpc_frame(self.seq_no, payload); - self.seq_no += 1; - - let pad = (16 - (frame.len() % 16)) % 16; - let mut buf = frame; - let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; - for i in 0..pad { - buf.push(pad_pattern[i % 4]); - } - - let cipher = AesCbc::new(self.key, self.iv); - cipher.encrypt_in_place(&mut buf) - .map_err(|e| ProxyError::Crypto(format!("{}", e)))?; - - if buf.len() >= 16 { - self.iv.copy_from_slice(&buf[buf.len() - 16..]); - } - self.writer.write_all(&buf).await.map_err(ProxyError::Io) - } -} - -// ========== RPC_PROXY_REQ ========== - - -fn build_proxy_req_payload( - conn_id: u64, - client_addr: SocketAddr, - our_addr: SocketAddr, - data: &[u8], - proxy_tag: Option<&[u8]>, - proto_flags: u32, -) -> Vec { - // flags are pre-calculated by proto_flags_for_tag - // We just need to ensure FLAG_HAS_AD_TAG is set if we have a tag (it is set by default in our new function, but let's be safe) - let mut flags = proto_flags; - - // The C code logic: - // flags = (transport_flags) | 0x1000 | 0x20000 | 0x8 (if tag) - // Our proto_flags_for_tag returns: 0x8 | 0x1000 | 0x20000 | transport_flags - // So we are good. - - let b_cap = 128 + data.len(); - let mut b = Vec::with_capacity(b_cap); - - b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes()); - b.extend_from_slice(&flags.to_le_bytes()); - b.extend_from_slice(&conn_id.to_le_bytes()); - - // Client IP (16 bytes IPv4-mapped-v6) + port (4 bytes) - match client_addr.ip() { - IpAddr::V4(v4) => b.extend_from_slice(&ipv4_to_mapped_v6(v4)), - IpAddr::V6(v6) => b.extend_from_slice(&v6.octets()), - } - b.extend_from_slice(&(client_addr.port() as u32).to_le_bytes()); - - // Our IP (16 bytes) + port (4 bytes) - match our_addr.ip() { - IpAddr::V4(v4) => b.extend_from_slice(&ipv4_to_mapped_v6(v4)), - IpAddr::V6(v6) => b.extend_from_slice(&v6.octets()), - } - b.extend_from_slice(&(our_addr.port() as u32).to_le_bytes()); - - // Extra section (proxy_tag) - if flags & 12 != 0 { - let extra_start = b.len(); - b.extend_from_slice(&0u32.to_le_bytes()); // placeholder - - if let Some(tag) = proxy_tag { - b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes()); - // TL string encoding - if tag.len() < 254 { - b.push(tag.len() as u8); - b.extend_from_slice(tag); - let pad = (4 - ((1 + tag.len()) % 4)) % 4; - b.extend(std::iter::repeat(0u8).take(pad)); - } else { - b.push(0xfe); - let len_bytes = (tag.len() as u32).to_le_bytes(); - b.extend_from_slice(&len_bytes[..3]); - b.extend_from_slice(tag); - let pad = (4 - (tag.len() % 4)) % 4; - b.extend(std::iter::repeat(0u8).take(pad)); - } - } - - let extra_bytes = (b.len() - extra_start - 4) as u32; - let eb = extra_bytes.to_le_bytes(); - b[extra_start..extra_start + 4].copy_from_slice(&eb); - } - - b.extend_from_slice(data); - b -} - -// ========== ME Pool ========== - -pub struct MePool { - registry: Arc, - writers: Arc>>>>, - rr: AtomicU64, - proxy_tag: Option>, - /// Telegram proxy-secret (binary, 32-512 bytes) - proxy_secret: Vec, - pool_size: usize, -} - -impl MePool { - pub fn new(proxy_tag: Option>, proxy_secret: Vec) -> Arc { - Arc::new(Self { - registry: Arc::new(ConnRegistry::new()), - writers: Arc::new(RwLock::new(Vec::new())), - rr: AtomicU64::new(0), - proxy_tag, - proxy_secret, - pool_size: 2, - }) - } - - pub fn registry(&self) -> &Arc { - &self.registry - } - - fn writers_arc(&self) -> Arc>>>> { - self.writers.clone() - } - - /// key_selector = first 4 bytes of proxy-secret as LE u32 - /// C: main_secret.key_signature via union { char secret[]; int key_signature; } - fn key_selector(&self) -> u32 { - if self.proxy_secret.len() >= 4 { - u32::from_le_bytes([ - self.proxy_secret[0], self.proxy_secret[1], - self.proxy_secret[2], self.proxy_secret[3], - ]) - } else { 0 } - } - - pub async fn init( - self: &Arc, - pool_size: usize, - rng: &SecureRandom, - ) -> Result<()> { - let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4; - let ks = self.key_selector(); - info!( - me_servers = addrs.len(), - pool_size, - key_selector = format_args!("0x{:08x}", ks), - secret_len = self.proxy_secret.len(), - "Initializing ME pool" - ); - - for &(ip, port) in addrs.iter() { - for i in 0..pool_size { - let addr = SocketAddr::new(ip, port); - match self.connect_one(addr, rng).await { - Ok(()) => info!(%addr, idx = i, "ME connected"), - Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"), - } - } - if self.writers.read().await.len() >= pool_size { - break; - } - } - - if self.writers.read().await.is_empty() { - return Err(ProxyError::Proxy("No ME connections".into())); - } - Ok(()) - } - - async fn connect_one( - self: &Arc, - addr: SocketAddr, - rng: &SecureRandom, - ) -> Result<()> { - let secret = &self.proxy_secret; - if secret.len() < 32 { - return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); - } - - // ===== TCP connect ===== - let stream = timeout( - Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), - TcpStream::connect(addr), - ) - .await - .map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })? - .map_err(ProxyError::Io)?; - stream.set_nodelay(true).ok(); - - let local_addr = stream.local_addr().map_err(ProxyError::Io)?; - let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; - let (mut rd, mut wr) = tokio::io::split(stream); - - // ===== 1. Send RPC nonce (plaintext, seq=-2) ===== - let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); - let crypto_ts = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as u32; - let ks = self.key_selector(); - - let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); - let nonce_frame = build_rpc_frame(-2, &nonce_payload); - - debug!( - %addr, - frame_len = nonce_frame.len(), - key_sel = format_args!("0x{:08x}", ks), - crypto_ts, - "Sending nonce" - ); - - wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; - wr.flush().await.map_err(ProxyError::Io)?; - - // ===== 2. Read server nonce (plaintext, seq=-2) ===== - let (srv_seq, srv_nonce_payload) = timeout( - Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), - read_rpc_frame_plaintext(&mut rd), - ) - .await - .map_err(|_| ProxyError::TgHandshakeTimeout)??; - - if srv_seq != -2 { - return Err(ProxyError::InvalidHandshake( - format!("Expected seq=-2, got {}", srv_seq), - )); - } - - let (schema, _srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; - if schema != RPC_CRYPTO_AES_U32 { - return Err(ProxyError::InvalidHandshake( - format!("Unsupported crypto schema: 0x{:x}", schema), - )); - } - - debug!(%addr, "Nonce exchange OK, deriving keys"); - - // ===== 3. Derive AES-256-CBC keys ===== - // C buffer layout: - // [0..16] nonce_server (srv_nonce) - // [16..32] nonce_client (my_nonce) - // [32..36] client_timestamp - // [36..40] server_ip - // [40..42] client_port - // [42..48] "CLIENT" or "SERVER" - // [48..52] client_ip - // [52..54] server_port - // [54..54+N] secret (proxy-secret binary) - // [54+N..70+N] nonce_server - // nonce_client(16) - - let ts_bytes = crypto_ts.to_le_bytes(); - let server_ip = addr_to_ip_u32(&peer_addr); - let client_ip = addr_to_ip_u32(&local_addr); - let server_ip_bytes = server_ip.to_le_bytes(); - let client_ip_bytes = client_ip.to_le_bytes(); - let server_port_bytes = peer_addr.port().to_le_bytes(); - let client_port_bytes = local_addr.port().to_le_bytes(); - - let (wk, wi) = derive_middleproxy_keys( - &srv_nonce, &my_nonce, &ts_bytes, - Some(&server_ip_bytes), &client_port_bytes, - b"CLIENT", - Some(&client_ip_bytes), &server_port_bytes, - secret, None, None, - ); - let (rk, ri) = derive_middleproxy_keys( - &srv_nonce, &my_nonce, &ts_bytes, - Some(&server_ip_bytes), &client_port_bytes, - b"SERVER", - Some(&client_ip_bytes), &server_port_bytes, - secret, None, None, - ); - - debug!( - %addr, - write_key = %hex::encode(&wk[..8]), - read_key = %hex::encode(&rk[..8]), - "Keys derived" - ); - - // ===== 4. Send encrypted handshake (seq=-1) ===== - let hs_payload = build_handshake_payload( - client_ip, local_addr.port(), - server_ip, peer_addr.port(), - ); - let hs_frame = build_rpc_frame(-1, &hs_payload); - let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; - wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; - wr.flush().await.map_err(ProxyError::Io)?; - - debug!(%addr, enc_len = encrypted_hs.len(), "Sent encrypted handshake"); - - // ===== 5. Read encrypted handshake response (STREAMING) ===== - // Server sends encrypted handshake. C crypto layer may send partial - // blocks (only complete 16-byte blocks get encrypted at a time). - // We read incrementally and decrypt block-by-block. - let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); - let mut enc_buf = BytesMut::with_capacity(256); - let mut dec_buf = BytesMut::with_capacity(256); - let mut read_iv = ri; - let mut handshake_ok = false; - - while Instant::now() < deadline && !handshake_ok { - let remaining = deadline - Instant::now(); - let mut tmp = [0u8; 256]; - let n = match timeout(remaining, rd.read(&mut tmp)).await { - Ok(Ok(0)) => return Err(ProxyError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, "ME closed during handshake", - ))), - Ok(Ok(n)) => n, - Ok(Err(e)) => return Err(ProxyError::Io(e)), - Err(_) => return Err(ProxyError::TgHandshakeTimeout), - }; - enc_buf.extend_from_slice(&tmp[..n]); - - // Decrypt complete 16-byte blocks - let blocks = enc_buf.len() / 16 * 16; - if blocks > 0 { - let mut chunk = vec![0u8; blocks]; - chunk.copy_from_slice(&enc_buf[..blocks]); - let new_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; - read_iv = new_iv; - dec_buf.extend_from_slice(&chunk); - let _ = enc_buf.split_to(blocks); - } - - // Try to parse RPC frame from decrypted data - while dec_buf.len() >= 4 { - let fl = u32::from_le_bytes([ - dec_buf[0], dec_buf[1], dec_buf[2], dec_buf[3], - ]) as usize; - - // Skip noop padding - if fl == 4 { - let _ = dec_buf.split_to(4); - continue; - } - if fl < 12 || fl > (1 << 24) { - return Err(ProxyError::InvalidHandshake( - format!("Bad HS response frame len: {}", fl), - )); - } - if dec_buf.len() < fl { - break; // need more data - } - - let frame = dec_buf.split_to(fl); - - // CRC32 check - let pe = fl - 4; - let ec = u32::from_le_bytes([ - frame[pe], frame[pe + 1], frame[pe + 2], frame[pe + 3], - ]); - let ac = crc32(&frame[..pe]); - if ec != ac { - return Err(ProxyError::InvalidHandshake( - format!("HS CRC mismatch: 0x{:08x} vs 0x{:08x}", ec, ac), - )); - } - - // Check type - let hs_type = u32::from_le_bytes([ - frame[8], frame[9], frame[10], frame[11], - ]); - if hs_type == RPC_HANDSHAKE_ERROR_U32 { - let err_code = if frame.len() >= 16 { - i32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]]) - } else { -1 }; - return Err(ProxyError::InvalidHandshake( - format!("ME rejected handshake (error={})", err_code), - )); - } - if hs_type != RPC_HANDSHAKE_U32 { - return Err(ProxyError::InvalidHandshake( - format!("Expected HANDSHAKE 0x{:08x}, got 0x{:08x}", RPC_HANDSHAKE_U32, hs_type), - )); - } - - handshake_ok = true; - break; - } - } - - if !handshake_ok { - return Err(ProxyError::TgHandshakeTimeout); - } - - info!(%addr, "RPC handshake OK"); - - // ===== 6. Setup writer + reader ===== - let rpc_w = Arc::new(Mutex::new(RpcWriter { - writer: wr, - key: wk, - iv: write_iv, - seq_no: 0, - })); - self.writers.write().await.push(rpc_w.clone()); - - let reg = self.registry.clone(); - let w_pong = rpc_w.clone(); - let w_pool = self.writers_arc(); - tokio::spawn(async move { - if let Err(e) = reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await { - warn!(error = %e, "ME reader ended"); - } - // Remove dead writer from pool - let mut ws = w_pool.write().await; - ws.retain(|w| !Arc::ptr_eq(w, &w_pong)); - info!(remaining = ws.len(), "Dead ME writer removed from pool"); - }); - - Ok(()) - } - - pub async fn send_proxy_req( - &self, - conn_id: u64, - client_addr: SocketAddr, - our_addr: SocketAddr, - data: &[u8], - proto_flags: u32, - ) -> Result<()> { - let payload = build_proxy_req_payload( - conn_id, client_addr, our_addr, data, - self.proxy_tag.as_deref(), proto_flags, - ); - loop { - let ws = self.writers.read().await; - if ws.is_empty() { - return Err(ProxyError::Proxy("All ME connections dead".into())); - } - let idx = self.rr.fetch_add(1, Ordering::Relaxed) as usize % ws.len(); - let w = ws[idx].clone(); - drop(ws); - match w.lock().await.send(&payload).await { - Ok(()) => return Ok(()), - Err(e) => { - warn!(error = %e, "ME write failed, removing dead conn"); - let mut ws = self.writers.write().await; - ws.retain(|o| !Arc::ptr_eq(o, &w)); - if ws.is_empty() { - return Err(ProxyError::Proxy("All ME connections dead".into())); - } - } - } - } - } - - pub async fn send_close(&self, conn_id: u64) -> Result<()> { - let ws = self.writers.read().await; - if !ws.is_empty() { - let w = ws[0].clone(); - drop(ws); - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); - p.extend_from_slice(&conn_id.to_le_bytes()); - if let Err(e) = w.lock().await.send(&p).await { - debug!(error = %e, "ME close write failed"); - let mut ws = self.writers.write().await; - ws.retain(|o| !Arc::ptr_eq(o, &w)); - } - } - self.registry.unregister(conn_id).await; - Ok(()) - } - - pub fn connection_count(&self) -> usize { - self.writers.try_read().map(|w| w.len()).unwrap_or(0) - } -} - -// ========== Reader Loop ========== - -async fn reader_loop( - mut rd: tokio::io::ReadHalf, - dk: [u8; 32], - mut div: [u8; 16], - reg: Arc, - mut enc_leftover: BytesMut, - mut dec: BytesMut, - writer: Arc>, -) -> Result<()> { - let mut raw = enc_leftover; - loop { - let mut tmp = [0u8; 16384]; - let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?; - if n == 0 { return Ok(()); } - raw.extend_from_slice(&tmp[..n]); - - // Decrypt complete 16-byte blocks - let blocks = raw.len() / 16 * 16; - if blocks > 0 { - let mut new_iv = [0u8; 16]; - new_iv.copy_from_slice(&raw[blocks - 16..blocks]); - let mut chunk = vec![0u8; blocks]; - chunk.copy_from_slice(&raw[..blocks]); - AesCbc::new(dk, div) - .decrypt_in_place(&mut chunk) - .map_err(|e| ProxyError::Crypto(format!("{}", e)))?; - div = new_iv; - dec.extend_from_slice(&chunk); - let _ = raw.split_to(blocks); - } - - // Parse RPC frames - while dec.len() >= 12 { - let fl = u32::from_le_bytes([dec[0], dec[1], dec[2], dec[3]]) as usize; - if fl == 4 { let _ = dec.split_to(4); continue; } - if fl < 12 || fl > (1 << 24) { - warn!(frame_len = fl, "Invalid RPC frame len"); - dec.clear(); - break; - } - if dec.len() < fl { break; } - - let frame = dec.split_to(fl); - let pe = fl - 4; - let ec = u32::from_le_bytes([frame[pe], frame[pe+1], frame[pe+2], frame[pe+3]]); - if crc32(&frame[..pe]) != ec { - warn!("CRC mismatch in data frame"); - continue; - } - - let payload = &frame[8..pe]; - if payload.len() < 4 { continue; } - let pt = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]); - let body = &payload[4..]; - - if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 { - let flags = u32::from_le_bytes(body[0..4].try_into().unwrap()); - let cid = u64::from_le_bytes(body[4..12].try_into().unwrap()); - let data = Bytes::copy_from_slice(&body[12..]); - trace!(cid, len = data.len(), flags, "ANS"); - reg.route(cid, MeResponse::Data(data)).await; - } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { - let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); - let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap()); - trace!(cid, cfm, "ACK"); - reg.route(cid, MeResponse::Ack(cfm)).await; - } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { - let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); - debug!(cid, "CLOSE_EXT from ME"); - reg.route(cid, MeResponse::Close).await; - reg.unregister(cid).await; - } else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 { - let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); - debug!(cid, "CLOSE_CONN from ME"); - reg.route(cid, MeResponse::Close).await; - reg.unregister(cid).await; - } else if pt == RPC_PING_U32 && body.len() >= 8 { - let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); - trace!(ping_id, "RPC_PING -> PONG"); - let mut pong = Vec::with_capacity(12); - pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); - pong.extend_from_slice(&ping_id.to_le_bytes()); - if let Err(e) = writer.lock().await.send(&pong).await { - warn!(error = %e, "PONG send failed"); - break; - } - } else { - debug!(rpc_type = format_args!("0x{:08x}", pt), len = body.len(), "Unknown RPC"); - } - } - } -} - -// ========== Proto flags ========== - -/// Map ProtoTag to C-compatible RPC_PROXY_REQ transport flags. -/// C: RPC_F_COMPACT(0x40000000)=abridged, RPC_F_MEDIUM(0x20000000)=intermediate/secure -/// The 0x1000(magic) and 0x8(proxy_tag) are added inside build_proxy_req_payload. - -pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag) -> u32 { - use crate::protocol::constants::*; - let mut flags = RPC_FLAG_HAS_AD_TAG | RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2; - match tag { - ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED, - ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE, - ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE, - } -} - - -// ========== Health Monitor (Phase 4) ========== - -pub async fn me_health_monitor( - pool: Arc, - rng: Arc, - min_connections: usize, -) { - loop { - tokio::time::sleep(Duration::from_secs(30)).await; - let current = pool.writers.read().await.len(); - if current < min_connections { - warn!(current, min = min_connections, "ME pool below minimum, reconnecting..."); - let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone(); - for &(ip, port) in addrs.iter() { - let needed = min_connections.saturating_sub(pool.writers.read().await.len()); - if needed == 0 { break; } - for _ in 0..needed { - let addr = SocketAddr::new(ip, port); - match pool.connect_one(addr, &rng).await { - Ok(()) => info!(%addr, "ME reconnected"), - Err(e) => debug!(%addr, error = %e, "ME reconnect failed"), - } - } - } - } - } -} diff --git a/src/transport/socket.rs b/src/transport/socket.rs index a07c21c..f440f44 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -202,6 +202,51 @@ pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result Result { + use std::os::unix::net::UnixListener; + use std::path::Path; + + let socket_path = Path::new(path); + + if socket_path.exists() { + match std::os::unix::net::UnixStream::connect(socket_path) { + Ok(_) => { + return Err(std::io::Error::new( + std::io::ErrorKind::AddrInUse, + format!("Unix socket {} is already in use by another process", path) + )); + } + Err(_) => { + debug!("Removing stale Unix socket: {}", path); + std::fs::remove_file(socket_path)?; + } + } + } + + let listener = UnixListener::bind(socket_path)?; + listener.set_nonblocking(true)?; + + debug!("Created Unix socket listener at {}", path); + Ok(listener) +} + +/// Remove Unix socket file on shutdown +#[cfg(unix)] +pub fn cleanup_unix_socket(path: &str) { + if std::path::Path::new(path).exists() { + match std::fs::remove_file(path) { + Ok(_) => debug!("Cleaned up Unix socket: {}", path), + Err(e) => debug!("Failed to remove Unix socket {}: {}", path, e), + } + } +} + #[cfg(test)] mod tests { use super::*;