diff --git a/Cargo.lock b/Cargo.lock index cf52770..6846aba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2793,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.32" +version = "3.3.35" dependencies = [ "aes", "anyhow", @@ -2844,6 +2844,7 @@ dependencies = [ "tokio-util", "toml", "tracing", + "tracing-appender", "tracing-subscriber", "url", "webpki-roots", @@ -3170,6 +3171,18 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf" +dependencies = [ + "crossbeam-channel", + "thiserror 2.0.18", + "time", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.31" diff --git a/Cargo.toml b/Cargo.toml index 62b3b13..a61bccf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.3.32" +version = "3.3.35" edition = "2024" [features] @@ -30,7 +30,13 @@ static_assertions = "1.1" # Network socket2 = { version = "0.6", features = ["all"] } -nix = { version = "0.31", default-features = false, features = ["net", "fs"] } +nix = { version = "0.31", default-features = false, features = [ + "net", + "user", + "process", + "fs", + "signal", +] } shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] } # Serialization @@ -44,6 +50,7 @@ bytes = "1.9" thiserror = "2.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing-appender = "0.2" parking_lot = "0.12" dashmap = "6.1" arc-swap = "1.7" @@ -68,8 +75,14 @@ hyper = { version = "1", features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } http-body-util = "0.1" httpdate = "1.0" -tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] } -rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] } +tokio-rustls = { version = "0.26", default-features = false, features = [ + "tls12", +] } +rustls = { version = "0.23", default-features = false, features = [ + "std", + "tls12", + "ring", +] } webpki-roots = "1.0" [dev-dependencies] diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index eda2435..4d4467d 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -91,6 +91,7 @@ This document lists all configuration keys accepted by `config.toml`. | upstream_connect_retry_attempts | `u32` | `2` | Must be `> 0`. | Connect attempts for selected upstream before error/fallback. | | upstream_connect_retry_backoff_ms | `u64` | `100` | — | Delay between upstream connect attempts (ms). | | upstream_connect_budget_ms | `u64` | `3000` | Must be `> 0`. | Total wall-clock budget for one upstream connect request (ms). | +| tg_connect | `u64` | `10` | Must be `> 0`. | Per-attempt upstream TCP connect timeout to Telegram DC (seconds). | | upstream_unhealthy_fail_threshold | `u32` | `5` | Must be `> 0`. | Consecutive failed requests before upstream is marked unhealthy. | | upstream_connect_failfast_hard_errors | `bool` | `false` | — | Skips additional retries for hard non-transient connect errors. | | stun_iface_mismatch_ignore | `bool` | `false` | none | Reserved compatibility flag in current runtime revision. | @@ -249,7 +250,6 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a | relay_client_idle_soft_secs | `u64` | `120` | Must be `> 0`; must be `<= relay_client_idle_hard_secs`. | Soft idle threshold for middle-relay client uplink inactivity (seconds). | | relay_client_idle_hard_secs | `u64` | `360` | Must be `> 0`; must be `>= relay_client_idle_soft_secs`. | Hard idle threshold for middle-relay client uplink inactivity (seconds). | | relay_idle_grace_after_downstream_activity_secs | `u64` | `30` | Must be `<= relay_client_idle_hard_secs`. | Extra hard-idle grace after recent downstream activity (seconds). | -| tg_connect | `u64` | `10` | — | Upstream Telegram connect timeout. | | client_keepalive | `u64` | `15` | — | Client keepalive timeout. | | client_ack | `u64` | `90` | — | Client ACK timeout. | | me_one_retry | `u8` | `12` | none | Fast reconnect attempts budget for single-endpoint DC scenarios. | diff --git a/src/api/mod.rs b/src/api/mod.rs index c1e3557..e60a375 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -37,11 +37,12 @@ mod runtime_watch; mod runtime_zero; mod users; -use config_store::{current_revision, parse_if_match}; +use config_store::{current_revision, load_config_from_disk, parse_if_match}; use events::ApiEventStore; use http_utils::{error_response, read_json, read_optional_json, success_response}; use model::{ - ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData, + ApiFailure, CreateUserRequest, DeleteUserResponse, HealthData, PatchUserRequest, + RotateSecretRequest, SummaryData, UserActiveIps, }; use runtime_edge::{ EdgeConnectionsCacheEntry, build_runtime_connections_summary_data, @@ -362,15 +363,33 @@ async fn handle( ); Ok(success_response(StatusCode::OK, data, revision)) } + ("GET", "/v1/stats/users/active-ips") => { + let revision = current_revision(&shared.config_path).await?; + let usernames: Vec<_> = cfg.access.users.keys().cloned().collect(); + let active_ips_map = shared.ip_tracker.get_active_ips_for_users(&usernames).await; + let mut data: Vec = active_ips_map + .into_iter() + .filter(|(_, ips)| !ips.is_empty()) + .map(|(username, active_ips)| UserActiveIps { + username, + active_ips, + }) + .collect(); + data.sort_by(|a, b| a.username.cmp(&b.username)); + Ok(success_response(StatusCode::OK, data, revision)) + } ("GET", "/v1/stats/users") | ("GET", "/v1/users") => { let revision = current_revision(&shared.config_path).await?; + let disk_cfg = load_config_from_disk(&shared.config_path).await?; + let runtime_cfg = config_rx.borrow().clone(); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let users = users_from_config( - &cfg, + &disk_cfg, &shared.stats, &shared.ip_tracker, detected_ip_v4, detected_ip_v6, + Some(runtime_cfg.as_ref()), ) .await; Ok(success_response(StatusCode::OK, users, revision)) @@ -389,7 +408,7 @@ async fn handle( let expected_revision = parse_if_match(req.headers()); let body = read_json::(req.into_body(), body_limit).await?; let result = create_user(body, expected_revision, &shared).await; - let (data, revision) = match result { + let (mut data, revision) = match result { Ok(ok) => ok, Err(error) => { shared @@ -398,11 +417,18 @@ async fn handle( return Err(error); } }; + let runtime_cfg = config_rx.borrow().clone(); + data.user.in_runtime = runtime_cfg.access.users.contains_key(&data.user.username); shared.runtime_events.record( "api.user.create.ok", format!("username={}", data.user.username), ); - Ok(success_response(StatusCode::CREATED, data, revision)) + let status = if data.user.in_runtime { + StatusCode::CREATED + } else { + StatusCode::ACCEPTED + }; + Ok(success_response(status, data, revision)) } _ => { if let Some(user) = path.strip_prefix("/v1/users/") @@ -411,13 +437,16 @@ async fn handle( { if method == Method::GET { let revision = current_revision(&shared.config_path).await?; + let disk_cfg = load_config_from_disk(&shared.config_path).await?; + let runtime_cfg = config_rx.borrow().clone(); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let users = users_from_config( - &cfg, + &disk_cfg, &shared.stats, &shared.ip_tracker, detected_ip_v4, detected_ip_v6, + Some(runtime_cfg.as_ref()), ) .await; if let Some(user_info) = @@ -445,7 +474,7 @@ async fn handle( let body = read_json::(req.into_body(), body_limit).await?; let result = patch_user(user, body, expected_revision, &shared).await; - let (data, revision) = match result { + let (mut data, revision) = match result { Ok(ok) => ok, Err(error) => { shared.runtime_events.record( @@ -455,10 +484,17 @@ async fn handle( return Err(error); } }; + let runtime_cfg = config_rx.borrow().clone(); + data.in_runtime = runtime_cfg.access.users.contains_key(&data.username); shared .runtime_events .record("api.user.patch.ok", format!("username={}", data.username)); - return Ok(success_response(StatusCode::OK, data, revision)); + let status = if data.in_runtime { + StatusCode::OK + } else { + StatusCode::ACCEPTED + }; + return Ok(success_response(status, data, revision)); } if method == Method::DELETE { if api_cfg.read_only { @@ -486,7 +522,18 @@ async fn handle( shared .runtime_events .record("api.user.delete.ok", format!("username={}", deleted_user)); - return Ok(success_response(StatusCode::OK, deleted_user, revision)); + let runtime_cfg = config_rx.borrow().clone(); + let in_runtime = runtime_cfg.access.users.contains_key(&deleted_user); + let response = DeleteUserResponse { + username: deleted_user, + in_runtime, + }; + let status = if response.in_runtime { + StatusCode::ACCEPTED + } else { + StatusCode::OK + }; + return Ok(success_response(status, response, revision)); } if method == Method::POST && let Some(base_user) = user.strip_suffix("/rotate-secret") @@ -514,7 +561,7 @@ async fn handle( &shared, ) .await; - let (data, revision) = match result { + let (mut data, revision) = match result { Ok(ok) => ok, Err(error) => { shared.runtime_events.record( @@ -524,11 +571,19 @@ async fn handle( return Err(error); } }; + let runtime_cfg = config_rx.borrow().clone(); + data.user.in_runtime = + runtime_cfg.access.users.contains_key(&data.user.username); shared.runtime_events.record( "api.user.rotate_secret.ok", format!("username={}", base_user), ); - return Ok(success_response(StatusCode::OK, data, revision)); + let status = if data.user.in_runtime { + StatusCode::OK + } else { + StatusCode::ACCEPTED + }; + return Ok(success_response(status, data, revision)); } if method == Method::POST { return Ok(error_response( diff --git a/src/api/model.rs b/src/api/model.rs index 8ae0c0b..ebc67d7 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -428,6 +428,7 @@ pub(super) struct UserLinks { #[derive(Serialize)] pub(super) struct UserInfo { pub(super) username: String, + pub(super) in_runtime: bool, pub(super) user_ad_tag: Option, pub(super) max_tcp_conns: Option, pub(super) expiration_rfc3339: Option, @@ -442,12 +443,24 @@ pub(super) struct UserInfo { pub(super) links: UserLinks, } +#[derive(Serialize)] +pub(super) struct UserActiveIps { + pub(super) username: String, + pub(super) active_ips: Vec, +} + #[derive(Serialize)] pub(super) struct CreateUserResponse { pub(super) user: UserInfo, pub(super) secret: String, } +#[derive(Serialize)] +pub(super) struct DeleteUserResponse { + pub(super) username: String, + pub(super) in_runtime: bool, +} + #[derive(Deserialize)] pub(super) struct CreateUserRequest { pub(super) username: String, diff --git a/src/api/runtime_zero.rs b/src/api/runtime_zero.rs index 0ed84a8..d54c50f 100644 --- a/src/api/runtime_zero.rs +++ b/src/api/runtime_zero.rs @@ -50,6 +50,7 @@ pub(super) struct RuntimeGatesData { #[derive(Serialize)] pub(super) struct EffectiveTimeoutLimits { + pub(super) client_first_byte_idle_secs: u64, pub(super) client_handshake_secs: u64, pub(super) tg_connect_secs: u64, pub(super) client_keepalive_secs: u64, @@ -99,6 +100,11 @@ pub(super) struct EffectiveUserIpPolicyLimits { pub(super) window_secs: u64, } +#[derive(Serialize)] +pub(super) struct EffectiveUserTcpPolicyLimits { + pub(super) global_each: usize, +} + #[derive(Serialize)] pub(super) struct EffectiveLimitsData { pub(super) update_every_secs: u64, @@ -108,6 +114,7 @@ pub(super) struct EffectiveLimitsData { pub(super) upstream: EffectiveUpstreamLimits, pub(super) middle_proxy: EffectiveMiddleProxyLimits, pub(super) user_ip_policy: EffectiveUserIpPolicyLimits, + pub(super) user_tcp_policy: EffectiveUserTcpPolicyLimits, } #[derive(Serialize)] @@ -227,8 +234,9 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD me_reinit_every_secs: cfg.general.effective_me_reinit_every_secs(), me_pool_force_close_secs: cfg.general.effective_me_pool_force_close_secs(), timeouts: EffectiveTimeoutLimits { + client_first_byte_idle_secs: cfg.timeouts.client_first_byte_idle_secs, client_handshake_secs: cfg.timeouts.client_handshake, - tg_connect_secs: cfg.timeouts.tg_connect, + tg_connect_secs: cfg.general.tg_connect, client_keepalive_secs: cfg.timeouts.client_keepalive, client_ack_secs: cfg.timeouts.client_ack, me_one_retry: cfg.timeouts.me_one_retry, @@ -287,6 +295,9 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD mode: user_max_unique_ips_mode_label(cfg.access.user_max_unique_ips_mode), window_secs: cfg.access.user_max_unique_ips_window_secs, }, + user_tcp_policy: EffectiveUserTcpPolicyLimits { + global_each: cfg.access.user_max_tcp_conns_global_each, + }, } } diff --git a/src/api/users.rs b/src/api/users.rs index 2ee8b98..5a09714 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -136,6 +136,7 @@ pub(super) async fn create_user( &shared.ip_tracker, detected_ip_v4, detected_ip_v6, + None, ) .await; let user = users @@ -143,8 +144,16 @@ pub(super) async fn create_user( .find(|entry| entry.username == body.username) .unwrap_or(UserInfo { username: body.username.clone(), + in_runtime: false, user_ad_tag: None, - max_tcp_conns: None, + max_tcp_conns: cfg + .access + .user_max_tcp_conns + .get(&body.username) + .copied() + .filter(|limit| *limit > 0) + .or((cfg.access.user_max_tcp_conns_global_each > 0) + .then_some(cfg.access.user_max_tcp_conns_global_each)), expiration_rfc3339: None, data_quota_bytes: None, max_unique_ips: updated_limit, @@ -236,6 +245,7 @@ pub(super) async fn patch_user( &shared.ip_tracker, detected_ip_v4, detected_ip_v6, + None, ) .await; let user_info = users @@ -293,6 +303,7 @@ pub(super) async fn rotate_secret( &shared.ip_tracker, detected_ip_v4, detected_ip_v6, + None, ) .await; let user_info = users @@ -365,6 +376,7 @@ pub(super) async fn users_from_config( ip_tracker: &UserIpTracker, startup_detected_ip_v4: Option, startup_detected_ip_v6: Option, + runtime_cfg: Option<&ProxyConfig>, ) -> Vec { let mut names = cfg.access.users.keys().cloned().collect::>(); names.sort(); @@ -394,8 +406,18 @@ pub(super) async fn users_from_config( tls: Vec::new(), }); users.push(UserInfo { + in_runtime: runtime_cfg + .map(|runtime| runtime.access.users.contains_key(&username)) + .unwrap_or(false), user_ad_tag: cfg.access.user_ad_tags.get(&username).cloned(), - max_tcp_conns: cfg.access.user_max_tcp_conns.get(&username).copied(), + max_tcp_conns: cfg + .access + .user_max_tcp_conns + .get(&username) + .copied() + .filter(|limit| *limit > 0) + .or((cfg.access.user_max_tcp_conns_global_each > 0) + .then_some(cfg.access.user_max_tcp_conns_global_each)), expiration_rfc3339: cfg .access .user_expirations @@ -572,3 +594,94 @@ fn resolve_tls_domains(cfg: &ProxyConfig) -> Vec<&str> { } domains } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ip_tracker::UserIpTracker; + use crate::stats::Stats; + + #[tokio::test] + async fn users_from_config_reports_effective_tcp_limit_with_global_fallback() { + let mut cfg = ProxyConfig::default(); + cfg.access.users.insert( + "alice".to_string(), + "0123456789abcdef0123456789abcdef".to_string(), + ); + cfg.access.user_max_tcp_conns_global_each = 7; + + let stats = Stats::new(); + let tracker = UserIpTracker::new(); + + let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert!(!alice.in_runtime); + assert_eq!(alice.max_tcp_conns, Some(7)); + + cfg.access.user_max_tcp_conns.insert("alice".to_string(), 5); + let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert!(!alice.in_runtime); + assert_eq!(alice.max_tcp_conns, Some(5)); + + cfg.access.user_max_tcp_conns.insert("alice".to_string(), 0); + let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert!(!alice.in_runtime); + assert_eq!(alice.max_tcp_conns, Some(7)); + + cfg.access.user_max_tcp_conns_global_each = 0; + let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert!(!alice.in_runtime); + assert_eq!(alice.max_tcp_conns, None); + } + + #[tokio::test] + async fn users_from_config_marks_runtime_membership_when_snapshot_is_provided() { + let mut disk_cfg = ProxyConfig::default(); + disk_cfg.access.users.insert( + "alice".to_string(), + "0123456789abcdef0123456789abcdef".to_string(), + ); + disk_cfg.access.users.insert( + "bob".to_string(), + "fedcba9876543210fedcba9876543210".to_string(), + ); + + let mut runtime_cfg = ProxyConfig::default(); + runtime_cfg.access.users.insert( + "alice".to_string(), + "0123456789abcdef0123456789abcdef".to_string(), + ); + + let stats = Stats::new(); + let tracker = UserIpTracker::new(); + let users = + users_from_config(&disk_cfg, &stats, &tracker, None, None, Some(&runtime_cfg)).await; + + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + let bob = users + .iter() + .find(|entry| entry.username == "bob") + .expect("bob must be present"); + + assert!(alice.in_runtime); + assert!(!bob.in_runtime); + } +} diff --git a/src/cli.rs b/src/cli.rs index 6dc0e2a..5a79bae 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,11 +1,270 @@ -//! CLI commands: --init (fire-and-forget setup) +//! CLI commands: --init (fire-and-forget setup), daemon options, subcommands +//! +//! Subcommands: +//! - `start [OPTIONS] [config.toml]` - Start the daemon +//! - `stop [--pid-file PATH]` - Stop a running daemon +//! - `reload [--pid-file PATH]` - Reload configuration (SIGHUP) +//! - `status [--pid-file PATH]` - Check daemon status +//! - `run [OPTIONS] [config.toml]` - Run in foreground (default behavior) use rand::RngExt; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; +#[cfg(unix)] +use crate::daemon::{self, DEFAULT_PID_FILE, DaemonOptions}; + +/// CLI subcommand to execute. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Subcommand { + /// Run the proxy (default, or explicit `run` subcommand). + Run, + /// Start as daemon (`start` subcommand). + Start, + /// Stop a running daemon (`stop` subcommand). + Stop, + /// Reload configuration (`reload` subcommand). + Reload, + /// Check daemon status (`status` subcommand). + Status, + /// Fire-and-forget setup (`--init`). + Init, +} + +/// Parsed subcommand with its options. +#[derive(Debug)] +pub struct ParsedCommand { + pub subcommand: Subcommand, + pub pid_file: PathBuf, + pub config_path: String, + #[cfg(unix)] + pub daemon_opts: DaemonOptions, + pub init_opts: Option, +} + +impl Default for ParsedCommand { + fn default() -> Self { + Self { + subcommand: Subcommand::Run, + #[cfg(unix)] + pid_file: PathBuf::from(DEFAULT_PID_FILE), + #[cfg(not(unix))] + pid_file: PathBuf::from("/var/run/telemt.pid"), + config_path: "config.toml".to_string(), + #[cfg(unix)] + daemon_opts: DaemonOptions::default(), + init_opts: None, + } + } +} + +/// Parse CLI arguments into a command structure. +pub fn parse_command(args: &[String]) -> ParsedCommand { + let mut cmd = ParsedCommand::default(); + + // Check for --init first (legacy form) + if args.iter().any(|a| a == "--init") { + cmd.subcommand = Subcommand::Init; + cmd.init_opts = parse_init_args(args); + return cmd; + } + + // Check for subcommand as first argument + if let Some(first) = args.first() { + match first.as_str() { + "start" => { + cmd.subcommand = Subcommand::Start; + #[cfg(unix)] + { + cmd.daemon_opts = parse_daemon_args(args); + // Force daemonize for start command + cmd.daemon_opts.daemonize = true; + } + } + "stop" => { + cmd.subcommand = Subcommand::Stop; + } + "reload" => { + cmd.subcommand = Subcommand::Reload; + } + "status" => { + cmd.subcommand = Subcommand::Status; + } + "run" => { + cmd.subcommand = Subcommand::Run; + #[cfg(unix)] + { + cmd.daemon_opts = parse_daemon_args(args); + } + } + _ => { + // No subcommand, default to Run + #[cfg(unix)] + { + cmd.daemon_opts = parse_daemon_args(args); + } + } + } + } + + // Parse remaining options + let mut i = 0; + while i < args.len() { + match args[i].as_str() { + // Skip subcommand names + "start" | "stop" | "reload" | "status" | "run" => {} + // PID file option (for stop/reload/status) + "--pid-file" => { + i += 1; + if i < args.len() { + cmd.pid_file = PathBuf::from(&args[i]); + #[cfg(unix)] + { + cmd.daemon_opts.pid_file = Some(cmd.pid_file.clone()); + } + } + } + s if s.starts_with("--pid-file=") => { + cmd.pid_file = PathBuf::from(s.trim_start_matches("--pid-file=")); + #[cfg(unix)] + { + cmd.daemon_opts.pid_file = Some(cmd.pid_file.clone()); + } + } + // Config path (positional, non-flag argument) + s if !s.starts_with('-') => { + cmd.config_path = s.to_string(); + } + _ => {} + } + i += 1; + } + + cmd +} + +/// Execute a subcommand that doesn't require starting the server. +/// Returns `Some(exit_code)` if the command was handled, `None` if server should start. +#[cfg(unix)] +pub fn execute_subcommand(cmd: &ParsedCommand) -> Option { + match cmd.subcommand { + Subcommand::Stop => Some(cmd_stop(&cmd.pid_file)), + Subcommand::Reload => Some(cmd_reload(&cmd.pid_file)), + Subcommand::Status => Some(cmd_status(&cmd.pid_file)), + Subcommand::Init => { + if let Some(opts) = cmd.init_opts.clone() { + match run_init(opts) { + Ok(()) => Some(0), + Err(e) => { + eprintln!("[telemt] Init failed: {}", e); + Some(1) + } + } + } else { + Some(1) + } + } + // Run and Start need the server + Subcommand::Run | Subcommand::Start => None, + } +} + +#[cfg(not(unix))] +pub fn execute_subcommand(cmd: &ParsedCommand) -> Option { + match cmd.subcommand { + Subcommand::Stop | Subcommand::Reload | Subcommand::Status => { + eprintln!("[telemt] Subcommand not supported on this platform"); + Some(1) + } + Subcommand::Init => { + if let Some(opts) = cmd.init_opts.clone() { + match run_init(opts) { + Ok(()) => Some(0), + Err(e) => { + eprintln!("[telemt] Init failed: {}", e); + Some(1) + } + } + } else { + Some(1) + } + } + Subcommand::Run | Subcommand::Start => None, + } +} + +/// Stop command: send SIGTERM to the running daemon. +#[cfg(unix)] +fn cmd_stop(pid_file: &Path) -> i32 { + use nix::sys::signal::Signal; + + println!("Stopping telemt daemon..."); + + match daemon::signal_pid_file(pid_file, Signal::SIGTERM) { + Ok(()) => { + println!("Stop signal sent successfully"); + + // Wait for process to exit (up to 10 seconds) + for _ in 0..20 { + std::thread::sleep(std::time::Duration::from_millis(500)); + if let daemon::DaemonStatus::NotRunning = daemon::check_status(pid_file) { + println!("Daemon stopped"); + return 0; + } + } + println!("Daemon may still be shutting down"); + 0 + } + Err(e) => { + eprintln!("Failed to stop daemon: {}", e); + 1 + } + } +} + +/// Reload command: send SIGHUP to trigger config reload. +#[cfg(unix)] +fn cmd_reload(pid_file: &Path) -> i32 { + use nix::sys::signal::Signal; + + println!("Reloading telemt configuration..."); + + match daemon::signal_pid_file(pid_file, Signal::SIGHUP) { + Ok(()) => { + println!("Reload signal sent successfully"); + 0 + } + Err(e) => { + eprintln!("Failed to reload daemon: {}", e); + 1 + } + } +} + +/// Status command: check if daemon is running. +#[cfg(unix)] +fn cmd_status(pid_file: &Path) -> i32 { + match daemon::check_status(pid_file) { + daemon::DaemonStatus::Running(pid) => { + println!("telemt is running (pid {})", pid); + 0 + } + daemon::DaemonStatus::Stale(pid) => { + println!("telemt is not running (stale pid file, was pid {})", pid); + // Clean up stale PID file + let _ = std::fs::remove_file(pid_file); + 1 + } + daemon::DaemonStatus::NotRunning => { + println!("telemt is not running"); + 1 + } + } +} + /// Options for the init command +#[derive(Debug, Clone)] pub struct InitOptions { pub port: u16, pub domain: String, @@ -15,6 +274,64 @@ pub struct InitOptions { pub no_start: bool, } +/// Parse daemon-related options from CLI args. +#[cfg(unix)] +pub fn parse_daemon_args(args: &[String]) -> DaemonOptions { + let mut opts = DaemonOptions::default(); + let mut i = 0; + + while i < args.len() { + match args[i].as_str() { + "--daemon" | "-d" => { + opts.daemonize = true; + } + "--foreground" | "-f" => { + opts.foreground = true; + } + "--pid-file" => { + i += 1; + if i < args.len() { + opts.pid_file = Some(PathBuf::from(&args[i])); + } + } + s if s.starts_with("--pid-file=") => { + opts.pid_file = Some(PathBuf::from(s.trim_start_matches("--pid-file="))); + } + "--run-as-user" => { + i += 1; + if i < args.len() { + opts.user = Some(args[i].clone()); + } + } + s if s.starts_with("--run-as-user=") => { + opts.user = Some(s.trim_start_matches("--run-as-user=").to_string()); + } + "--run-as-group" => { + i += 1; + if i < args.len() { + opts.group = Some(args[i].clone()); + } + } + s if s.starts_with("--run-as-group=") => { + opts.group = Some(s.trim_start_matches("--run-as-group=").to_string()); + } + "--working-dir" => { + i += 1; + if i < args.len() { + opts.working_dir = Some(PathBuf::from(&args[i])); + } + } + s if s.starts_with("--working-dir=") => { + opts.working_dir = Some(PathBuf::from(s.trim_start_matches("--working-dir="))); + } + _ => {} + } + i += 1; + } + + opts +} + impl Default for InitOptions { fn default() -> Self { Self { @@ -84,10 +401,16 @@ pub fn parse_init_args(args: &[String]) -> Option { /// Run the fire-and-forget setup. pub fn run_init(opts: InitOptions) -> Result<(), Box> { + use crate::service::{self, InitSystem, ServiceOptions}; + eprintln!("[telemt] Fire-and-forget setup"); eprintln!(); - // 1. Generate or validate secret + // 1. Detect init system + let init_system = service::detect_init_system(); + eprintln!("[+] Detected init system: {}", init_system); + + // 2. Generate or validate secret let secret = match opts.secret { Some(s) => { if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) { @@ -104,72 +427,126 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[+] Port: {}", opts.port); eprintln!("[+] Domain: {}", opts.domain); - // 2. Create config directory + // 3. Create config directory fs::create_dir_all(&opts.config_dir)?; let config_path = opts.config_dir.join("config.toml"); - // 3. Write config + // 4. Write config let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain); fs::write(&config_path, &config_content)?; eprintln!("[+] Config written to {}", config_path.display()); - // 4. Write systemd unit + // 5. Generate and write service file let exe_path = std::env::current_exe().unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); - let unit_path = Path::new("/etc/systemd/system/telemt.service"); - let unit_content = generate_systemd_unit(&exe_path, &config_path); + let service_opts = ServiceOptions { + exe_path: &exe_path, + config_path: &config_path, + user: None, // Let systemd/init handle user + group: None, + pid_file: "/var/run/telemt.pid", + working_dir: Some("/var/lib/telemt"), + description: "Telemt MTProxy - Telegram MTProto Proxy", + }; - match fs::write(unit_path, &unit_content) { + let service_path = service::service_file_path(init_system); + let service_content = service::generate_service_file(init_system, &service_opts); + + // Ensure parent directory exists + if let Some(parent) = Path::new(service_path).parent() { + let _ = fs::create_dir_all(parent); + } + + match fs::write(service_path, &service_content) { Ok(()) => { - eprintln!("[+] Systemd unit written to {}", unit_path.display()); + eprintln!("[+] Service file written to {}", service_path); + + // Make script executable for OpenRC/FreeBSD + #[cfg(unix)] + if init_system == InitSystem::OpenRC || init_system == InitSystem::FreeBSDRc { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(service_path)?.permissions(); + perms.set_mode(0o755); + fs::set_permissions(service_path, perms)?; + } } Err(e) => { - eprintln!("[!] Cannot write systemd unit (run as root?): {}", e); - eprintln!("[!] Manual unit file content:"); - eprintln!("{}", unit_content); + eprintln!("[!] Cannot write service file (run as root?): {}", e); + eprintln!("[!] Manual service file content:"); + eprintln!("{}", service_content); - // Still print links and config + // Still print links and installation instructions + eprintln!(); + eprintln!("{}", service::installation_instructions(init_system)); print_links(&opts.username, &secret, opts.port, &opts.domain); return Ok(()); } } - // 5. Reload systemd - run_cmd("systemctl", &["daemon-reload"]); + // 6. Install and enable service based on init system + match init_system { + InitSystem::Systemd => { + run_cmd("systemctl", &["daemon-reload"]); + run_cmd("systemctl", &["enable", "telemt.service"]); + eprintln!("[+] Service enabled"); - // 6. Enable service - run_cmd("systemctl", &["enable", "telemt.service"]); - eprintln!("[+] Service enabled"); + if !opts.no_start { + run_cmd("systemctl", &["start", "telemt.service"]); + eprintln!("[+] Service started"); - // 7. Start service (unless --no-start) - if !opts.no_start { - run_cmd("systemctl", &["start", "telemt.service"]); - eprintln!("[+] Service started"); + std::thread::sleep(std::time::Duration::from_secs(1)); + let status = Command::new("systemctl") + .args(["is-active", "telemt.service"]) + .output(); - // Brief delay then check status - std::thread::sleep(std::time::Duration::from_secs(1)); - let status = Command::new("systemctl") - .args(["is-active", "telemt.service"]) - .output(); - - match status { - Ok(out) if out.status.success() => { - eprintln!("[+] Service is running"); - } - _ => { - eprintln!("[!] Service may not have started correctly"); - eprintln!("[!] Check: journalctl -u telemt.service -n 20"); + match status { + Ok(out) if out.status.success() => { + eprintln!("[+] Service is running"); + } + _ => { + eprintln!("[!] Service may not have started correctly"); + eprintln!("[!] Check: journalctl -u telemt.service -n 20"); + } + } + } else { + eprintln!("[+] Service not started (--no-start)"); + eprintln!("[+] Start manually: systemctl start telemt.service"); } } - } else { - eprintln!("[+] Service not started (--no-start)"); - eprintln!("[+] Start manually: systemctl start telemt.service"); + InitSystem::OpenRC => { + run_cmd("rc-update", &["add", "telemt", "default"]); + eprintln!("[+] Service enabled"); + + if !opts.no_start { + run_cmd("rc-service", &["telemt", "start"]); + eprintln!("[+] Service started"); + } else { + eprintln!("[+] Service not started (--no-start)"); + eprintln!("[+] Start manually: rc-service telemt start"); + } + } + InitSystem::FreeBSDRc => { + run_cmd("sysrc", &["telemt_enable=YES"]); + eprintln!("[+] Service enabled"); + + if !opts.no_start { + run_cmd("service", &["telemt", "start"]); + eprintln!("[+] Service started"); + } else { + eprintln!("[+] Service not started (--no-start)"); + eprintln!("[+] Start manually: service telemt start"); + } + } + InitSystem::Unknown => { + eprintln!("[!] Unknown init system - service file written but not installed"); + eprintln!("[!] You may need to install it manually"); + } } eprintln!(); - // 8. Print links + // 7. Print links print_links(&opts.username, &secret, opts.port, &opts.domain); Ok(()) @@ -207,6 +584,7 @@ me_pool_drain_soft_evict_cooldown_ms = 1000 me_bind_stale_mode = "never" me_pool_min_fresh_ratio = 0.8 me_reinit_drain_timeout_secs = 90 +tg_connect = 10 [network] ipv4 = true @@ -232,8 +610,8 @@ ip = "0.0.0.0" ip = "::" [timeouts] -client_handshake = 15 -tg_connect = 10 +client_first_byte_idle_secs = 300 +client_handshake = 60 client_keepalive = 60 client_ack = 300 @@ -245,6 +623,7 @@ fake_cert_len = 2048 tls_full_cert_ttl_secs = 90 [access] +user_max_tcp_conns_global_each = 0 replay_check_len = 65536 replay_window_secs = 120 ignore_time_skew = false @@ -264,35 +643,6 @@ weight = 10 ) } -fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String { - format!( - r#"[Unit] -Description=Telemt MTProxy -Documentation=https://github.com/telemt/telemt -After=network-online.target -Wants=network-online.target - -[Service] -Type=simple -ExecStart={exe} {config} -Restart=always -RestartSec=5 -LimitNOFILE=65535 -# Security hardening -NoNewPrivileges=true -ProtectSystem=strict -ProtectHome=true -ReadWritePaths=/etc/telemt -PrivateTmp=true - -[Install] -WantedBy=multi-user.target -"#, - exe = exe_path.display(), - config = config_path.display(), - ) -} - fn run_cmd(cmd: &str, args: &[&str]) { match Command::new(cmd).args(args).output() { Ok(output) => { diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 608e1b8..89e72bb 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -110,7 +110,11 @@ pub(crate) fn default_replay_window_secs() -> u64 { } pub(crate) fn default_handshake_timeout() -> u64 { - 30 + 60 +} + +pub(crate) fn default_client_first_byte_idle_secs() -> u64 { + 300 } pub(crate) fn default_relay_idle_policy_v2_enabled() -> bool { @@ -209,6 +213,10 @@ pub(crate) fn default_server_max_connections() -> u32 { 10_000 } +pub(crate) fn default_listen_backlog() -> u32 { + 1024 +} + pub(crate) fn default_accept_permit_timeout_ms() -> u64 { DEFAULT_ACCEPT_PERMIT_TIMEOUT_MS } @@ -803,6 +811,10 @@ pub(crate) fn default_user_max_unique_ips_window_secs() -> u64 { DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS } +pub(crate) fn default_user_max_tcp_conns_global_each() -> usize { + 0 +} + pub(crate) fn default_user_max_unique_ips_global_each() -> usize { 0 } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 9bd2927..fa42c55 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -117,6 +117,7 @@ pub struct HotFields { pub users: std::collections::HashMap, pub user_ad_tags: std::collections::HashMap, pub user_max_tcp_conns: std::collections::HashMap, + pub user_max_tcp_conns_global_each: usize, pub user_expirations: std::collections::HashMap>, pub user_data_quota: std::collections::HashMap, pub user_max_unique_ips: std::collections::HashMap, @@ -240,6 +241,7 @@ impl HotFields { users: cfg.access.users.clone(), user_ad_tags: cfg.access.user_ad_tags.clone(), user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), + user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each, user_expirations: cfg.access.user_expirations.clone(), user_data_quota: cfg.access.user_data_quota.clone(), user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), @@ -530,6 +532,7 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.access.users = new.access.users.clone(); cfg.access.user_ad_tags = new.access.user_ad_tags.clone(); cfg.access.user_max_tcp_conns = new.access.user_max_tcp_conns.clone(); + cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each; cfg.access.user_expirations = new.access.user_expirations.clone(); cfg.access.user_data_quota = new.access.user_data_quota.clone(); cfg.access.user_max_unique_ips = new.access.user_max_unique_ips.clone(); @@ -570,6 +573,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b } if old.server.proxy_protocol != new.server.proxy_protocol || !listeners_equal(&old.server.listeners, &new.server.listeners) + || old.server.listen_backlog != new.server.listen_backlog || old.server.listen_addr_ipv4 != new.server.listen_addr_ipv4 || old.server.listen_addr_ipv6 != new.server.listen_addr_ipv6 || old.server.listen_tcp != new.server.listen_tcp @@ -695,6 +699,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b if old.general.upstream_connect_retry_attempts != new.general.upstream_connect_retry_attempts || old.general.upstream_connect_retry_backoff_ms != new.general.upstream_connect_retry_backoff_ms + || old.general.tg_connect != new.general.tg_connect || old.general.upstream_unhealthy_fail_threshold != new.general.upstream_unhealthy_fail_threshold || old.general.upstream_connect_failfast_hard_errors @@ -1143,6 +1148,12 @@ fn log_changes( new_hot.user_max_tcp_conns.len() ); } + if old_hot.user_max_tcp_conns_global_each != new_hot.user_max_tcp_conns_global_each { + info!( + "config reload: user_max_tcp_conns policy global_each={}", + new_hot.user_max_tcp_conns_global_each + ); + } if old_hot.user_expirations != new_hot.user_expirations { info!( "config reload: user_expirations updated ({} entries)", diff --git a/src/config/load.rs b/src/config/load.rs index 7892e2c..cc95f34 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -346,6 +346,12 @@ impl ProxyConfig { )); } + if config.general.tg_connect == 0 { + return Err(ProxyError::Config( + "general.tg_connect must be > 0".to_string(), + )); + } + if config.general.upstream_unhealthy_fail_threshold == 0 { return Err(ProxyError::Config( "general.upstream_unhealthy_fail_threshold must be > 0".to_string(), @@ -1322,6 +1328,10 @@ mod tests { default_api_runtime_edge_events_capacity() ); assert_eq!(cfg.access.users, default_access_users()); + assert_eq!( + cfg.access.user_max_tcp_conns_global_each, + default_user_max_tcp_conns_global_each() + ); assert_eq!( cfg.access.user_max_unique_ips_mode, UserMaxUniqueIpsMode::default() @@ -1465,6 +1475,10 @@ mod tests { let access = AccessConfig::default(); assert_eq!(access.users, default_access_users()); + assert_eq!( + access.user_max_tcp_conns_global_each, + default_user_max_tcp_conns_global_each() + ); } #[test] @@ -1907,6 +1921,26 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn tg_connect_zero_is_rejected() { + let toml = r#" + [general] + tg_connect = 0 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tg_connect_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.tg_connect must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn rpc_proxy_req_every_out_of_range_is_rejected() { let toml = r#" diff --git a/src/config/tests/load_idle_policy_tests.rs b/src/config/tests/load_idle_policy_tests.rs index c6a4e86..0767e8e 100644 --- a/src/config/tests/load_idle_policy_tests.rs +++ b/src/config/tests/load_idle_policy_tests.rs @@ -17,6 +17,28 @@ fn remove_temp_config(path: &PathBuf) { let _ = fs::remove_file(path); } +#[test] +fn default_timeouts_enable_apple_compatible_handshake_profile() { + let cfg = ProxyConfig::default(); + assert_eq!(cfg.timeouts.client_first_byte_idle_secs, 300); + assert_eq!(cfg.timeouts.client_handshake, 60); +} + +#[test] +fn load_accepts_zero_first_byte_idle_timeout_as_legacy_opt_out() { + let path = write_temp_config( + r#" +[timeouts] +client_first_byte_idle_secs = 0 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("config with zero first-byte idle timeout must load"); + assert_eq!(cfg.timeouts.client_first_byte_idle_secs, 0); + + remove_temp_config(&path); +} + #[test] fn load_rejects_relay_hard_idle_smaller_than_soft_idle_with_clear_error() { let path = write_temp_config( diff --git a/src/config/types.rs b/src/config/types.rs index cb14747..41b0c2e 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -663,6 +663,10 @@ pub struct GeneralConfig { #[serde(default = "default_upstream_connect_budget_ms")] pub upstream_connect_budget_ms: u64, + /// Per-attempt TCP connect timeout to Telegram DC (seconds). + #[serde(default = "default_connect_timeout")] + pub tg_connect: u64, + /// Consecutive failed requests before upstream is marked unhealthy. #[serde(default = "default_upstream_unhealthy_fail_threshold")] pub upstream_unhealthy_fail_threshold: u32, @@ -1007,6 +1011,7 @@ impl Default for GeneralConfig { upstream_connect_retry_attempts: default_upstream_connect_retry_attempts(), upstream_connect_retry_backoff_ms: default_upstream_connect_retry_backoff_ms(), upstream_connect_budget_ms: default_upstream_connect_budget_ms(), + tg_connect: default_connect_timeout(), upstream_unhealthy_fail_threshold: default_upstream_unhealthy_fail_threshold(), upstream_connect_failfast_hard_errors: default_upstream_connect_failfast_hard_errors(), stun_iface_mismatch_ignore: false, @@ -1272,6 +1277,11 @@ pub struct ServerConfig { #[serde(default)] pub listeners: Vec, + /// TCP `listen(2)` backlog for client-facing sockets (also used for the metrics HTTP listener). + /// The effective queue is capped by the kernel (for example `somaxconn` on Linux). + #[serde(default = "default_listen_backlog")] + pub listen_backlog: u32, + /// Maximum number of concurrent client connections. /// 0 means unlimited. #[serde(default = "default_server_max_connections")] @@ -1300,6 +1310,7 @@ impl Default for ServerConfig { metrics_whitelist: default_metrics_whitelist(), api: ApiConfig::default(), listeners: Vec::new(), + listen_backlog: default_listen_backlog(), max_connections: default_server_max_connections(), accept_permit_timeout_ms: default_accept_permit_timeout_ms(), } @@ -1308,6 +1319,12 @@ impl Default for ServerConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TimeoutsConfig { + /// Maximum idle wait in seconds for the first client byte before handshake parsing starts. + /// `0` disables the separate idle phase and keeps legacy timeout behavior. + #[serde(default = "default_client_first_byte_idle_secs")] + pub client_first_byte_idle_secs: u64, + + /// Maximum active handshake duration in seconds after the first client byte is received. #[serde(default = "default_handshake_timeout")] pub client_handshake: u64, @@ -1329,9 +1346,6 @@ pub struct TimeoutsConfig { #[serde(default = "default_relay_idle_grace_after_downstream_activity_secs")] pub relay_idle_grace_after_downstream_activity_secs: u64, - #[serde(default = "default_connect_timeout")] - pub tg_connect: u64, - #[serde(default = "default_keepalive")] pub client_keepalive: u64, @@ -1350,13 +1364,13 @@ pub struct TimeoutsConfig { impl Default for TimeoutsConfig { fn default() -> Self { Self { + client_first_byte_idle_secs: default_client_first_byte_idle_secs(), client_handshake: default_handshake_timeout(), relay_idle_policy_v2_enabled: default_relay_idle_policy_v2_enabled(), relay_client_idle_soft_secs: default_relay_client_idle_soft_secs(), relay_client_idle_hard_secs: default_relay_client_idle_hard_secs(), relay_idle_grace_after_downstream_activity_secs: default_relay_idle_grace_after_downstream_activity_secs(), - tg_connect: default_connect_timeout(), client_keepalive: default_keepalive(), client_ack: default_ack_timeout(), me_one_retry: default_me_one_retry(), @@ -1619,6 +1633,12 @@ pub struct AccessConfig { #[serde(default)] pub user_max_tcp_conns: HashMap, + /// Global per-user TCP connection limit applied when a user has no + /// positive individual override. + /// `0` disables the inherited limit. + #[serde(default = "default_user_max_tcp_conns_global_each")] + pub user_max_tcp_conns_global_each: usize, + #[serde(default)] pub user_expirations: HashMap>, @@ -1655,6 +1675,7 @@ impl Default for AccessConfig { users: default_access_users(), user_ad_tags: HashMap::new(), user_max_tcp_conns: HashMap::new(), + user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(), user_expirations: HashMap::new(), user_data_quota: HashMap::new(), user_max_unique_ips: HashMap::new(), diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs new file mode 100644 index 0000000..8e2481e --- /dev/null +++ b/src/daemon/mod.rs @@ -0,0 +1,541 @@ +//! Unix daemon support for telemt. +//! +//! Provides classic Unix daemonization (double-fork), PID file management, +//! and privilege dropping for running telemt as a background service. + +use std::fs::{self, File, OpenOptions}; +use std::io::{self, Read, Write}; +use std::os::unix::fs::OpenOptionsExt; +use std::path::{Path, PathBuf}; + +use nix::fcntl::{Flock, FlockArg}; +use nix::unistd::{self, ForkResult, Gid, Pid, Uid, chdir, close, fork, getpid, setsid}; +use tracing::{debug, info, warn}; + +/// Default PID file location. +pub const DEFAULT_PID_FILE: &str = "/var/run/telemt.pid"; + +/// Daemon configuration options parsed from CLI. +#[derive(Debug, Clone, Default)] +pub struct DaemonOptions { + /// Run as daemon (fork to background). + pub daemonize: bool, + /// Path to PID file. + pub pid_file: Option, + /// User to run as after binding sockets. + pub user: Option, + /// Group to run as after binding sockets. + pub group: Option, + /// Working directory for the daemon. + pub working_dir: Option, + /// Explicit foreground mode (for systemd Type=simple). + pub foreground: bool, +} + +impl DaemonOptions { + /// Returns the effective PID file path. + pub fn pid_file_path(&self) -> &Path { + self.pid_file + .as_deref() + .unwrap_or(Path::new(DEFAULT_PID_FILE)) + } + + /// Returns true if we should actually daemonize. + /// Foreground flag takes precedence. + pub fn should_daemonize(&self) -> bool { + self.daemonize && !self.foreground + } +} + +/// Error types for daemon operations. +#[derive(Debug, thiserror::Error)] +pub enum DaemonError { + #[error("fork failed: {0}")] + ForkFailed(#[source] nix::Error), + + #[error("setsid failed: {0}")] + SetsidFailed(#[source] nix::Error), + + #[error("chdir failed: {0}")] + ChdirFailed(#[source] nix::Error), + + #[error("failed to open /dev/null: {0}")] + DevNullFailed(#[source] io::Error), + + #[error("failed to redirect stdio: {0}")] + RedirectFailed(#[source] nix::Error), + + #[error("PID file error: {0}")] + PidFile(String), + + #[error("another instance is already running (pid {0})")] + AlreadyRunning(i32), + + #[error("user '{0}' not found")] + UserNotFound(String), + + #[error("group '{0}' not found")] + GroupNotFound(String), + + #[error("failed to set uid/gid: {0}")] + PrivilegeDrop(#[source] nix::Error), + + #[error("io error: {0}")] + Io(#[from] io::Error), +} + +/// Result of a successful daemonize() call. +#[derive(Debug)] +pub enum DaemonizeResult { + /// We are the parent process and should exit. + Parent, + /// We are the daemon child process and should continue. + Child, +} + +/// Performs classic Unix double-fork daemonization. +/// +/// This detaches the process from the controlling terminal: +/// 1. First fork - parent exits, child continues +/// 2. setsid() - become session leader +/// 3. Second fork - ensure we can never acquire a controlling terminal +/// 4. chdir("/") - don't hold any directory open +/// 5. Redirect stdin/stdout/stderr to /dev/null +/// +/// Returns `DaemonizeResult::Parent` in the original parent (which should exit), +/// or `DaemonizeResult::Child` in the final daemon child. +pub fn daemonize(working_dir: Option<&Path>) -> Result { + // First fork + match unsafe { fork() } { + Ok(ForkResult::Parent { .. }) => { + // Parent exits + return Ok(DaemonizeResult::Parent); + } + Ok(ForkResult::Child) => { + // Child continues + } + Err(e) => return Err(DaemonError::ForkFailed(e)), + } + + // Create new session, become session leader + setsid().map_err(DaemonError::SetsidFailed)?; + + // Second fork to ensure we can never acquire a controlling terminal + match unsafe { fork() } { + Ok(ForkResult::Parent { .. }) => { + // Intermediate parent exits + std::process::exit(0); + } + Ok(ForkResult::Child) => { + // Final daemon child continues + } + Err(e) => return Err(DaemonError::ForkFailed(e)), + } + + // Change working directory + let target_dir = working_dir.unwrap_or(Path::new("/")); + chdir(target_dir).map_err(DaemonError::ChdirFailed)?; + + // Redirect stdin, stdout, stderr to /dev/null + redirect_stdio_to_devnull()?; + + Ok(DaemonizeResult::Child) +} + +/// Redirects stdin, stdout, and stderr to /dev/null. +fn redirect_stdio_to_devnull() -> Result<(), DaemonError> { + let devnull = File::options() + .read(true) + .write(true) + .open("/dev/null") + .map_err(DaemonError::DevNullFailed)?; + + let devnull_fd = std::os::unix::io::AsRawFd::as_raw_fd(&devnull); + + // Use libc::dup2 directly for redirecting standard file descriptors + // nix 0.31's dup2 requires OwnedFd which doesn't work well with stdio fds + unsafe { + // Redirect stdin (fd 0) + if libc::dup2(devnull_fd, 0) < 0 { + return Err(DaemonError::RedirectFailed(nix::errno::Errno::last())); + } + // Redirect stdout (fd 1) + if libc::dup2(devnull_fd, 1) < 0 { + return Err(DaemonError::RedirectFailed(nix::errno::Errno::last())); + } + // Redirect stderr (fd 2) + if libc::dup2(devnull_fd, 2) < 0 { + return Err(DaemonError::RedirectFailed(nix::errno::Errno::last())); + } + } + + // Close original devnull fd if it's not one of the standard fds + if devnull_fd > 2 { + let _ = close(devnull_fd); + } + + Ok(()) +} + +/// PID file manager with flock-based locking. +pub struct PidFile { + path: PathBuf, + file: Option, + locked: bool, +} + +impl PidFile { + /// Creates a new PID file manager for the given path. + pub fn new>(path: P) -> Self { + Self { + path: path.as_ref().to_path_buf(), + file: None, + locked: false, + } + } + + /// Checks if another instance is already running. + /// + /// Returns the PID of the running instance if one exists. + pub fn check_running(&self) -> Result, DaemonError> { + if !self.path.exists() { + return Ok(None); + } + + // Try to read existing PID + let mut contents = String::new(); + File::open(&self.path) + .and_then(|mut f| f.read_to_string(&mut contents)) + .map_err(|e| { + DaemonError::PidFile(format!("cannot read {}: {}", self.path.display(), e)) + })?; + + let pid: i32 = contents + .trim() + .parse() + .map_err(|_| DaemonError::PidFile(format!("invalid PID in {}", self.path.display())))?; + + // Check if process is still running + if is_process_running(pid) { + Ok(Some(pid)) + } else { + // Stale PID file + debug!(pid, path = %self.path.display(), "Removing stale PID file"); + let _ = fs::remove_file(&self.path); + Ok(None) + } + } + + /// Acquires the PID file lock and writes the current PID. + /// + /// Fails if another instance is already running. + pub fn acquire(&mut self) -> Result<(), DaemonError> { + // Check for running instance first + if let Some(pid) = self.check_running()? { + return Err(DaemonError::AlreadyRunning(pid)); + } + + // Ensure parent directory exists + if let Some(parent) = self.path.parent() { + if !parent.exists() { + fs::create_dir_all(parent).map_err(|e| { + DaemonError::PidFile(format!( + "cannot create directory {}: {}", + parent.display(), + e + )) + })?; + } + } + + // Open/create PID file with exclusive lock + let file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o644) + .open(&self.path) + .map_err(|e| { + DaemonError::PidFile(format!("cannot open {}: {}", self.path.display(), e)) + })?; + + // Try to acquire exclusive lock (non-blocking) + let flock = Flock::lock(file, FlockArg::LockExclusiveNonblock).map_err(|(_, errno)| { + // Check if another instance grabbed the lock + if let Some(pid) = self.check_running().ok().flatten() { + DaemonError::AlreadyRunning(pid) + } else { + DaemonError::PidFile(format!("cannot lock {}: {}", self.path.display(), errno)) + } + })?; + + // Write our PID + let pid = getpid(); + let mut file = flock + .unlock() + .map_err(|(_, errno)| DaemonError::PidFile(format!("unlock failed: {}", errno)))?; + + writeln!(file, "{}", pid).map_err(|e| { + DaemonError::PidFile(format!( + "cannot write PID to {}: {}", + self.path.display(), + e + )) + })?; + + // Re-acquire lock and keep it + let flock = Flock::lock(file, FlockArg::LockExclusiveNonblock).map_err(|(_, errno)| { + DaemonError::PidFile(format!("cannot re-lock {}: {}", self.path.display(), errno)) + })?; + + self.file = Some(flock.unlock().map_err(|(_, errno)| { + DaemonError::PidFile(format!("unlock for storage failed: {}", errno)) + })?); + self.locked = true; + + info!(pid = pid.as_raw(), path = %self.path.display(), "PID file created"); + Ok(()) + } + + /// Releases the PID file lock and removes the file. + pub fn release(&mut self) -> Result<(), DaemonError> { + if let Some(file) = self.file.take() { + drop(file); + } + self.locked = false; + + if self.path.exists() { + fs::remove_file(&self.path).map_err(|e| { + DaemonError::PidFile(format!("cannot remove {}: {}", self.path.display(), e)) + })?; + debug!(path = %self.path.display(), "PID file removed"); + } + + Ok(()) + } + + /// Returns the path to this PID file. + #[allow(dead_code)] + pub fn path(&self) -> &Path { + &self.path + } +} + +impl Drop for PidFile { + fn drop(&mut self) { + if self.locked { + if let Err(e) = self.release() { + warn!(error = %e, "Failed to clean up PID file on drop"); + } + } + } +} + +/// Checks if a process with the given PID is running. +fn is_process_running(pid: i32) -> bool { + // kill(pid, 0) checks if process exists without sending a signal + nix::sys::signal::kill(Pid::from_raw(pid), None).is_ok() +} + +/// Drops privileges to the specified user and group. +/// +/// This should be called after binding privileged ports but before +/// entering the main event loop. +pub fn drop_privileges(user: Option<&str>, group: Option<&str>) -> Result<(), DaemonError> { + // Look up group first (need to do this while still root) + let target_gid = if let Some(group_name) = group { + Some(lookup_group(group_name)?) + } else if let Some(user_name) = user { + // If no group specified but user is, use user's primary group + Some(lookup_user_primary_gid(user_name)?) + } else { + None + }; + + // Look up user + let target_uid = if let Some(user_name) = user { + Some(lookup_user(user_name)?) + } else { + None + }; + + // Drop privileges: set GID first, then UID + // (Setting UID first would prevent us from setting GID) + if let Some(gid) = target_gid { + unistd::setgid(gid).map_err(DaemonError::PrivilegeDrop)?; + // Also set supplementary groups to just this one + unistd::setgroups(&[gid]).map_err(DaemonError::PrivilegeDrop)?; + info!(gid = gid.as_raw(), "Dropped group privileges"); + } + + if let Some(uid) = target_uid { + unistd::setuid(uid).map_err(DaemonError::PrivilegeDrop)?; + info!(uid = uid.as_raw(), "Dropped user privileges"); + } + + Ok(()) +} + +/// Looks up a user by name and returns their UID. +fn lookup_user(name: &str) -> Result { + // Use libc getpwnam + let c_name = + std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; + + unsafe { + let pwd = libc::getpwnam(c_name.as_ptr()); + if pwd.is_null() { + Err(DaemonError::UserNotFound(name.to_string())) + } else { + Ok(Uid::from_raw((*pwd).pw_uid)) + } + } +} + +/// Looks up a user's primary GID by username. +fn lookup_user_primary_gid(name: &str) -> Result { + let c_name = + std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; + + unsafe { + let pwd = libc::getpwnam(c_name.as_ptr()); + if pwd.is_null() { + Err(DaemonError::UserNotFound(name.to_string())) + } else { + Ok(Gid::from_raw((*pwd).pw_gid)) + } + } +} + +/// Looks up a group by name and returns its GID. +fn lookup_group(name: &str) -> Result { + let c_name = + std::ffi::CString::new(name).map_err(|_| DaemonError::GroupNotFound(name.to_string()))?; + + unsafe { + let grp = libc::getgrnam(c_name.as_ptr()); + if grp.is_null() { + Err(DaemonError::GroupNotFound(name.to_string())) + } else { + Ok(Gid::from_raw((*grp).gr_gid)) + } + } +} + +/// Reads PID from a PID file. +#[allow(dead_code)] +pub fn read_pid_file>(path: P) -> Result { + let path = path.as_ref(); + let mut contents = String::new(); + File::open(path) + .and_then(|mut f| f.read_to_string(&mut contents)) + .map_err(|e| DaemonError::PidFile(format!("cannot read {}: {}", path.display(), e)))?; + + contents + .trim() + .parse() + .map_err(|_| DaemonError::PidFile(format!("invalid PID in {}", path.display()))) +} + +/// Sends a signal to the process specified in a PID file. +#[allow(dead_code)] +pub fn signal_pid_file>( + path: P, + signal: nix::sys::signal::Signal, +) -> Result<(), DaemonError> { + let pid = read_pid_file(&path)?; + + if !is_process_running(pid) { + return Err(DaemonError::PidFile(format!( + "process {} from {} is not running", + pid, + path.as_ref().display() + ))); + } + + nix::sys::signal::kill(Pid::from_raw(pid), signal) + .map_err(|e| DaemonError::PidFile(format!("cannot signal process {}: {}", pid, e)))?; + + Ok(()) +} + +/// Returns the status of the daemon based on PID file. +#[allow(dead_code)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DaemonStatus { + /// Daemon is running with the given PID. + Running(i32), + /// PID file exists but process is not running. + Stale(i32), + /// No PID file exists. + NotRunning, +} + +/// Checks the daemon status from a PID file. +#[allow(dead_code)] +pub fn check_status>(path: P) -> DaemonStatus { + let path = path.as_ref(); + + if !path.exists() { + return DaemonStatus::NotRunning; + } + + match read_pid_file(path) { + Ok(pid) => { + if is_process_running(pid) { + DaemonStatus::Running(pid) + } else { + DaemonStatus::Stale(pid) + } + } + Err(_) => DaemonStatus::NotRunning, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_daemon_options_default() { + let opts = DaemonOptions::default(); + assert!(!opts.daemonize); + assert!(!opts.should_daemonize()); + assert_eq!(opts.pid_file_path(), Path::new(DEFAULT_PID_FILE)); + } + + #[test] + fn test_daemon_options_foreground_overrides() { + let opts = DaemonOptions { + daemonize: true, + foreground: true, + ..Default::default() + }; + assert!(!opts.should_daemonize()); + } + + #[test] + fn test_check_status_not_running() { + let path = "/tmp/telemt_test_nonexistent.pid"; + assert_eq!(check_status(path), DaemonStatus::NotRunning); + } + + #[test] + fn test_pid_file_basic() { + let path = "/tmp/telemt_test_pidfile.pid"; + let _ = fs::remove_file(path); + + let mut pf = PidFile::new(path); + assert!(pf.check_running().unwrap().is_none()); + + pf.acquire().unwrap(); + assert!(Path::new(path).exists()); + + // Read it back + let pid = read_pid_file(path).unwrap(); + assert_eq!(pid, std::process::id() as i32); + + pf.release().unwrap(); + assert!(!Path::new(path).exists()); + } +} diff --git a/src/logging.rs b/src/logging.rs new file mode 100644 index 0000000..bb381ef --- /dev/null +++ b/src/logging.rs @@ -0,0 +1,305 @@ +//! Logging configuration for telemt. +//! +//! Supports multiple log destinations: +//! - stderr (default, works with systemd journald) +//! - syslog (Unix only, for traditional init systems) +//! - file (with optional rotation) + +#![allow(dead_code)] // Infrastructure module - used via CLI flags + +use std::path::Path; + +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{EnvFilter, fmt, reload}; + +/// Log destination configuration. +#[derive(Debug, Clone, Default)] +pub enum LogDestination { + /// Log to stderr (default, captured by systemd journald). + #[default] + Stderr, + /// Log to syslog (Unix only). + #[cfg(unix)] + Syslog, + /// Log to a file with optional rotation. + File { + path: String, + /// Rotate daily if true. + rotate_daily: bool, + }, +} + +/// Logging options parsed from CLI/config. +#[derive(Debug, Clone, Default)] +pub struct LoggingOptions { + /// Where to send logs. + pub destination: LogDestination, + /// Disable ANSI colors. + pub disable_colors: bool, +} + +/// Guard that must be held to keep file logging active. +/// When dropped, flushes and closes log files. +pub struct LoggingGuard { + _guard: Option, +} + +impl LoggingGuard { + fn new(guard: Option) -> Self { + Self { _guard: guard } + } + + /// Creates a no-op guard for stderr/syslog logging. + pub fn noop() -> Self { + Self { _guard: None } + } +} + +/// Initialize the tracing subscriber with the specified options. +/// +/// Returns a reload handle for dynamic log level changes and a guard +/// that must be kept alive for file logging. +pub fn init_logging( + opts: &LoggingOptions, + initial_filter: &str, +) -> ( + reload::Handle, + LoggingGuard, +) { + let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new(initial_filter)); + + match &opts.destination { + LogDestination::Stderr => { + let fmt_layer = fmt::Layer::default() + .with_ansi(!opts.disable_colors) + .with_target(true); + + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .init(); + + (filter_handle, LoggingGuard::noop()) + } + + #[cfg(unix)] + LogDestination::Syslog => { + // Use a custom fmt layer that writes to syslog + let fmt_layer = fmt::Layer::default() + .with_ansi(false) + .with_target(true) + .with_writer(SyslogWriter::new); + + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .init(); + + (filter_handle, LoggingGuard::noop()) + } + + LogDestination::File { path, rotate_daily } => { + let (non_blocking, guard) = if *rotate_daily { + // Extract directory and filename prefix + let path = Path::new(path); + let dir = path.parent().unwrap_or(Path::new("/var/log")); + let prefix = path + .file_name() + .and_then(|s| s.to_str()) + .unwrap_or("telemt"); + + let file_appender = tracing_appender::rolling::daily(dir, prefix); + tracing_appender::non_blocking(file_appender) + } else { + let file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(path) + .expect("Failed to open log file"); + tracing_appender::non_blocking(file) + }; + + let fmt_layer = fmt::Layer::default() + .with_ansi(false) + .with_target(true) + .with_writer(non_blocking); + + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .init(); + + (filter_handle, LoggingGuard::new(Some(guard))) + } + } +} + +/// Syslog writer for tracing. +#[cfg(unix)] +struct SyslogWriter { + _private: (), +} + +#[cfg(unix)] +impl SyslogWriter { + fn new() -> Self { + // Open syslog connection on first use + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(|| { + unsafe { + // Open syslog with ident "telemt", LOG_PID, LOG_DAEMON facility + let ident = b"telemt\0".as_ptr() as *const libc::c_char; + libc::openlog(ident, libc::LOG_PID | libc::LOG_NDELAY, libc::LOG_DAEMON); + } + }); + Self { _private: () } + } +} + +#[cfg(unix)] +impl std::io::Write for SyslogWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + // Convert to C string, stripping newlines + let msg = String::from_utf8_lossy(buf); + let msg = msg.trim_end(); + + if msg.is_empty() { + return Ok(buf.len()); + } + + // Determine priority based on log level in the message + let priority = if msg.contains(" ERROR ") || msg.contains(" error ") { + libc::LOG_ERR + } else if msg.contains(" WARN ") || msg.contains(" warn ") { + libc::LOG_WARNING + } else if msg.contains(" INFO ") || msg.contains(" info ") { + libc::LOG_INFO + } else if msg.contains(" DEBUG ") || msg.contains(" debug ") { + libc::LOG_DEBUG + } else { + libc::LOG_INFO + }; + + // Write to syslog + let c_msg = std::ffi::CString::new(msg.as_bytes()) + .unwrap_or_else(|_| std::ffi::CString::new("(invalid utf8)").unwrap()); + + unsafe { + libc::syslog( + priority, + b"%s\0".as_ptr() as *const libc::c_char, + c_msg.as_ptr(), + ); + } + + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +#[cfg(unix)] +impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for SyslogWriter { + type Writer = SyslogWriter; + + fn make_writer(&'a self) -> Self::Writer { + SyslogWriter::new() + } +} + +/// Parse log destination from CLI arguments. +pub fn parse_log_destination(args: &[String]) -> LogDestination { + let mut i = 0; + while i < args.len() { + match args[i].as_str() { + #[cfg(unix)] + "--syslog" => { + return LogDestination::Syslog; + } + "--log-file" => { + i += 1; + if i < args.len() { + return LogDestination::File { + path: args[i].clone(), + rotate_daily: false, + }; + } + } + s if s.starts_with("--log-file=") => { + return LogDestination::File { + path: s.trim_start_matches("--log-file=").to_string(), + rotate_daily: false, + }; + } + "--log-file-daily" => { + i += 1; + if i < args.len() { + return LogDestination::File { + path: args[i].clone(), + rotate_daily: true, + }; + } + } + s if s.starts_with("--log-file-daily=") => { + return LogDestination::File { + path: s.trim_start_matches("--log-file-daily=").to_string(), + rotate_daily: true, + }; + } + _ => {} + } + i += 1; + } + LogDestination::Stderr +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_log_destination_default() { + let args: Vec = vec![]; + assert!(matches!( + parse_log_destination(&args), + LogDestination::Stderr + )); + } + + #[test] + fn test_parse_log_destination_file() { + let args = vec!["--log-file".to_string(), "/var/log/telemt.log".to_string()]; + match parse_log_destination(&args) { + LogDestination::File { path, rotate_daily } => { + assert_eq!(path, "/var/log/telemt.log"); + assert!(!rotate_daily); + } + _ => panic!("Expected File destination"), + } + } + + #[test] + fn test_parse_log_destination_file_daily() { + let args = vec!["--log-file-daily=/var/log/telemt".to_string()]; + match parse_log_destination(&args) { + LogDestination::File { path, rotate_daily } => { + assert_eq!(path, "/var/log/telemt"); + assert!(rotate_daily); + } + _ => panic!("Expected File destination"), + } + } + + #[cfg(unix)] + #[test] + fn test_parse_log_destination_syslog() { + let args = vec!["--syslog".to_string()]; + assert!(matches!( + parse_log_destination(&args), + LogDestination::Syslog + )); + } +} diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index 032460c..d9d8e8b 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -8,6 +8,7 @@ use tracing::{debug, error, info, warn}; use crate::cli; use crate::config::ProxyConfig; +use crate::logging::LogDestination; use crate::transport::UpstreamManager; use crate::transport::middle_proxy::{ ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, @@ -27,7 +28,16 @@ pub(crate) fn resolve_runtime_config_path( absolute.canonicalize().unwrap_or(absolute) } -pub(crate) fn parse_cli() -> (String, Option, bool, Option) { +/// Parsed CLI arguments. +pub(crate) struct CliArgs { + pub config_path: String, + pub data_path: Option, + pub silent: bool, + pub log_level: Option, + pub log_destination: LogDestination, +} + +pub(crate) fn parse_cli() -> CliArgs { let mut config_path = "config.toml".to_string(); let mut data_path: Option = None; let mut silent = false; @@ -35,6 +45,9 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { let args: Vec = std::env::args().skip(1).collect(); + // Parse log destination + let log_destination = crate::logging::parse_log_destination(&args); + // Check for --init first (handled before tokio) if let Some(init_opts) = cli::parse_init_args(&args) { if let Err(e) = cli::run_init(init_opts) { @@ -74,36 +87,35 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { log_level = Some(s.trim_start_matches("--log-level=").to_string()); } "--help" | "-h" => { - eprintln!("Usage: telemt [config.toml] [OPTIONS]"); - eprintln!(); - eprintln!("Options:"); - eprintln!( - " --data-path Set data directory (absolute path; overrides config value)" - ); - eprintln!(" --silent, -s Suppress info logs"); - eprintln!(" --log-level debug|verbose|normal|silent"); - eprintln!(" --help, -h Show this help"); - eprintln!(); - eprintln!("Setup (fire-and-forget):"); - eprintln!( - " --init Generate config, install systemd service, start" - ); - eprintln!(" --port Listen port (default: 443)"); - eprintln!( - " --domain TLS domain for masking (default: www.google.com)" - ); - eprintln!( - " --secret 32-char hex secret (auto-generated if omitted)" - ); - eprintln!(" --user Username (default: user)"); - eprintln!(" --config-dir Config directory (default: /etc/telemt)"); - eprintln!(" --no-start Don't start the service after install"); + print_help(); std::process::exit(0); } "--version" | "-V" => { println!("telemt {}", env!("CARGO_PKG_VERSION")); std::process::exit(0); } + // Skip daemon-related flags (already parsed) + "--daemon" | "-d" | "--foreground" | "-f" => {} + s if s.starts_with("--pid-file") => { + if !s.contains('=') { + i += 1; // skip value + } + } + s if s.starts_with("--run-as-user") => { + if !s.contains('=') { + i += 1; + } + } + s if s.starts_with("--run-as-group") => { + if !s.contains('=') { + i += 1; + } + } + s if s.starts_with("--working-dir") => { + if !s.contains('=') { + i += 1; + } + } s if !s.starts_with('-') => { config_path = s.to_string(); } @@ -114,7 +126,73 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { i += 1; } - (config_path, data_path, silent, log_level) + CliArgs { + config_path, + data_path, + silent, + log_level, + log_destination, + } +} + +fn print_help() { + eprintln!("Usage: telemt [COMMAND] [OPTIONS] [config.toml]"); + eprintln!(); + eprintln!("Commands:"); + eprintln!(" run Run in foreground (default if no command given)"); + #[cfg(unix)] + { + eprintln!(" start Start as background daemon"); + eprintln!(" stop Stop a running daemon"); + eprintln!(" reload Reload configuration (send SIGHUP)"); + eprintln!(" status Check if daemon is running"); + } + eprintln!(); + eprintln!("Options:"); + eprintln!( + " --data-path Set data directory (absolute path; overrides config value)" + ); + eprintln!(" --silent, -s Suppress info logs"); + eprintln!(" --log-level debug|verbose|normal|silent"); + eprintln!(" --help, -h Show this help"); + eprintln!(" --version, -V Show version"); + eprintln!(); + eprintln!("Logging options:"); + eprintln!(" --log-file Log to file (default: stderr)"); + eprintln!(" --log-file-daily Log to file with daily rotation"); + #[cfg(unix)] + eprintln!(" --syslog Log to syslog (Unix only)"); + eprintln!(); + #[cfg(unix)] + { + eprintln!("Daemon options (Unix only):"); + eprintln!(" --daemon, -d Fork to background (daemonize)"); + eprintln!(" --foreground, -f Explicit foreground mode (for systemd)"); + eprintln!(" --pid-file PID file path (default: /var/run/telemt.pid)"); + eprintln!(" --run-as-user Drop privileges to this user after binding"); + eprintln!(" --run-as-group Drop privileges to this group after binding"); + eprintln!(" --working-dir Working directory for daemon mode"); + eprintln!(); + } + eprintln!("Setup (fire-and-forget):"); + eprintln!(" --init Generate config, install systemd service, start"); + eprintln!(" --port Listen port (default: 443)"); + eprintln!(" --domain TLS domain for masking (default: www.google.com)"); + eprintln!(" --secret 32-char hex secret (auto-generated if omitted)"); + eprintln!(" --user Username (default: user)"); + eprintln!(" --config-dir Config directory (default: /etc/telemt)"); + eprintln!(" --no-start Don't start the service after install"); + #[cfg(unix)] + { + eprintln!(); + eprintln!("Examples:"); + eprintln!(" telemt config.toml Run in foreground"); + eprintln!(" telemt start config.toml Start as daemon"); + eprintln!(" telemt start --pid-file /tmp/t.pid Start with custom PID file"); + eprintln!(" telemt stop Stop daemon"); + eprintln!(" telemt reload Reload configuration"); + eprintln!(" telemt status Check daemon status"); + } } #[cfg(test)] diff --git a/src/maestro/listeners.rs b/src/maestro/listeners.rs index effaff8..3b2a92f 100644 --- a/src/maestro/listeners.rs +++ b/src/maestro/listeners.rs @@ -72,6 +72,7 @@ pub(crate) async fn bind_listeners( let options = ListenOptions { reuse_port: listener_conf.reuse_allow, ipv6_only: listener_conf.ip.is_ipv6(), + backlog: config.server.listen_backlog, ..Default::default() }; diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index 5f3fd3a..aa95cb6 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -47,8 +47,55 @@ use crate::transport::UpstreamManager; use crate::transport::middle_proxy::MePool; use helpers::{parse_cli, resolve_runtime_config_path}; +#[cfg(unix)] +use crate::daemon::{DaemonOptions, PidFile, drop_privileges}; + /// Runs the full telemt runtime startup pipeline and blocks until shutdown. +/// +/// On Unix, daemon options should be handled before calling this function +/// (daemonization must happen before tokio runtime starts). +#[cfg(unix)] +pub async fn run_with_daemon( + daemon_opts: DaemonOptions, +) -> std::result::Result<(), Box> { + run_inner(daemon_opts).await +} + +/// Runs the full telemt runtime startup pipeline and blocks until shutdown. +/// +/// This is the main entry point for non-daemon mode or when called as a library. +#[allow(dead_code)] pub async fn run() -> std::result::Result<(), Box> { + #[cfg(unix)] + { + // Parse CLI to get daemon options even in simple run() path + let args: Vec = std::env::args().skip(1).collect(); + let daemon_opts = crate::cli::parse_daemon_args(&args); + run_inner(daemon_opts).await + } + #[cfg(not(unix))] + { + run_inner().await + } +} + +#[cfg(unix)] +async fn run_inner( + daemon_opts: DaemonOptions, +) -> std::result::Result<(), Box> { + // Acquire PID file if daemonizing or if explicitly requested + // Keep it alive until shutdown (underscore prefix = intentionally kept for RAII cleanup) + let _pid_file = if daemon_opts.daemonize || daemon_opts.pid_file.is_some() { + let mut pf = PidFile::new(daemon_opts.pid_file_path()); + if let Err(e) = pf.acquire() { + eprintln!("[telemt] {}", e); + std::process::exit(1); + } + Some(pf) + } else { + None + }; + let process_started_at = Instant::now(); let process_started_at_epoch_secs = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -61,7 +108,12 @@ pub async fn run() -> std::result::Result<(), Box> { Some("load and validate config".to_string()), ) .await; - let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); + let cli_args = parse_cli(); + let config_path_cli = cli_args.config_path; + let data_path = cli_args.data_path; + let cli_silent = cli_args.silent; + let cli_log_level = cli_args.log_level; + let log_destination = cli_args.log_destination; let startup_cwd = match std::env::current_dir() { Ok(cwd) => cwd, Err(e) => { @@ -159,17 +211,43 @@ pub async fn run() -> std::result::Result<(), Box> { ) .await; - // Configure color output based on config - let fmt_layer = if config.general.disable_colors { - fmt::Layer::default().with_ansi(false) - } else { - fmt::Layer::default().with_ansi(true) - }; + // Initialize logging based on destination + let _logging_guard: Option; + match log_destination { + crate::logging::LogDestination::Stderr => { + // Default: log to stderr (works with systemd journald) + let fmt_layer = if config.general.disable_colors { + fmt::Layer::default().with_ansi(false) + } else { + fmt::Layer::default().with_ansi(true) + }; + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .init(); + _logging_guard = None; + } + #[cfg(unix)] + crate::logging::LogDestination::Syslog => { + // Syslog: for OpenRC/FreeBSD + let logging_opts = crate::logging::LoggingOptions { + destination: log_destination, + disable_colors: true, + }; + let (_, guard) = crate::logging::init_logging(&logging_opts, "info"); + _logging_guard = Some(guard); + } + crate::logging::LogDestination::File { .. } => { + // File logging with optional rotation + let logging_opts = crate::logging::LoggingOptions { + destination: log_destination, + disable_colors: true, + }; + let (_, guard) = crate::logging::init_logging(&logging_opts, "info"); + _logging_guard = Some(guard); + } + } - tracing_subscriber::registry() - .with(filter_layer) - .with(fmt_layer) - .init(); startup_tracker .complete_component( COMPONENT_TRACING_INIT, @@ -223,6 +301,7 @@ pub async fn run() -> std::result::Result<(), Box> { config.general.upstream_connect_retry_attempts, config.general.upstream_connect_retry_backoff_ms, config.general.upstream_connect_budget_ms, + config.general.tg_connect, config.general.upstream_unhealthy_fail_threshold, config.general.upstream_connect_failfast_hard_errors, stats.clone(), @@ -583,6 +662,14 @@ pub async fn run() -> std::result::Result<(), Box> { std::process::exit(1); } + // Drop privileges after binding sockets (which may require root for port < 1024) + if daemon_opts.user.is_some() || daemon_opts.group.is_some() { + if let Err(e) = drop_privileges(daemon_opts.user.as_deref(), daemon_opts.group.as_deref()) { + error!(error = %e, "Failed to drop privileges"); + std::process::exit(1); + } + } + runtime_tasks::apply_runtime_log_filter( has_rust_log, &effective_log_level, @@ -603,6 +690,9 @@ pub async fn run() -> std::result::Result<(), Box> { runtime_tasks::mark_runtime_ready(&startup_tracker).await; + // Spawn signal handlers for SIGUSR1/SIGUSR2 (non-shutdown signals) + shutdown::spawn_signal_handlers(stats.clone(), process_started_at); + listeners::spawn_tcp_accept_loops( listeners, config_rx.clone(), @@ -620,7 +710,7 @@ pub async fn run() -> std::result::Result<(), Box> { max_connections.clone(), ); - shutdown::wait_for_shutdown(process_started_at, me_pool).await; + shutdown::wait_for_shutdown(process_started_at, me_pool, stats).await; Ok(()) } diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d553eb9..b8b10da 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -323,10 +323,12 @@ pub(crate) async fn spawn_metrics_if_configured( let config_rx_metrics = config_rx.clone(); let ip_tracker_metrics = ip_tracker.clone(); let whitelist = config.server.metrics_whitelist.clone(); + let listen_backlog = config.server.listen_backlog; tokio::spawn(async move { metrics::serve( port, listen, + listen_backlog, stats, beobachten, ip_tracker_metrics, diff --git a/src/maestro/shutdown.rs b/src/maestro/shutdown.rs index 243c772..f6e50ca 100644 --- a/src/maestro/shutdown.rs +++ b/src/maestro/shutdown.rs @@ -1,45 +1,206 @@ +//! Shutdown and signal handling for telemt. +//! +//! Handles graceful shutdown on various signals: +//! - SIGINT (Ctrl+C) / SIGTERM: Graceful shutdown +//! - SIGQUIT: Graceful shutdown with stats dump +//! - SIGUSR1: Reserved for log rotation (logs acknowledgment) +//! - SIGUSR2: Dump runtime status to log +//! +//! SIGHUP is handled separately in config/hot_reload.rs for config reload. + use std::sync::Arc; use std::time::{Duration, Instant}; +#[cfg(not(unix))] use tokio::signal; -use tracing::{error, info, warn}; +#[cfg(unix)] +use tokio::signal::unix::{SignalKind, signal}; +use tracing::{info, warn}; +use crate::stats::Stats; use crate::transport::middle_proxy::MePool; use super::helpers::{format_uptime, unit_label}; -pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Option>) { - match signal::ctrl_c().await { - Ok(()) => { - let shutdown_started_at = Instant::now(); - info!("Shutting down..."); - let uptime_secs = process_started_at.elapsed().as_secs(); - info!("Uptime: {}", format_uptime(uptime_secs)); - if let Some(pool) = &me_pool { - match tokio::time::timeout( - Duration::from_secs(2), - pool.shutdown_send_close_conn_all(), - ) - .await - { - Ok(total) => { - info!( - close_conn_sent = total, - "ME shutdown: RPC_CLOSE_CONN broadcast completed" - ); - } - Err(_) => { - warn!("ME shutdown: RPC_CLOSE_CONN broadcast timed out"); - } - } - } - let shutdown_secs = shutdown_started_at.elapsed().as_secs(); - info!( - "Shutdown completed successfully in {} {}.", - shutdown_secs, - unit_label(shutdown_secs, "second", "seconds") - ); +/// Signal that triggered shutdown. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ShutdownSignal { + /// SIGINT (Ctrl+C) + Interrupt, + /// SIGTERM + Terminate, + /// SIGQUIT (with stats dump) + Quit, +} + +impl std::fmt::Display for ShutdownSignal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ShutdownSignal::Interrupt => write!(f, "SIGINT"), + ShutdownSignal::Terminate => write!(f, "SIGTERM"), + ShutdownSignal::Quit => write!(f, "SIGQUIT"), } - Err(e) => error!("Signal error: {}", e), } } + +/// Waits for a shutdown signal and performs graceful shutdown. +pub(crate) async fn wait_for_shutdown( + process_started_at: Instant, + me_pool: Option>, + stats: Arc, +) { + let signal = wait_for_shutdown_signal().await; + perform_shutdown(signal, process_started_at, me_pool, &stats).await; +} + +/// Waits for any shutdown signal (SIGINT, SIGTERM, SIGQUIT). +#[cfg(unix)] +async fn wait_for_shutdown_signal() -> ShutdownSignal { + let mut sigint = signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler"); + let mut sigterm = signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler"); + let mut sigquit = signal(SignalKind::quit()).expect("Failed to register SIGQUIT handler"); + + tokio::select! { + _ = sigint.recv() => ShutdownSignal::Interrupt, + _ = sigterm.recv() => ShutdownSignal::Terminate, + _ = sigquit.recv() => ShutdownSignal::Quit, + } +} + +#[cfg(not(unix))] +async fn wait_for_shutdown_signal() -> ShutdownSignal { + signal::ctrl_c().await.expect("Failed to listen for Ctrl+C"); + ShutdownSignal::Interrupt +} + +/// Performs graceful shutdown sequence. +async fn perform_shutdown( + signal: ShutdownSignal, + process_started_at: Instant, + me_pool: Option>, + stats: &Stats, +) { + let shutdown_started_at = Instant::now(); + info!(signal = %signal, "Received shutdown signal"); + + // Dump stats if SIGQUIT + if signal == ShutdownSignal::Quit { + dump_stats(stats, process_started_at); + } + + info!("Shutting down..."); + let uptime_secs = process_started_at.elapsed().as_secs(); + info!("Uptime: {}", format_uptime(uptime_secs)); + + // Graceful ME pool shutdown + if let Some(pool) = &me_pool { + match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()) + .await + { + Ok(total) => { + info!( + close_conn_sent = total, + "ME shutdown: RPC_CLOSE_CONN broadcast completed" + ); + } + Err(_) => { + warn!("ME shutdown: RPC_CLOSE_CONN broadcast timed out"); + } + } + } + + let shutdown_secs = shutdown_started_at.elapsed().as_secs(); + info!( + "Shutdown completed successfully in {} {}.", + shutdown_secs, + unit_label(shutdown_secs, "second", "seconds") + ); +} + +/// Dumps runtime statistics to the log. +fn dump_stats(stats: &Stats, process_started_at: Instant) { + let uptime_secs = process_started_at.elapsed().as_secs(); + + info!("=== Runtime Statistics Dump ==="); + info!("Uptime: {}", format_uptime(uptime_secs)); + + // Connection stats + info!( + "Connections: total={}, current={} (direct={}, me={}), bad={}", + stats.get_connects_all(), + stats.get_current_connections_total(), + stats.get_current_connections_direct(), + stats.get_current_connections_me(), + stats.get_connects_bad(), + ); + + // ME pool stats + info!( + "ME keepalive: sent={}, pong={}, failed={}, timeout={}", + stats.get_me_keepalive_sent(), + stats.get_me_keepalive_pong(), + stats.get_me_keepalive_failed(), + stats.get_me_keepalive_timeout(), + ); + + // Relay stats + info!( + "Relay idle: soft_mark={}, hard_close={}, pressure_evict={}", + stats.get_relay_idle_soft_mark_total(), + stats.get_relay_idle_hard_close_total(), + stats.get_relay_pressure_evict_total(), + ); + + info!("=== End Statistics Dump ==="); +} + +/// Spawns a background task to handle operational signals (SIGUSR1, SIGUSR2). +/// +/// These signals don't trigger shutdown but perform specific actions: +/// - SIGUSR1: Log rotation acknowledgment (for external log rotation tools) +/// - SIGUSR2: Dump runtime status to log +#[cfg(unix)] +pub(crate) fn spawn_signal_handlers(stats: Arc, process_started_at: Instant) { + tokio::spawn(async move { + let mut sigusr1 = + signal(SignalKind::user_defined1()).expect("Failed to register SIGUSR1 handler"); + let mut sigusr2 = + signal(SignalKind::user_defined2()).expect("Failed to register SIGUSR2 handler"); + + loop { + tokio::select! { + _ = sigusr1.recv() => { + handle_sigusr1(); + } + _ = sigusr2.recv() => { + handle_sigusr2(&stats, process_started_at); + } + } + } + }); +} + +/// No-op on non-Unix platforms. +#[cfg(not(unix))] +pub(crate) fn spawn_signal_handlers(_stats: Arc, _process_started_at: Instant) { + // No SIGUSR1/SIGUSR2 on non-Unix +} + +/// Handles SIGUSR1 - log rotation signal. +/// +/// This signal is typically sent by logrotate or similar tools after +/// rotating log files. Since tracing-subscriber doesn't natively support +/// reopening files, we just acknowledge the signal. If file logging is +/// added in the future, this would reopen log file handles. +#[cfg(unix)] +fn handle_sigusr1() { + info!("SIGUSR1 received - log rotation acknowledged"); + // Future: If using file-based logging, reopen file handles here +} + +/// Handles SIGUSR2 - dump runtime status. +#[cfg(unix)] +fn handle_sigusr2(stats: &Stats, process_started_at: Instant) { + info!("SIGUSR2 received - dumping runtime status"); + dump_stats(stats, process_started_at); +} diff --git a/src/main.rs b/src/main.rs index e5d931f..68c89fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,8 @@ mod api; mod cli; mod config; mod crypto; +#[cfg(unix)] +mod daemon; mod error; mod ip_tracker; #[cfg(test)] @@ -15,11 +17,13 @@ mod ip_tracker_hotpath_adversarial_tests; #[cfg(test)] #[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; +mod logging; mod maestro; mod metrics; mod network; mod protocol; mod proxy; +mod service; mod startup; mod stats; mod stream; @@ -27,8 +31,49 @@ mod tls_front; mod transport; mod util; -#[tokio::main] -async fn main() -> std::result::Result<(), Box> { +fn main() -> std::result::Result<(), Box> { + // Install rustls crypto provider early let _ = rustls::crypto::ring::default_provider().install_default(); - maestro::run().await + + let args: Vec = std::env::args().skip(1).collect(); + let cmd = cli::parse_command(&args); + + // Handle subcommands that don't need the server (stop, reload, status, init) + if let Some(exit_code) = cli::execute_subcommand(&cmd) { + std::process::exit(exit_code); + } + + #[cfg(unix)] + { + let daemon_opts = cmd.daemon_opts; + + // Daemonize BEFORE runtime + if daemon_opts.should_daemonize() { + match daemon::daemonize(daemon_opts.working_dir.as_deref()) { + Ok(daemon::DaemonizeResult::Parent) => { + std::process::exit(0); + } + Ok(daemon::DaemonizeResult::Child) => { + // continue + } + Err(e) => { + eprintln!("[telemt] Daemonization failed: {}", e); + std::process::exit(1); + } + } + } + + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()? + .block_on(maestro::run_with_daemon(daemon_opts)) + } + + #[cfg(not(unix))] + { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()? + .block_on(maestro::run()) + } } diff --git a/src/metrics.rs b/src/metrics.rs index 2c87ed6..3a88a5b 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -22,6 +22,7 @@ use crate::transport::{ListenOptions, create_listener}; pub async fn serve( port: u16, listen: Option, + listen_backlog: u32, stats: Arc, beobachten: Arc, ip_tracker: Arc, @@ -40,7 +41,7 @@ pub async fn serve( } }; let is_ipv6 = addr.is_ipv6(); - match bind_metrics_listener(addr, is_ipv6) { + match bind_metrics_listener(addr, is_ipv6, listen_backlog) { Ok(listener) => { info!("Metrics endpoint: http://{}/metrics and /beobachten", addr); serve_listener( @@ -60,7 +61,7 @@ pub async fn serve( let mut listener_v6 = None; let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port)); - match bind_metrics_listener(addr_v4, false) { + match bind_metrics_listener(addr_v4, false, listen_backlog) { Ok(listener) => { info!( "Metrics endpoint: http://{}/metrics and /beobachten", @@ -74,7 +75,7 @@ pub async fn serve( } let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port)); - match bind_metrics_listener(addr_v6, true) { + match bind_metrics_listener(addr_v6, true, listen_backlog) { Ok(listener) => { info!( "Metrics endpoint: http://[::]:{}/metrics and /beobachten", @@ -122,10 +123,15 @@ pub async fn serve( } } -fn bind_metrics_listener(addr: SocketAddr, ipv6_only: bool) -> std::io::Result { +fn bind_metrics_listener( + addr: SocketAddr, + ipv6_only: bool, + listen_backlog: u32, +) -> std::io::Result { let options = ListenOptions { reuse_port: false, ipv6_only, + backlog: listen_backlog, ..Default::default() }; let socket = create_listener(addr, &options)?; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 8ce3e96..7472459 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -416,16 +416,68 @@ where debug!(peer = %real_peer, "New connection (generic stream)"); + let first_byte = if config.timeouts.client_first_byte_idle_secs == 0 { + None + } else { + let idle_timeout = Duration::from_secs(config.timeouts.client_first_byte_idle_secs); + let mut first_byte = [0u8; 1]; + match timeout(idle_timeout, stream.read(&mut first_byte)).await { + Ok(Ok(0)) => { + debug!(peer = %real_peer, "Connection closed before first client byte"); + return Ok(()); + } + Ok(Ok(_)) => Some(first_byte[0]), + Ok(Err(e)) + if matches!( + e.kind(), + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::BrokenPipe + | std::io::ErrorKind::NotConnected + ) => + { + debug!( + peer = %real_peer, + error = %e, + "Connection closed before first client byte" + ); + return Ok(()); + } + Ok(Err(e)) => { + debug!( + peer = %real_peer, + error = %e, + "Failed while waiting for first client byte" + ); + return Err(ProxyError::Io(e)); + } + Err(_) => { + debug!( + peer = %real_peer, + idle_secs = config.timeouts.client_first_byte_idle_secs, + "Closing idle pooled connection before first client byte" + ); + return Ok(()); + } + } + }; + let handshake_timeout = handshake_timeout_with_mask_grace(&config); let stats_for_timeout = stats.clone(); let config_for_timeout = config.clone(); let beobachten_for_timeout = beobachten.clone(); let peer_for_timeout = real_peer.ip(); - // Phase 1: handshake (with timeout) + // Phase 2: active handshake (with timeout after the first client byte) let outcome = match timeout(handshake_timeout, async { let mut first_bytes = [0u8; 5]; - stream.read_exact(&mut first_bytes).await?; + if let Some(first_byte) = first_byte { + first_bytes[0] = first_byte; + stream.read_exact(&mut first_bytes[1..]).await?; + } else { + stream.read_exact(&mut first_bytes).await?; + } let is_tls = tls::is_tls_handshake(&first_bytes[..3]); debug!(peer = %real_peer, is_tls = is_tls, "Handshake type detected"); @@ -736,36 +788,9 @@ impl RunningClientHandler { debug!(peer = %peer, error = %e, "Failed to configure client socket"); } - let handshake_timeout = handshake_timeout_with_mask_grace(&self.config); - let stats = self.stats.clone(); - let config_for_timeout = self.config.clone(); - let beobachten_for_timeout = self.beobachten.clone(); - let peer_for_timeout = peer.ip(); - - // Phase 1: handshake (with timeout) - let outcome = match timeout(handshake_timeout, self.do_handshake()).await { - Ok(Ok(outcome)) => outcome, - Ok(Err(e)) => { - debug!(peer = %peer, error = %e, "Handshake failed"); - record_handshake_failure_class( - &beobachten_for_timeout, - &config_for_timeout, - peer_for_timeout, - &e, - ); - return Err(e); - } - Err(_) => { - stats.increment_handshake_timeouts(); - debug!(peer = %peer, "Handshake timeout"); - record_beobachten_class( - &beobachten_for_timeout, - &config_for_timeout, - peer_for_timeout, - "other", - ); - return Err(ProxyError::TgHandshakeTimeout); - } + let outcome = match self.do_handshake().await? { + Some(outcome) => outcome, + None => return Ok(()), }; // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) @@ -774,7 +799,7 @@ impl RunningClientHandler { } } - async fn do_handshake(mut self) -> Result { + async fn do_handshake(mut self) -> Result> { let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; if self.proxy_protocol_enabled { @@ -849,19 +874,108 @@ impl RunningClientHandler { } } - let mut first_bytes = [0u8; 5]; - self.stream.read_exact(&mut first_bytes).await?; - - let is_tls = tls::is_tls_handshake(&first_bytes[..3]); - let peer = self.peer; - - debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); - - if is_tls { - self.handle_tls_client(first_bytes, local_addr).await + let first_byte = if self.config.timeouts.client_first_byte_idle_secs == 0 { + None } else { - self.handle_direct_client(first_bytes, local_addr).await - } + let idle_timeout = + Duration::from_secs(self.config.timeouts.client_first_byte_idle_secs); + let mut first_byte = [0u8; 1]; + match timeout(idle_timeout, self.stream.read(&mut first_byte)).await { + Ok(Ok(0)) => { + debug!(peer = %self.peer, "Connection closed before first client byte"); + return Ok(None); + } + Ok(Ok(_)) => Some(first_byte[0]), + Ok(Err(e)) + if matches!( + e.kind(), + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::BrokenPipe + | std::io::ErrorKind::NotConnected + ) => + { + debug!( + peer = %self.peer, + error = %e, + "Connection closed before first client byte" + ); + return Ok(None); + } + Ok(Err(e)) => { + debug!( + peer = %self.peer, + error = %e, + "Failed while waiting for first client byte" + ); + return Err(ProxyError::Io(e)); + } + Err(_) => { + debug!( + peer = %self.peer, + idle_secs = self.config.timeouts.client_first_byte_idle_secs, + "Closing idle pooled connection before first client byte" + ); + return Ok(None); + } + } + }; + + let handshake_timeout = handshake_timeout_with_mask_grace(&self.config); + let stats = self.stats.clone(); + let config_for_timeout = self.config.clone(); + let beobachten_for_timeout = self.beobachten.clone(); + let peer_for_timeout = self.peer.ip(); + let peer_for_log = self.peer; + + let outcome = match timeout(handshake_timeout, async { + let mut first_bytes = [0u8; 5]; + if let Some(first_byte) = first_byte { + first_bytes[0] = first_byte; + self.stream.read_exact(&mut first_bytes[1..]).await?; + } else { + self.stream.read_exact(&mut first_bytes).await?; + } + + let is_tls = tls::is_tls_handshake(&first_bytes[..3]); + let peer = self.peer; + + debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + + if is_tls { + self.handle_tls_client(first_bytes, local_addr).await + } else { + self.handle_direct_client(first_bytes, local_addr).await + } + }) + .await + { + Ok(Ok(outcome)) => outcome, + Ok(Err(e)) => { + debug!(peer = %peer_for_log, error = %e, "Handshake failed"); + record_handshake_failure_class( + &beobachten_for_timeout, + &config_for_timeout, + peer_for_timeout, + &e, + ); + return Err(e); + } + Err(_) => { + stats.increment_handshake_timeouts(); + debug!(peer = %peer_for_log, "Handshake timeout"); + record_beobachten_class( + &beobachten_for_timeout, + &config_for_timeout, + peer_for_timeout, + "other", + ); + return Err(ProxyError::TgHandshakeTimeout); + } + }; + + Ok(Some(outcome)) } async fn handle_tls_client( @@ -1252,7 +1366,11 @@ impl RunningClientHandler { .access .user_max_tcp_conns .get(user) - .map(|v| *v as u64); + .copied() + .filter(|limit| *limit > 0) + .or((config.access.user_max_tcp_conns_global_each > 0) + .then_some(config.access.user_max_tcp_conns_global_each)) + .map(|v| v as u64); if !stats.try_acquire_user_curr_connects(user, limit) { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string(), @@ -1311,7 +1429,11 @@ impl RunningClientHandler { .access .user_max_tcp_conns .get(user) - .map(|v| *v as u64); + .copied() + .filter(|limit| *limit > 0) + .or((config.access.user_max_tcp_conns_global_each > 0) + .then_some(config.access.user_max_tcp_conns_global_each)) + .map(|v| v as u64); if !stats.try_acquire_user_curr_connects(user, limit) { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string(), diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs index f462ed8..51beb24 100644 --- a/src/proxy/tests/client_clever_advanced_tests.rs +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -94,6 +94,7 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -141,6 +142,7 @@ async fn blackhat_proxy_protocol_slowloris_timeout() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -193,6 +195,7 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -239,6 +242,7 @@ async fn edge_client_stream_exactly_4_bytes_eof() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -282,6 +286,7 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -328,6 +333,7 @@ async fn integration_non_tls_modes_disabled_immediately_masks() { 1, 1, 1, + 10, 1, false, stats.clone(), diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs index e57f817..7a08ccc 100644 --- a/src/proxy/tests/client_deep_invariants_tests.rs +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -47,6 +47,7 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() 1, 1, 1, + 10, 1, false, stats.clone(), @@ -177,6 +178,7 @@ async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { 1, 1, 1, + 10, 1, false, stats.clone(), diff --git a/src/proxy/tests/client_masking_blackhat_campaign_tests.rs b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs index 88d4a58..917e799 100644 --- a/src/proxy/tests/client_masking_blackhat_campaign_tests.rs +++ b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs @@ -40,6 +40,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_budget_security_tests.rs b/src/proxy/tests/client_masking_budget_security_tests.rs index d98c780..332451c 100644 --- a/src/proxy/tests/client_masking_budget_security_tests.rs +++ b/src/proxy/tests/client_masking_budget_security_tests.rs @@ -36,6 +36,7 @@ fn build_harness(config: ProxyConfig) -> PipelineHarness { 1, 1, 1, + 10, 1, false, stats.clone(), diff --git a/src/proxy/tests/client_masking_diagnostics_security_tests.rs b/src/proxy/tests/client_masking_diagnostics_security_tests.rs index 0d9ca99..67b797b 100644 --- a/src/proxy/tests/client_masking_diagnostics_security_tests.rs +++ b/src/proxy/tests/client_masking_diagnostics_security_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs index d7ac4ef..8fa2689 100644 --- a/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs +++ b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_hard_adversarial_tests.rs b/src/proxy/tests/client_masking_hard_adversarial_tests.rs index 65e66d3..c6b0e98 100644 --- a/src/proxy/tests/client_masking_hard_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_hard_adversarial_tests.rs @@ -34,6 +34,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs index 3036f95..b5a8b4d 100644 --- a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs index e64dc03..b3fd5cb 100644 --- a/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs index b49db3c..b57ad51 100644 --- a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -47,6 +47,7 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { 1, 1, 1, + 10, 1, false, stats.clone(), diff --git a/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs index f7229ce..9ab5f78 100644 --- a/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs +++ b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs @@ -25,6 +25,7 @@ fn make_test_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs index 50aa44c..2b6f600 100644 --- a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs +++ b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs @@ -48,6 +48,7 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> RedTeamHarness { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -237,6 +238,7 @@ async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() { 1, 1, 1, + 10, 1, false, Arc::new(Stats::new()), @@ -477,6 +479,7 @@ async fn measure_invalid_probe_duration_ms(delay_ms: u64, tls_len: u16, body_sen 1, 1, 1, + 10, 1, false, Arc::new(Stats::new()), @@ -550,6 +553,7 @@ async fn capture_forwarded_probe_len(tls_len: u16, body_sent: usize) -> usize { 1, 1, 1, + 10, 1, false, Arc::new(Stats::new()), diff --git a/src/proxy/tests/client_masking_replay_timing_security_tests.rs b/src/proxy/tests/client_masking_replay_timing_security_tests.rs index c3339e8..97ed52a 100644 --- a/src/proxy/tests/client_masking_replay_timing_security_tests.rs +++ b/src/proxy/tests/client_masking_replay_timing_security_tests.rs @@ -22,6 +22,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs index 3a01a69..c4dd4db 100644 --- a/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs +++ b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs index 48e94a5..2cf98c4 100644 --- a/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs index f91e687..b0bf73e 100644 --- a/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs +++ b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_shape_hardening_security_tests.rs b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs index f2bec42..7d2380b 100644 --- a/src/proxy/tests/client_masking_shape_hardening_security_tests.rs +++ b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs @@ -20,6 +20,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_masking_stress_adversarial_tests.rs b/src/proxy/tests/client_masking_stress_adversarial_tests.rs index 5c00c63..1c8b599 100644 --- a/src/proxy/tests/client_masking_stress_adversarial_tests.rs +++ b/src/proxy/tests/client_masking_stress_adversarial_tests.rs @@ -34,6 +34,7 @@ fn new_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs index 8f9d832..21cf96c 100644 --- a/src/proxy/tests/client_more_advanced_tests.rs +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -100,6 +100,7 @@ async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -146,6 +147,7 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -195,6 +197,7 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() { 1, 1, 1, + 10, 1, false, stats.clone(), diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 1b46c6d..d585326 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -1,8 +1,10 @@ use super::*; use crate::config::{UpstreamConfig, UpstreamType}; -use crate::crypto::AesCtr; -use crate::crypto::sha256_hmac; -use crate::protocol::constants::ProtoTag; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; +use crate::protocol::constants::{ + DC_IDX_POS, HANDSHAKE_LEN, IV_LEN, PREKEY_LEN, PROTO_TAG_POS, ProtoTag, SKIP_LEN, + TLS_RECORD_CHANGE_CIPHER, +}; use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; @@ -339,6 +341,7 @@ async fn relay_task_abort_releases_user_gate_and_ip_reservation() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -452,6 +455,7 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -575,6 +579,7 @@ async fn integration_route_cutover_and_quota_overlap_fails_closed_and_releases_s 1, 1, 1, + 10, 1, false, stats.clone(), @@ -744,6 +749,7 @@ async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -820,6 +826,7 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load( 1, 1, 1, + 10, 1, false, stats.clone(), @@ -979,6 +986,7 @@ async fn short_tls_probe_is_masked_through_client_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1066,6 +1074,7 @@ async fn tls12_record_probe_is_masked_through_client_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1151,6 +1160,7 @@ async fn handle_client_stream_increments_connects_all_exactly_once() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1243,6 +1253,7 @@ async fn running_client_handler_increments_connects_all_exactly_once() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1310,6 +1321,163 @@ async fn running_client_handler_increments_connects_all_exactly_once() { ); } +#[tokio::test(start_paused = true)] +async fn idle_pooled_connection_closes_cleanly_in_generic_stream_path() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_first_byte_idle_secs = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 10, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, _client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.169:55200".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + // Let the spawned handler arm the idle-phase timeout before advancing paused time. + tokio::task::yield_now().await; + tokio::time::advance(Duration::from_secs(2)).await; + tokio::task::yield_now().await; + + let result = tokio::time::timeout(Duration::from_secs(1), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); + assert_eq!(stats.get_connects_bad(), 0); +} + +#[tokio::test(start_paused = true)] +async fn idle_pooled_connection_closes_cleanly_in_client_handler_path() { + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_first_byte_idle_secs = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 10, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let _client = TcpStream::connect(front_addr).await.unwrap(); + + // Let the accepted connection reach the idle wait before advancing paused time. + tokio::task::yield_now().await; + tokio::time::advance(Duration::from_secs(2)).await; + tokio::task::yield_now().await; + + let result = tokio::time::timeout(Duration::from_secs(1), server_task) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); + assert_eq!(stats.get_connects_bad(), 0); +} + #[tokio::test] async fn partial_tls_header_stall_triggers_handshake_timeout() { let mut cfg = ProxyConfig::default(); @@ -1332,6 +1500,7 @@ async fn partial_tls_header_stall_triggers_handshake_timeout() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1477,6 +1646,148 @@ fn wrap_tls_application_data(payload: &[u8]) -> Vec { record } +fn wrap_tls_ccs_record() -> Vec { + let mut record = Vec::with_capacity(6); + record.push(TLS_RECORD_CHANGE_CIPHER); + record.extend_from_slice(&[0x03, 0x03]); + record.extend_from_slice(&1u16.to_be_bytes()); + record.push(0x01); + record +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[tokio::test] +async fn fragmented_tls_mtproto_with_interleaved_ccs_is_accepted() { + let secret_hex = "55555555555555555555555555555555"; + let secret = [0x55u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let mtproto_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let rng = SecureRandom::new(); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.85:55007".parse().unwrap(); + let (read_half, write_half) = tokio::io::split(server_side); + + let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake( + &client_hello, + read_half, + write_half, + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + { + HandshakeResult::Success(result) => result, + _ => panic!("expected successful TLS handshake"), + }; + + let mut tls_response_head = [0u8; 5]; + client_side + .read_exact(&mut tls_response_head) + .await + .unwrap(); + assert_eq!(tls_response_head[0], 0x16); + let tls_response_len = + u16::from_be_bytes([tls_response_head[3], tls_response_head[4]]) as usize; + let mut tls_response_body = vec![0u8; tls_response_len]; + client_side + .read_exact(&mut tls_response_body) + .await + .unwrap(); + + client_side + .write_all(&wrap_tls_application_data(&mtproto_handshake[..13])) + .await + .unwrap(); + client_side.write_all(&wrap_tls_ccs_record()).await.unwrap(); + client_side + .write_all(&wrap_tls_application_data(&mtproto_handshake[13..37])) + .await + .unwrap(); + client_side.write_all(&wrap_tls_ccs_record()).await.unwrap(); + client_side + .write_all(&wrap_tls_application_data(&mtproto_handshake[37..])) + .await + .unwrap(); + + let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await.unwrap(); + assert_eq!(&mtproto_data[..], &mtproto_handshake); + + let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into().unwrap(); + let (_, _, success) = match handle_mtproto_handshake( + &mtproto_handshake, + tls_reader, + tls_writer, + peer, + &config, + &replay_checker, + true, + Some(tls_user.as_str()), + ) + .await + { + HandshakeResult::Success(result) => result, + _ => panic!("expected successful MTProto handshake"), + }; + + assert_eq!(success.user, "user"); + assert_eq!(success.proto_tag, ProtoTag::Secure); + assert_eq!(success.dc_idx, 2); +} + #[tokio::test] async fn valid_tls_path_does_not_fall_back_to_mask_backend() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -1514,6 +1825,7 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1622,6 +1934,7 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1728,6 +2041,7 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1849,6 +2163,7 @@ async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1941,6 +2256,7 @@ async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -2039,6 +2355,7 @@ async fn burst_invalid_tls_probes_are_masked_verbatim() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -2217,14 +2534,16 @@ async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { } #[tokio::test] -async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { +async fn zero_tcp_limit_uses_global_fallback_and_rejects_without_side_effects() { let mut config = ProxyConfig::default(); config .access .user_max_tcp_conns .insert("user".to_string(), 0); + config.access.user_max_tcp_conns_global_each = 1; let stats = Stats::new(); + stats.increment_user_curr_connects("user"); let ip_tracker = UserIpTracker::new(); let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); @@ -2241,10 +2560,75 @@ async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { result, Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" )); + assert_eq!( + stats.get_user_curr_connects("user"), + 1, + "TCP-limit rejection must keep pre-existing in-flight connection count unchanged" + ); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn zero_tcp_limit_with_disabled_global_fallback_is_unlimited() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 0); + config.access.user_max_tcp_conns_global_each = 0; + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!( + result.is_ok(), + "per-user zero with global fallback disabled must not enforce a TCP limit" + ); assert_eq!(stats.get_user_curr_connects("user"), 0); assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); } +#[tokio::test] +async fn global_tcp_fallback_applies_when_per_user_limit_is_missing() { + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns_global_each = 1; + + let stats = Stats::new(); + stats.increment_user_curr_connects("user"); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.213:50003".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!( + stats.get_user_curr_connects("user"), + 1, + "Global fallback TCP-limit rejection must keep pre-existing counter unchanged" + ); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + #[tokio::test] async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { let user = "check-helper-user"; @@ -2876,6 +3260,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -3436,6 +3821,7 @@ async fn untrusted_proxy_header_source_is_rejected() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -3505,6 +3891,7 @@ async fn empty_proxy_trusted_cidrs_rejects_proxy_header_by_default() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -3601,6 +3988,7 @@ async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -3703,6 +4091,7 @@ async fn oversized_tls_record_is_masked_in_client_handler_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -3819,6 +4208,7 @@ async fn tls_record_len_min_minus_1_is_rejected_in_generic_stream_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -3921,6 +4311,7 @@ async fn tls_record_len_min_minus_1_is_rejected_in_client_handler_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -4026,6 +4417,7 @@ async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -4126,6 +4518,7 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { 1, 1, 1, + 10, 1, false, stats.clone(), diff --git a/src/proxy/tests/client_timing_profile_adversarial_tests.rs b/src/proxy/tests/client_timing_profile_adversarial_tests.rs index 69a9ff4..d8df19f 100644 --- a/src/proxy/tests/client_timing_profile_adversarial_tests.rs +++ b/src/proxy/tests/client_timing_profile_adversarial_tests.rs @@ -33,6 +33,7 @@ fn make_test_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_tls_clienthello_size_security_tests.rs b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs index 0c864e7..14c24b7 100644 --- a/src/proxy/tests/client_tls_clienthello_size_security_tests.rs +++ b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs @@ -35,6 +35,7 @@ fn make_test_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs index 79a8640..c757999 100644 --- a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs +++ b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs @@ -36,6 +36,7 @@ fn make_test_upstream_manager(stats: Arc) -> Arc { 1, 1, 1, + 10, 1, false, stats, diff --git a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs index 95e49f7..a4d5df8 100644 --- a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs +++ b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs @@ -50,6 +50,7 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { 1, 1, 1, + 10, 1, false, stats.clone(), diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index a731830..e139923 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -1302,6 +1302,7 @@ async fn direct_relay_abort_midflight_releases_route_gauge() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1408,6 +1409,7 @@ async fn direct_relay_cutover_midflight_releases_route_gauge() { 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1529,6 +1531,7 @@ async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_relea 1, 1, 1, + 10, 1, false, stats.clone(), @@ -1761,6 +1764,7 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() { 1, 100, 5000, + 10, 3, false, stats.clone(), @@ -1851,6 +1855,7 @@ async fn adversarial_direct_relay_cutover_integrity() { 1, 100, 5000, + 10, 3, false, stats.clone(), diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index a977409..1a705ee 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -562,9 +562,10 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u if low_info_pair_count > 0 { let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64; let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64; + let low_info_avg_jitter_budget = 0.40 + acc_quant_step; assert!( - low_info_hardened_avg <= low_info_baseline_avg + 0.40, - "normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}" + low_info_hardened_avg <= low_info_baseline_avg + low_info_avg_jitter_budget, + "normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3} tolerated={low_info_avg_jitter_budget:.3}" ); } diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 0000000..7a6e4f6 --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,388 @@ +//! Service manager integration for telemt. +//! +//! Supports generating service files for: +//! - systemd (Linux) +//! - OpenRC (Alpine, Gentoo) +//! - rc.d (FreeBSD) + +use std::path::Path; + +/// Detected init/service system. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InitSystem { + /// systemd (most modern Linux distributions) + Systemd, + /// OpenRC (Alpine, Gentoo, some BSDs) + OpenRC, + /// FreeBSD rc.d + FreeBSDRc, + /// No known init system detected + Unknown, +} + +impl std::fmt::Display for InitSystem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InitSystem::Systemd => write!(f, "systemd"), + InitSystem::OpenRC => write!(f, "OpenRC"), + InitSystem::FreeBSDRc => write!(f, "FreeBSD rc.d"), + InitSystem::Unknown => write!(f, "unknown"), + } + } +} + +/// Detects the init system in use on the current host. +pub fn detect_init_system() -> InitSystem { + // Check for systemd first (most common on Linux) + if Path::new("/run/systemd/system").exists() { + return InitSystem::Systemd; + } + + // Check for OpenRC + if Path::new("/sbin/openrc-run").exists() || Path::new("/sbin/openrc").exists() { + return InitSystem::OpenRC; + } + + // Check for FreeBSD rc.d + if Path::new("/etc/rc.subr").exists() && Path::new("/etc/rc.d").exists() { + return InitSystem::FreeBSDRc; + } + + // Fallback: check if systemctl exists even without /run/systemd + if Path::new("/usr/bin/systemctl").exists() || Path::new("/bin/systemctl").exists() { + return InitSystem::Systemd; + } + + InitSystem::Unknown +} + +/// Returns the default service file path for the given init system. +pub fn service_file_path(init_system: InitSystem) -> &'static str { + match init_system { + InitSystem::Systemd => "/etc/systemd/system/telemt.service", + InitSystem::OpenRC => "/etc/init.d/telemt", + InitSystem::FreeBSDRc => "/usr/local/etc/rc.d/telemt", + InitSystem::Unknown => "/etc/init.d/telemt", + } +} + +/// Options for generating service files. +pub struct ServiceOptions<'a> { + /// Path to the telemt executable + pub exe_path: &'a Path, + /// Path to the configuration file + pub config_path: &'a Path, + /// User to run as (optional) + pub user: Option<&'a str>, + /// Group to run as (optional) + pub group: Option<&'a str>, + /// PID file path + pub pid_file: &'a str, + /// Working directory + pub working_dir: Option<&'a str>, + /// Description + pub description: &'a str, +} + +impl<'a> Default for ServiceOptions<'a> { + fn default() -> Self { + Self { + exe_path: Path::new("/usr/local/bin/telemt"), + config_path: Path::new("/etc/telemt/config.toml"), + user: Some("telemt"), + group: Some("telemt"), + pid_file: "/var/run/telemt.pid", + working_dir: Some("/var/lib/telemt"), + description: "Telemt MTProxy - Telegram MTProto Proxy", + } + } +} + +/// Generates a service file for the given init system. +pub fn generate_service_file(init_system: InitSystem, opts: &ServiceOptions) -> String { + match init_system { + InitSystem::Systemd => generate_systemd_unit(opts), + InitSystem::OpenRC => generate_openrc_script(opts), + InitSystem::FreeBSDRc => generate_freebsd_rc_script(opts), + InitSystem::Unknown => generate_systemd_unit(opts), // Default to systemd format + } +} + +/// Generates an enhanced systemd unit file. +fn generate_systemd_unit(opts: &ServiceOptions) -> String { + let user_line = opts.user.map(|u| format!("User={}", u)).unwrap_or_default(); + let group_line = opts + .group + .map(|g| format!("Group={}", g)) + .unwrap_or_default(); + let working_dir = opts + .working_dir + .map(|d| format!("WorkingDirectory={}", d)) + .unwrap_or_default(); + + format!( + r#"[Unit] +Description={description} +Documentation=https://github.com/telemt/telemt +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +ExecStart={exe} --foreground --pid-file {pid_file} {config} +ExecReload=/bin/kill -HUP $MAINPID +PIDFile={pid_file} +Restart=always +RestartSec=5 +{user} +{group} +{working_dir} + +# Resource limits +LimitNOFILE=65535 +LimitNPROC=4096 + +# Security hardening +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=true +PrivateTmp=true +PrivateDevices=true +ProtectKernelTunables=true +ProtectKernelModules=true +ProtectControlGroups=true +RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX +RestrictNamespaces=true +RestrictRealtime=true +RestrictSUIDSGID=true +MemoryDenyWriteExecute=true +LockPersonality=true + +# Allow binding to privileged ports and writing to specific paths +AmbientCapabilities=CAP_NET_BIND_SERVICE +CapabilityBoundingSet=CAP_NET_BIND_SERVICE +ReadWritePaths=/etc/telemt /var/run /var/lib/telemt + +[Install] +WantedBy=multi-user.target +"#, + description = opts.description, + exe = opts.exe_path.display(), + config = opts.config_path.display(), + pid_file = opts.pid_file, + user = user_line, + group = group_line, + working_dir = working_dir, + ) +} + +/// Generates an OpenRC init script. +fn generate_openrc_script(opts: &ServiceOptions) -> String { + let user = opts.user.unwrap_or("root"); + let group = opts.group.unwrap_or("root"); + + format!( + r#"#!/sbin/openrc-run +# OpenRC init script for telemt + +description="{description}" +command="{exe}" +command_args="--daemon --syslog --pid-file {pid_file} {config}" +command_user="{user}:{group}" +pidfile="{pid_file}" + +depend() {{ + need net + use logger + after firewall +}} + +start_pre() {{ + checkpath --directory --owner {user}:{group} --mode 0755 /var/run + checkpath --directory --owner {user}:{group} --mode 0755 /var/lib/telemt + checkpath --directory --owner {user}:{group} --mode 0755 /var/log/telemt +}} + +reload() {{ + ebegin "Reloading ${{RC_SVCNAME}}" + start-stop-daemon --signal HUP --pidfile "${{pidfile}}" + eend $? +}} +"#, + description = opts.description, + exe = opts.exe_path.display(), + config = opts.config_path.display(), + pid_file = opts.pid_file, + user = user, + group = group, + ) +} + +/// Generates a FreeBSD rc.d script. +fn generate_freebsd_rc_script(opts: &ServiceOptions) -> String { + let user = opts.user.unwrap_or("root"); + let group = opts.group.unwrap_or("wheel"); + + format!( + r#"#!/bin/sh +# +# PROVIDE: telemt +# REQUIRE: LOGIN NETWORKING +# KEYWORD: shutdown +# +# Add the following lines to /etc/rc.conf to enable telemt: +# +# telemt_enable="YES" +# telemt_config="/etc/telemt/config.toml" # optional +# telemt_user="telemt" # optional +# telemt_group="telemt" # optional +# + +. /etc/rc.subr + +name="telemt" +rcvar="telemt_enable" +desc="{description}" + +load_rc_config $name + +: ${{telemt_enable:="NO"}} +: ${{telemt_config:="{config}"}} +: ${{telemt_user:="{user}"}} +: ${{telemt_group:="{group}"}} +: ${{telemt_pidfile:="{pid_file}"}} + +pidfile="${{telemt_pidfile}}" +command="{exe}" +command_args="--daemon --syslog --pid-file ${{telemt_pidfile}} ${{telemt_config}}" + +start_precmd="telemt_prestart" +reload_cmd="telemt_reload" +extra_commands="reload" + +telemt_prestart() {{ + install -d -o ${{telemt_user}} -g ${{telemt_group}} -m 755 /var/run + install -d -o ${{telemt_user}} -g ${{telemt_group}} -m 755 /var/lib/telemt +}} + +telemt_reload() {{ + if [ -f "${{pidfile}}" ]; then + echo "Reloading ${{name}} configuration." + kill -HUP $(cat ${{pidfile}}) + else + echo "${{name}} is not running." + return 1 + fi +}} + +run_rc_command "$1" +"#, + description = opts.description, + exe = opts.exe_path.display(), + config = opts.config_path.display(), + pid_file = opts.pid_file, + user = user, + group = group, + ) +} + +/// Installation instructions for each init system. +pub fn installation_instructions(init_system: InitSystem) -> &'static str { + match init_system { + InitSystem::Systemd => { + r#"To install and enable the service: + sudo systemctl daemon-reload + sudo systemctl enable telemt + sudo systemctl start telemt + +To check status: + sudo systemctl status telemt + +To view logs: + journalctl -u telemt -f + +To reload configuration: + sudo systemctl reload telemt +"# + } + InitSystem::OpenRC => { + r#"To install and enable the service: + sudo chmod +x /etc/init.d/telemt + sudo rc-update add telemt default + sudo rc-service telemt start + +To check status: + sudo rc-service telemt status + +To reload configuration: + sudo rc-service telemt reload +"# + } + InitSystem::FreeBSDRc => { + r#"To install and enable the service: + sudo chmod +x /usr/local/etc/rc.d/telemt + sudo sysrc telemt_enable="YES" + sudo service telemt start + +To check status: + sudo service telemt status + +To reload configuration: + sudo service telemt reload +"# + } + InitSystem::Unknown => { + r#"No supported init system detected. +You may need to create a service file manually or run telemt directly: + telemt start /etc/telemt/config.toml +"# + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_systemd_unit_generation() { + let opts = ServiceOptions::default(); + let unit = generate_systemd_unit(&opts); + assert!(unit.contains("[Unit]")); + assert!(unit.contains("[Service]")); + assert!(unit.contains("[Install]")); + assert!(unit.contains("ExecReload=")); + assert!(unit.contains("PIDFile=")); + } + + #[test] + fn test_openrc_script_generation() { + let opts = ServiceOptions::default(); + let script = generate_openrc_script(&opts); + assert!(script.contains("#!/sbin/openrc-run")); + assert!(script.contains("depend()")); + assert!(script.contains("reload()")); + } + + #[test] + fn test_freebsd_rc_script_generation() { + let opts = ServiceOptions::default(); + let script = generate_freebsd_rc_script(&opts); + assert!(script.contains("#!/bin/sh")); + assert!(script.contains("PROVIDE: telemt")); + assert!(script.contains("run_rc_command")); + } + + #[test] + fn test_service_file_paths() { + assert_eq!( + service_file_path(InitSystem::Systemd), + "/etc/systemd/system/telemt.service" + ); + assert_eq!(service_file_path(InitSystem::OpenRC), "/etc/init.d/telemt"); + assert_eq!( + service_file_path(InitSystem::FreeBSDRc), + "/usr/local/etc/rc.d/telemt" + ); + } +} diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 1120eae..674f0f0 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -34,8 +34,6 @@ const NUM_DCS: usize = 5; /// Timeout for individual DC ping attempt const DC_PING_TIMEOUT_SECS: u64 = 5; -/// Timeout for direct TG DC TCP connect readiness. -const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10; /// Interval between upstream health-check cycles. const HEALTH_CHECK_INTERVAL_SECS: u64 = 30; /// Timeout for a single health-check connect attempt. @@ -319,6 +317,8 @@ pub struct UpstreamManager { connect_retry_attempts: u32, connect_retry_backoff: Duration, connect_budget: Duration, + /// Per-attempt TCP connect timeout to Telegram DC (`[general] tg_connect`, seconds). + tg_connect_timeout_secs: u64, unhealthy_fail_threshold: u32, connect_failfast_hard_errors: bool, no_upstreams_warn_epoch_ms: Arc, @@ -332,6 +332,7 @@ impl UpstreamManager { connect_retry_attempts: u32, connect_retry_backoff_ms: u64, connect_budget_ms: u64, + tg_connect_timeout_secs: u64, unhealthy_fail_threshold: u32, connect_failfast_hard_errors: bool, stats: Arc, @@ -347,6 +348,7 @@ impl UpstreamManager { connect_retry_attempts: connect_retry_attempts.max(1), connect_retry_backoff: Duration::from_millis(connect_retry_backoff_ms), connect_budget: Duration::from_millis(connect_budget_ms.max(1)), + tg_connect_timeout_secs: tg_connect_timeout_secs.max(1), unhealthy_fail_threshold: unhealthy_fail_threshold.max(1), connect_failfast_hard_errors, no_upstreams_warn_epoch_ms: Arc::new(AtomicU64::new(0)), @@ -798,7 +800,7 @@ impl UpstreamManager { } let remaining_budget = self.connect_budget.saturating_sub(elapsed); let attempt_timeout = - Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS).min(remaining_budget); + Duration::from_secs(self.tg_connect_timeout_secs).min(remaining_budget); if attempt_timeout.is_zero() { last_error = Some(ProxyError::ConnectionTimeout { addr: target.to_string(), @@ -1901,6 +1903,7 @@ mod tests { 1, 100, 1000, + 10, 1, false, Arc::new(Stats::new()),