diff --git a/src/api/mod.rs b/src/api/mod.rs index c0eab87..788c60c 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -370,7 +370,10 @@ async fn handle( let mut data: Vec = active_ips_map .into_iter() .filter(|(_, ips)| !ips.is_empty()) - .map(|(username, active_ips)| UserActiveIps { username, active_ips }) + .map(|(username, active_ips)| UserActiveIps { + username, + active_ips, + }) .collect(); data.sort_by(|a, b| a.username.cmp(&b.username)); Ok(success_response(StatusCode::OK, data, revision)) diff --git a/src/api/runtime_zero.rs b/src/api/runtime_zero.rs index 52f8d99..d54c50f 100644 --- a/src/api/runtime_zero.rs +++ b/src/api/runtime_zero.rs @@ -100,6 +100,11 @@ pub(super) struct EffectiveUserIpPolicyLimits { pub(super) window_secs: u64, } +#[derive(Serialize)] +pub(super) struct EffectiveUserTcpPolicyLimits { + pub(super) global_each: usize, +} + #[derive(Serialize)] pub(super) struct EffectiveLimitsData { pub(super) update_every_secs: u64, @@ -109,6 +114,7 @@ pub(super) struct EffectiveLimitsData { pub(super) upstream: EffectiveUpstreamLimits, pub(super) middle_proxy: EffectiveMiddleProxyLimits, pub(super) user_ip_policy: EffectiveUserIpPolicyLimits, + pub(super) user_tcp_policy: EffectiveUserTcpPolicyLimits, } #[derive(Serialize)] @@ -289,6 +295,9 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD mode: user_max_unique_ips_mode_label(cfg.access.user_max_unique_ips_mode), window_secs: cfg.access.user_max_unique_ips_window_secs, }, + user_tcp_policy: EffectiveUserTcpPolicyLimits { + global_each: cfg.access.user_max_tcp_conns_global_each, + }, } } diff --git a/src/api/users.rs b/src/api/users.rs index 2ee8b98..0b8a471 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -144,7 +144,14 @@ pub(super) async fn create_user( .unwrap_or(UserInfo { username: body.username.clone(), user_ad_tag: None, - max_tcp_conns: None, + max_tcp_conns: cfg + .access + .user_max_tcp_conns + .get(&body.username) + .copied() + .filter(|limit| *limit > 0) + .or((cfg.access.user_max_tcp_conns_global_each > 0) + .then_some(cfg.access.user_max_tcp_conns_global_each)), expiration_rfc3339: None, data_quota_bytes: None, max_unique_ips: updated_limit, @@ -395,7 +402,14 @@ pub(super) async fn users_from_config( }); users.push(UserInfo { user_ad_tag: cfg.access.user_ad_tags.get(&username).cloned(), - max_tcp_conns: cfg.access.user_max_tcp_conns.get(&username).copied(), + max_tcp_conns: cfg + .access + .user_max_tcp_conns + .get(&username) + .copied() + .filter(|limit| *limit > 0) + .or((cfg.access.user_max_tcp_conns_global_each > 0) + .then_some(cfg.access.user_max_tcp_conns_global_each)), expiration_rfc3339: cfg .access .user_expirations @@ -572,3 +586,54 @@ fn resolve_tls_domains(cfg: &ProxyConfig) -> Vec<&str> { } domains } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ip_tracker::UserIpTracker; + use crate::stats::Stats; + + #[tokio::test] + async fn users_from_config_reports_effective_tcp_limit_with_global_fallback() { + let mut cfg = ProxyConfig::default(); + cfg.access.users.insert( + "alice".to_string(), + "0123456789abcdef0123456789abcdef".to_string(), + ); + cfg.access.user_max_tcp_conns_global_each = 7; + + let stats = Stats::new(); + let tracker = UserIpTracker::new(); + + let users = users_from_config(&cfg, &stats, &tracker, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert_eq!(alice.max_tcp_conns, Some(7)); + + cfg.access.user_max_tcp_conns.insert("alice".to_string(), 5); + let users = users_from_config(&cfg, &stats, &tracker, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert_eq!(alice.max_tcp_conns, Some(5)); + + cfg.access.user_max_tcp_conns.insert("alice".to_string(), 0); + let users = users_from_config(&cfg, &stats, &tracker, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert_eq!(alice.max_tcp_conns, Some(7)); + + cfg.access.user_max_tcp_conns_global_each = 0; + let users = users_from_config(&cfg, &stats, &tracker, None, None).await; + let alice = users + .iter() + .find(|entry| entry.username == "alice") + .expect("alice must be present"); + assert_eq!(alice.max_tcp_conns, None); + } +} diff --git a/src/cli.rs b/src/cli.rs index fd12176..5a79bae 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -13,7 +13,7 @@ use std::path::{Path, PathBuf}; use std::process::Command; #[cfg(unix)] -use crate::daemon::{self, DaemonOptions, DEFAULT_PID_FILE}; +use crate::daemon::{self, DEFAULT_PID_FILE, DaemonOptions}; /// CLI subcommand to execute. #[derive(Debug, Clone, PartialEq, Eq)] @@ -437,13 +437,13 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box> { eprintln!("[+] Config written to {}", config_path.display()); // 5. Generate and write service file - let exe_path = std::env::current_exe() - .unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); + let exe_path = + std::env::current_exe().unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); let service_opts = ServiceOptions { exe_path: &exe_path, config_path: &config_path, - user: None, // Let systemd/init handle user + user: None, // Let systemd/init handle user group: None, pid_file: "/var/run/telemt.pid", working_dir: Some("/var/lib/telemt"), @@ -623,6 +623,7 @@ fake_cert_len = 2048 tls_full_cert_ttl_secs = 90 [access] +user_max_tcp_conns_global_each = 0 replay_check_len = 65536 replay_window_secs = 120 ignore_time_skew = false diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 6297a3e..89e72bb 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -811,6 +811,10 @@ pub(crate) fn default_user_max_unique_ips_window_secs() -> u64 { DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS } +pub(crate) fn default_user_max_tcp_conns_global_each() -> usize { + 0 +} + pub(crate) fn default_user_max_unique_ips_global_each() -> usize { 0 } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 848a2dc..fa42c55 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -117,6 +117,7 @@ pub struct HotFields { pub users: std::collections::HashMap, pub user_ad_tags: std::collections::HashMap, pub user_max_tcp_conns: std::collections::HashMap, + pub user_max_tcp_conns_global_each: usize, pub user_expirations: std::collections::HashMap>, pub user_data_quota: std::collections::HashMap, pub user_max_unique_ips: std::collections::HashMap, @@ -240,6 +241,7 @@ impl HotFields { users: cfg.access.users.clone(), user_ad_tags: cfg.access.user_ad_tags.clone(), user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), + user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each, user_expirations: cfg.access.user_expirations.clone(), user_data_quota: cfg.access.user_data_quota.clone(), user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), @@ -530,6 +532,7 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.access.users = new.access.users.clone(); cfg.access.user_ad_tags = new.access.user_ad_tags.clone(); cfg.access.user_max_tcp_conns = new.access.user_max_tcp_conns.clone(); + cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each; cfg.access.user_expirations = new.access.user_expirations.clone(); cfg.access.user_data_quota = new.access.user_data_quota.clone(); cfg.access.user_max_unique_ips = new.access.user_max_unique_ips.clone(); @@ -1145,6 +1148,12 @@ fn log_changes( new_hot.user_max_tcp_conns.len() ); } + if old_hot.user_max_tcp_conns_global_each != new_hot.user_max_tcp_conns_global_each { + info!( + "config reload: user_max_tcp_conns policy global_each={}", + new_hot.user_max_tcp_conns_global_each + ); + } if old_hot.user_expirations != new_hot.user_expirations { info!( "config reload: user_expirations updated ({} entries)", diff --git a/src/config/load.rs b/src/config/load.rs index 75ae1e9..cc95f34 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1328,6 +1328,10 @@ mod tests { default_api_runtime_edge_events_capacity() ); assert_eq!(cfg.access.users, default_access_users()); + assert_eq!( + cfg.access.user_max_tcp_conns_global_each, + default_user_max_tcp_conns_global_each() + ); assert_eq!( cfg.access.user_max_unique_ips_mode, UserMaxUniqueIpsMode::default() @@ -1471,6 +1475,10 @@ mod tests { let access = AccessConfig::default(); assert_eq!(access.users, default_access_users()); + assert_eq!( + access.user_max_tcp_conns_global_each, + default_user_max_tcp_conns_global_each() + ); } #[test] diff --git a/src/config/types.rs b/src/config/types.rs index 5f3342f..41b0c2e 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1633,6 +1633,12 @@ pub struct AccessConfig { #[serde(default)] pub user_max_tcp_conns: HashMap, + /// Global per-user TCP connection limit applied when a user has no + /// positive individual override. + /// `0` disables the inherited limit. + #[serde(default = "default_user_max_tcp_conns_global_each")] + pub user_max_tcp_conns_global_each: usize, + #[serde(default)] pub user_expirations: HashMap>, @@ -1669,6 +1675,7 @@ impl Default for AccessConfig { users: default_access_users(), user_ad_tags: HashMap::new(), user_max_tcp_conns: HashMap::new(), + user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(), user_expirations: HashMap::new(), user_data_quota: HashMap::new(), user_max_unique_ips: HashMap::new(), diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index f5fed72..8e2481e 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -206,7 +206,9 @@ impl PidFile { let mut contents = String::new(); File::open(&self.path) .and_then(|mut f| f.read_to_string(&mut contents)) - .map_err(|e| DaemonError::PidFile(format!("cannot read {}: {}", self.path.display(), e)))?; + .map_err(|e| { + DaemonError::PidFile(format!("cannot read {}: {}", self.path.display(), e)) + })?; let pid: i32 = contents .trim() @@ -269,12 +271,16 @@ impl PidFile { // Write our PID let pid = getpid(); - let mut file = flock.unlock().map_err(|(_, errno)| { - DaemonError::PidFile(format!("unlock failed: {}", errno)) - })?; + let mut file = flock + .unlock() + .map_err(|(_, errno)| DaemonError::PidFile(format!("unlock failed: {}", errno)))?; writeln!(file, "{}", pid).map_err(|e| { - DaemonError::PidFile(format!("cannot write PID to {}: {}", self.path.display(), e)) + DaemonError::PidFile(format!( + "cannot write PID to {}: {}", + self.path.display(), + e + )) })?; // Re-acquire lock and keep it @@ -373,7 +379,8 @@ pub fn drop_privileges(user: Option<&str>, group: Option<&str>) -> Result<(), Da /// Looks up a user by name and returns their UID. fn lookup_user(name: &str) -> Result { // Use libc getpwnam - let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; + let c_name = + std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; unsafe { let pwd = libc::getpwnam(c_name.as_ptr()); @@ -387,7 +394,8 @@ fn lookup_user(name: &str) -> Result { /// Looks up a user's primary GID by username. fn lookup_user_primary_gid(name: &str) -> Result { - let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; + let c_name = + std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; unsafe { let pwd = libc::getpwnam(c_name.as_ptr()); @@ -401,7 +409,8 @@ fn lookup_user_primary_gid(name: &str) -> Result { /// Looks up a group by name and returns its GID. fn lookup_group(name: &str) -> Result { - let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::GroupNotFound(name.to_string()))?; + let c_name = + std::ffi::CString::new(name).map_err(|_| DaemonError::GroupNotFound(name.to_string()))?; unsafe { let grp = libc::getgrnam(c_name.as_ptr()); @@ -444,9 +453,8 @@ pub fn signal_pid_file>( ))); } - nix::sys::signal::kill(Pid::from_raw(pid), signal).map_err(|e| { - DaemonError::PidFile(format!("cannot signal process {}: {}", pid, e)) - })?; + nix::sys::signal::kill(Pid::from_raw(pid), signal) + .map_err(|e| DaemonError::PidFile(format!("cannot signal process {}: {}", pid, e)))?; Ok(()) } diff --git a/src/logging.rs b/src/logging.rs index f372798..bb381ef 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -63,7 +63,10 @@ impl LoggingGuard { pub fn init_logging( opts: &LoggingOptions, initial_filter: &str, -) -> (reload::Handle, LoggingGuard) { +) -> ( + reload::Handle, + LoggingGuard, +) { let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new(initial_filter)); match &opts.destination { @@ -101,7 +104,8 @@ pub fn init_logging( // Extract directory and filename prefix let path = Path::new(path); let dir = path.parent().unwrap_or(Path::new("/var/log")); - let prefix = path.file_name() + let prefix = path + .file_name() .and_then(|s| s.to_str()) .unwrap_or("telemt"); @@ -182,7 +186,11 @@ impl std::io::Write for SyslogWriter { .unwrap_or_else(|_| std::ffi::CString::new("(invalid utf8)").unwrap()); unsafe { - libc::syslog(priority, b"%s\0".as_ptr() as *const libc::c_char, c_msg.as_ptr()); + libc::syslog( + priority, + b"%s\0".as_ptr() as *const libc::c_char, + c_msg.as_ptr(), + ); } Ok(buf.len()) @@ -255,7 +263,10 @@ mod tests { #[test] fn test_parse_log_destination_default() { let args: Vec = vec![]; - assert!(matches!(parse_log_destination(&args), LogDestination::Stderr)); + assert!(matches!( + parse_log_destination(&args), + LogDestination::Stderr + )); } #[test] @@ -286,6 +297,9 @@ mod tests { #[test] fn test_parse_log_destination_syslog() { let args = vec!["--syslog".to_string()]; - assert!(matches!(parse_log_destination(&args), LogDestination::Syslog)); + assert!(matches!( + parse_log_destination(&args), + LogDestination::Syslog + )); } } diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index e3c3feb..d9d8e8b 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -149,7 +149,9 @@ fn print_help() { } eprintln!(); eprintln!("Options:"); - eprintln!(" --data-path Set data directory (absolute path; overrides config value)"); + eprintln!( + " --data-path Set data directory (absolute path; overrides config value)" + ); eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --log-level debug|verbose|normal|silent"); eprintln!(" --help, -h Show this help"); @@ -173,16 +175,10 @@ fn print_help() { eprintln!(); } eprintln!("Setup (fire-and-forget):"); - eprintln!( - " --init Generate config, install systemd service, start" - ); + eprintln!(" --init Generate config, install systemd service, start"); eprintln!(" --port Listen port (default: 443)"); - eprintln!( - " --domain TLS domain for masking (default: www.google.com)" - ); - eprintln!( - " --secret 32-char hex secret (auto-generated if omitted)" - ); + eprintln!(" --domain TLS domain for masking (default: www.google.com)"); + eprintln!(" --secret 32-char hex secret (auto-generated if omitted)"); eprintln!(" --user Username (default: user)"); eprintln!(" --config-dir Config directory (default: /etc/telemt)"); eprintln!(" --no-start Don't start the service after install"); diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index baa0f07..aa95cb6 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -83,7 +83,6 @@ pub async fn run() -> std::result::Result<(), Box> { async fn run_inner( daemon_opts: DaemonOptions, ) -> std::result::Result<(), Box> { - // Acquire PID file if daemonizing or if explicitly requested // Keep it alive until shutdown (underscore prefix = intentionally kept for RAII cleanup) let _pid_file = if daemon_opts.daemonize || daemon_opts.pid_file.is_some() { @@ -665,10 +664,7 @@ async fn run_inner( // Drop privileges after binding sockets (which may require root for port < 1024) if daemon_opts.user.is_some() || daemon_opts.group.is_some() { - if let Err(e) = drop_privileges( - daemon_opts.user.as_deref(), - daemon_opts.group.as_deref(), - ) { + if let Err(e) = drop_privileges(daemon_opts.user.as_deref(), daemon_opts.group.as_deref()) { error!(error = %e, "Failed to drop privileges"); std::process::exit(1); } diff --git a/src/maestro/shutdown.rs b/src/maestro/shutdown.rs index cfdee24..f6e50ca 100644 --- a/src/maestro/shutdown.rs +++ b/src/maestro/shutdown.rs @@ -11,10 +11,10 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -#[cfg(unix)] -use tokio::signal::unix::{SignalKind, signal}; #[cfg(not(unix))] use tokio::signal; +#[cfg(unix)] +use tokio::signal::unix::{SignalKind, signal}; use tracing::{info, warn}; use crate::stats::Stats; @@ -94,7 +94,8 @@ async fn perform_shutdown( // Graceful ME pool shutdown if let Some(pool) = &me_pool { - match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()).await + match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()) + .await { Ok(total) => { info!( @@ -159,15 +160,12 @@ fn dump_stats(stats: &Stats, process_started_at: Instant) { /// - SIGUSR1: Log rotation acknowledgment (for external log rotation tools) /// - SIGUSR2: Dump runtime status to log #[cfg(unix)] -pub(crate) fn spawn_signal_handlers( - stats: Arc, - process_started_at: Instant, -) { +pub(crate) fn spawn_signal_handlers(stats: Arc, process_started_at: Instant) { tokio::spawn(async move { - let mut sigusr1 = signal(SignalKind::user_defined1()) - .expect("Failed to register SIGUSR1 handler"); - let mut sigusr2 = signal(SignalKind::user_defined2()) - .expect("Failed to register SIGUSR2 handler"); + let mut sigusr1 = + signal(SignalKind::user_defined1()).expect("Failed to register SIGUSR1 handler"); + let mut sigusr2 = + signal(SignalKind::user_defined2()).expect("Failed to register SIGUSR2 handler"); loop { tokio::select! { @@ -184,10 +182,7 @@ pub(crate) fn spawn_signal_handlers( /// No-op on non-Unix platforms. #[cfg(not(unix))] -pub(crate) fn spawn_signal_handlers( - _stats: Arc, - _process_started_at: Instant, -) { +pub(crate) fn spawn_signal_handlers(_stats: Arc, _process_started_at: Instant) { // No SIGUSR1/SIGUSR2 on non-Unix } diff --git a/src/main.rs b/src/main.rs index 0d29981..68c89fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,8 +8,6 @@ mod crypto; mod daemon; mod error; mod ip_tracker; -mod logging; -mod service; #[cfg(test)] #[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"] mod ip_tracker_encapsulation_adversarial_tests; @@ -19,11 +17,13 @@ mod ip_tracker_hotpath_adversarial_tests; #[cfg(test)] #[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; +mod logging; mod maestro; mod metrics; mod network; mod protocol; mod proxy; +mod service; mod startup; mod stats; mod stream; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index d71411a..7472459 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -877,7 +877,8 @@ impl RunningClientHandler { let first_byte = if self.config.timeouts.client_first_byte_idle_secs == 0 { None } else { - let idle_timeout = Duration::from_secs(self.config.timeouts.client_first_byte_idle_secs); + let idle_timeout = + Duration::from_secs(self.config.timeouts.client_first_byte_idle_secs); let mut first_byte = [0u8; 1]; match timeout(idle_timeout, self.stream.read(&mut first_byte)).await { Ok(Ok(0)) => { @@ -1365,7 +1366,11 @@ impl RunningClientHandler { .access .user_max_tcp_conns .get(user) - .map(|v| *v as u64); + .copied() + .filter(|limit| *limit > 0) + .or((config.access.user_max_tcp_conns_global_each > 0) + .then_some(config.access.user_max_tcp_conns_global_each)) + .map(|v| v as u64); if !stats.try_acquire_user_curr_connects(user, limit) { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string(), @@ -1424,7 +1429,11 @@ impl RunningClientHandler { .access .user_max_tcp_conns .get(user) - .map(|v| *v as u64); + .copied() + .filter(|limit| *limit > 0) + .or((config.access.user_max_tcp_conns_global_each > 0) + .then_some(config.access.user_max_tcp_conns_global_each)) + .map(|v| v as u64); if !stats.try_acquire_user_curr_connects(user, limit) { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string(), diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 3a66a09..d585326 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -1740,7 +1740,8 @@ async fn fragmented_tls_mtproto_with_interleaved_ccs_is_accepted() { .await .unwrap(); assert_eq!(tls_response_head[0], 0x16); - let tls_response_len = u16::from_be_bytes([tls_response_head[3], tls_response_head[4]]) as usize; + let tls_response_len = + u16::from_be_bytes([tls_response_head[3], tls_response_head[4]]) as usize; let mut tls_response_body = vec![0u8; tls_response_len]; client_side .read_exact(&mut tls_response_body) @@ -2533,14 +2534,16 @@ async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { } #[tokio::test] -async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { +async fn zero_tcp_limit_uses_global_fallback_and_rejects_without_side_effects() { let mut config = ProxyConfig::default(); config .access .user_max_tcp_conns .insert("user".to_string(), 0); + config.access.user_max_tcp_conns_global_each = 1; let stats = Stats::new(); + stats.increment_user_curr_connects("user"); let ip_tracker = UserIpTracker::new(); let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); @@ -2557,10 +2560,75 @@ async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { result, Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" )); + assert_eq!( + stats.get_user_curr_connects("user"), + 1, + "TCP-limit rejection must keep pre-existing in-flight connection count unchanged" + ); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn zero_tcp_limit_with_disabled_global_fallback_is_unlimited() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 0); + config.access.user_max_tcp_conns_global_each = 0; + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!( + result.is_ok(), + "per-user zero with global fallback disabled must not enforce a TCP limit" + ); assert_eq!(stats.get_user_curr_connects("user"), 0); assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); } +#[tokio::test] +async fn global_tcp_fallback_applies_when_per_user_limit_is_missing() { + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns_global_each = 1; + + let stats = Stats::new(); + stats.increment_user_curr_connects("user"); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.213:50003".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!( + stats.get_user_curr_connects("user"), + 1, + "Global fallback TCP-limit rejection must keep pre-existing counter unchanged" + ); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + #[tokio::test] async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { let user = "check-helper-user"; diff --git a/src/service/mod.rs b/src/service/mod.rs index 160c36c..7a6e4f6 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -111,11 +111,17 @@ pub fn generate_service_file(init_system: InitSystem, opts: &ServiceOptions) -> /// Generates an enhanced systemd unit file. fn generate_systemd_unit(opts: &ServiceOptions) -> String { let user_line = opts.user.map(|u| format!("User={}", u)).unwrap_or_default(); - let group_line = opts.group.map(|g| format!("Group={}", g)).unwrap_or_default(); - let working_dir = opts.working_dir.map(|d| format!("WorkingDirectory={}", d)).unwrap_or_default(); + let group_line = opts + .group + .map(|g| format!("Group={}", g)) + .unwrap_or_default(); + let working_dir = opts + .working_dir + .map(|d| format!("WorkingDirectory={}", d)) + .unwrap_or_default(); format!( -r#"[Unit] + r#"[Unit] Description={description} Documentation=https://github.com/telemt/telemt After=network-online.target @@ -176,7 +182,7 @@ fn generate_openrc_script(opts: &ServiceOptions) -> String { let group = opts.group.unwrap_or("root"); format!( -r#"#!/sbin/openrc-run + r#"#!/sbin/openrc-run # OpenRC init script for telemt description="{description}" @@ -218,7 +224,7 @@ fn generate_freebsd_rc_script(opts: &ServiceOptions) -> String { let group = opts.group.unwrap_or("wheel"); format!( -r#"#!/bin/sh + r#"#!/bin/sh # # PROVIDE: telemt # REQUIRE: LOGIN NETWORKING @@ -284,7 +290,7 @@ run_rc_command "$1" pub fn installation_instructions(init_system: InitSystem) -> &'static str { match init_system { InitSystem::Systemd => { -r#"To install and enable the service: + r#"To install and enable the service: sudo systemctl daemon-reload sudo systemctl enable telemt sudo systemctl start telemt @@ -300,7 +306,7 @@ To reload configuration: "# } InitSystem::OpenRC => { -r#"To install and enable the service: + r#"To install and enable the service: sudo chmod +x /etc/init.d/telemt sudo rc-update add telemt default sudo rc-service telemt start @@ -313,7 +319,7 @@ To reload configuration: "# } InitSystem::FreeBSDRc => { -r#"To install and enable the service: + r#"To install and enable the service: sudo chmod +x /usr/local/etc/rc.d/telemt sudo sysrc telemt_enable="YES" sudo service telemt start @@ -326,7 +332,7 @@ To reload configuration: "# } InitSystem::Unknown => { -r#"No supported init system detected. + r#"No supported init system detected. You may need to create a service file manually or run telemt directly: telemt start /etc/telemt/config.toml "# @@ -369,8 +375,14 @@ mod tests { #[test] fn test_service_file_paths() { - assert_eq!(service_file_path(InitSystem::Systemd), "/etc/systemd/system/telemt.service"); + assert_eq!( + service_file_path(InitSystem::Systemd), + "/etc/systemd/system/telemt.service" + ); assert_eq!(service_file_path(InitSystem::OpenRC), "/etc/init.d/telemt"); - assert_eq!(service_file_path(InitSystem::FreeBSDRc), "/usr/local/etc/rc.d/telemt"); + assert_eq!( + service_file_path(InitSystem::FreeBSDRc), + "/usr/local/etc/rc.d/telemt" + ); } }