Admission-timeouts + Global Each TCP Connections

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey 2026-03-31 11:14:55 +03:00
parent 5bf56b6dd8
commit b8cf596e7d
No known key found for this signature in database
17 changed files with 275 additions and 71 deletions

View File

@ -370,7 +370,10 @@ async fn handle(
let mut data: Vec<UserActiveIps> = active_ips_map let mut data: Vec<UserActiveIps> = active_ips_map
.into_iter() .into_iter()
.filter(|(_, ips)| !ips.is_empty()) .filter(|(_, ips)| !ips.is_empty())
.map(|(username, active_ips)| UserActiveIps { username, active_ips }) .map(|(username, active_ips)| UserActiveIps {
username,
active_ips,
})
.collect(); .collect();
data.sort_by(|a, b| a.username.cmp(&b.username)); data.sort_by(|a, b| a.username.cmp(&b.username));
Ok(success_response(StatusCode::OK, data, revision)) Ok(success_response(StatusCode::OK, data, revision))

View File

@ -100,6 +100,11 @@ pub(super) struct EffectiveUserIpPolicyLimits {
pub(super) window_secs: u64, pub(super) window_secs: u64,
} }
#[derive(Serialize)]
pub(super) struct EffectiveUserTcpPolicyLimits {
pub(super) global_each: usize,
}
#[derive(Serialize)] #[derive(Serialize)]
pub(super) struct EffectiveLimitsData { pub(super) struct EffectiveLimitsData {
pub(super) update_every_secs: u64, pub(super) update_every_secs: u64,
@ -109,6 +114,7 @@ pub(super) struct EffectiveLimitsData {
pub(super) upstream: EffectiveUpstreamLimits, pub(super) upstream: EffectiveUpstreamLimits,
pub(super) middle_proxy: EffectiveMiddleProxyLimits, pub(super) middle_proxy: EffectiveMiddleProxyLimits,
pub(super) user_ip_policy: EffectiveUserIpPolicyLimits, pub(super) user_ip_policy: EffectiveUserIpPolicyLimits,
pub(super) user_tcp_policy: EffectiveUserTcpPolicyLimits,
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -289,6 +295,9 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD
mode: user_max_unique_ips_mode_label(cfg.access.user_max_unique_ips_mode), mode: user_max_unique_ips_mode_label(cfg.access.user_max_unique_ips_mode),
window_secs: cfg.access.user_max_unique_ips_window_secs, window_secs: cfg.access.user_max_unique_ips_window_secs,
}, },
user_tcp_policy: EffectiveUserTcpPolicyLimits {
global_each: cfg.access.user_max_tcp_conns_global_each,
},
} }
} }

View File

@ -144,7 +144,14 @@ pub(super) async fn create_user(
.unwrap_or(UserInfo { .unwrap_or(UserInfo {
username: body.username.clone(), username: body.username.clone(),
user_ad_tag: None, user_ad_tag: None,
max_tcp_conns: None, max_tcp_conns: cfg
.access
.user_max_tcp_conns
.get(&body.username)
.copied()
.filter(|limit| *limit > 0)
.or((cfg.access.user_max_tcp_conns_global_each > 0)
.then_some(cfg.access.user_max_tcp_conns_global_each)),
expiration_rfc3339: None, expiration_rfc3339: None,
data_quota_bytes: None, data_quota_bytes: None,
max_unique_ips: updated_limit, max_unique_ips: updated_limit,
@ -395,7 +402,14 @@ pub(super) async fn users_from_config(
}); });
users.push(UserInfo { users.push(UserInfo {
user_ad_tag: cfg.access.user_ad_tags.get(&username).cloned(), user_ad_tag: cfg.access.user_ad_tags.get(&username).cloned(),
max_tcp_conns: cfg.access.user_max_tcp_conns.get(&username).copied(), max_tcp_conns: cfg
.access
.user_max_tcp_conns
.get(&username)
.copied()
.filter(|limit| *limit > 0)
.or((cfg.access.user_max_tcp_conns_global_each > 0)
.then_some(cfg.access.user_max_tcp_conns_global_each)),
expiration_rfc3339: cfg expiration_rfc3339: cfg
.access .access
.user_expirations .user_expirations
@ -572,3 +586,54 @@ fn resolve_tls_domains(cfg: &ProxyConfig) -> Vec<&str> {
} }
domains domains
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::ip_tracker::UserIpTracker;
use crate::stats::Stats;
#[tokio::test]
async fn users_from_config_reports_effective_tcp_limit_with_global_fallback() {
let mut cfg = ProxyConfig::default();
cfg.access.users.insert(
"alice".to_string(),
"0123456789abcdef0123456789abcdef".to_string(),
);
cfg.access.user_max_tcp_conns_global_each = 7;
let stats = Stats::new();
let tracker = UserIpTracker::new();
let users = users_from_config(&cfg, &stats, &tracker, None, None).await;
let alice = users
.iter()
.find(|entry| entry.username == "alice")
.expect("alice must be present");
assert_eq!(alice.max_tcp_conns, Some(7));
cfg.access.user_max_tcp_conns.insert("alice".to_string(), 5);
let users = users_from_config(&cfg, &stats, &tracker, None, None).await;
let alice = users
.iter()
.find(|entry| entry.username == "alice")
.expect("alice must be present");
assert_eq!(alice.max_tcp_conns, Some(5));
cfg.access.user_max_tcp_conns.insert("alice".to_string(), 0);
let users = users_from_config(&cfg, &stats, &tracker, None, None).await;
let alice = users
.iter()
.find(|entry| entry.username == "alice")
.expect("alice must be present");
assert_eq!(alice.max_tcp_conns, Some(7));
cfg.access.user_max_tcp_conns_global_each = 0;
let users = users_from_config(&cfg, &stats, &tracker, None, None).await;
let alice = users
.iter()
.find(|entry| entry.username == "alice")
.expect("alice must be present");
assert_eq!(alice.max_tcp_conns, None);
}
}

View File

@ -13,7 +13,7 @@ use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
#[cfg(unix)] #[cfg(unix)]
use crate::daemon::{self, DaemonOptions, DEFAULT_PID_FILE}; use crate::daemon::{self, DEFAULT_PID_FILE, DaemonOptions};
/// CLI subcommand to execute. /// CLI subcommand to execute.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -437,13 +437,13 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[+] Config written to {}", config_path.display()); eprintln!("[+] Config written to {}", config_path.display());
// 5. Generate and write service file // 5. Generate and write service file
let exe_path = std::env::current_exe() let exe_path =
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); std::env::current_exe().unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let service_opts = ServiceOptions { let service_opts = ServiceOptions {
exe_path: &exe_path, exe_path: &exe_path,
config_path: &config_path, config_path: &config_path,
user: None, // Let systemd/init handle user user: None, // Let systemd/init handle user
group: None, group: None,
pid_file: "/var/run/telemt.pid", pid_file: "/var/run/telemt.pid",
working_dir: Some("/var/lib/telemt"), working_dir: Some("/var/lib/telemt"),
@ -623,6 +623,7 @@ fake_cert_len = 2048
tls_full_cert_ttl_secs = 90 tls_full_cert_ttl_secs = 90
[access] [access]
user_max_tcp_conns_global_each = 0
replay_check_len = 65536 replay_check_len = 65536
replay_window_secs = 120 replay_window_secs = 120
ignore_time_skew = false ignore_time_skew = false

View File

@ -811,6 +811,10 @@ pub(crate) fn default_user_max_unique_ips_window_secs() -> u64 {
DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS
} }
pub(crate) fn default_user_max_tcp_conns_global_each() -> usize {
0
}
pub(crate) fn default_user_max_unique_ips_global_each() -> usize { pub(crate) fn default_user_max_unique_ips_global_each() -> usize {
0 0
} }

View File

@ -117,6 +117,7 @@ pub struct HotFields {
pub users: std::collections::HashMap<String, String>, pub users: std::collections::HashMap<String, String>,
pub user_ad_tags: std::collections::HashMap<String, String>, pub user_ad_tags: std::collections::HashMap<String, String>,
pub user_max_tcp_conns: std::collections::HashMap<String, usize>, pub user_max_tcp_conns: std::collections::HashMap<String, usize>,
pub user_max_tcp_conns_global_each: usize,
pub user_expirations: std::collections::HashMap<String, chrono::DateTime<chrono::Utc>>, pub user_expirations: std::collections::HashMap<String, chrono::DateTime<chrono::Utc>>,
pub user_data_quota: std::collections::HashMap<String, u64>, pub user_data_quota: std::collections::HashMap<String, u64>,
pub user_max_unique_ips: std::collections::HashMap<String, usize>, pub user_max_unique_ips: std::collections::HashMap<String, usize>,
@ -240,6 +241,7 @@ impl HotFields {
users: cfg.access.users.clone(), users: cfg.access.users.clone(),
user_ad_tags: cfg.access.user_ad_tags.clone(), user_ad_tags: cfg.access.user_ad_tags.clone(),
user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(),
user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each,
user_expirations: cfg.access.user_expirations.clone(), user_expirations: cfg.access.user_expirations.clone(),
user_data_quota: cfg.access.user_data_quota.clone(), user_data_quota: cfg.access.user_data_quota.clone(),
user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), user_max_unique_ips: cfg.access.user_max_unique_ips.clone(),
@ -530,6 +532,7 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
cfg.access.users = new.access.users.clone(); cfg.access.users = new.access.users.clone();
cfg.access.user_ad_tags = new.access.user_ad_tags.clone(); cfg.access.user_ad_tags = new.access.user_ad_tags.clone();
cfg.access.user_max_tcp_conns = new.access.user_max_tcp_conns.clone(); cfg.access.user_max_tcp_conns = new.access.user_max_tcp_conns.clone();
cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each;
cfg.access.user_expirations = new.access.user_expirations.clone(); cfg.access.user_expirations = new.access.user_expirations.clone();
cfg.access.user_data_quota = new.access.user_data_quota.clone(); cfg.access.user_data_quota = new.access.user_data_quota.clone();
cfg.access.user_max_unique_ips = new.access.user_max_unique_ips.clone(); cfg.access.user_max_unique_ips = new.access.user_max_unique_ips.clone();
@ -1145,6 +1148,12 @@ fn log_changes(
new_hot.user_max_tcp_conns.len() new_hot.user_max_tcp_conns.len()
); );
} }
if old_hot.user_max_tcp_conns_global_each != new_hot.user_max_tcp_conns_global_each {
info!(
"config reload: user_max_tcp_conns policy global_each={}",
new_hot.user_max_tcp_conns_global_each
);
}
if old_hot.user_expirations != new_hot.user_expirations { if old_hot.user_expirations != new_hot.user_expirations {
info!( info!(
"config reload: user_expirations updated ({} entries)", "config reload: user_expirations updated ({} entries)",

View File

@ -1328,6 +1328,10 @@ mod tests {
default_api_runtime_edge_events_capacity() default_api_runtime_edge_events_capacity()
); );
assert_eq!(cfg.access.users, default_access_users()); assert_eq!(cfg.access.users, default_access_users());
assert_eq!(
cfg.access.user_max_tcp_conns_global_each,
default_user_max_tcp_conns_global_each()
);
assert_eq!( assert_eq!(
cfg.access.user_max_unique_ips_mode, cfg.access.user_max_unique_ips_mode,
UserMaxUniqueIpsMode::default() UserMaxUniqueIpsMode::default()
@ -1471,6 +1475,10 @@ mod tests {
let access = AccessConfig::default(); let access = AccessConfig::default();
assert_eq!(access.users, default_access_users()); assert_eq!(access.users, default_access_users());
assert_eq!(
access.user_max_tcp_conns_global_each,
default_user_max_tcp_conns_global_each()
);
} }
#[test] #[test]

View File

@ -1633,6 +1633,12 @@ pub struct AccessConfig {
#[serde(default)] #[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>, pub user_max_tcp_conns: HashMap<String, usize>,
/// Global per-user TCP connection limit applied when a user has no
/// positive individual override.
/// `0` disables the inherited limit.
#[serde(default = "default_user_max_tcp_conns_global_each")]
pub user_max_tcp_conns_global_each: usize,
#[serde(default)] #[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>, pub user_expirations: HashMap<String, DateTime<Utc>>,
@ -1669,6 +1675,7 @@ impl Default for AccessConfig {
users: default_access_users(), users: default_access_users(),
user_ad_tags: HashMap::new(), user_ad_tags: HashMap::new(),
user_max_tcp_conns: HashMap::new(), user_max_tcp_conns: HashMap::new(),
user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(),
user_expirations: HashMap::new(), user_expirations: HashMap::new(),
user_data_quota: HashMap::new(), user_data_quota: HashMap::new(),
user_max_unique_ips: HashMap::new(), user_max_unique_ips: HashMap::new(),

View File

@ -206,7 +206,9 @@ impl PidFile {
let mut contents = String::new(); let mut contents = String::new();
File::open(&self.path) File::open(&self.path)
.and_then(|mut f| f.read_to_string(&mut contents)) .and_then(|mut f| f.read_to_string(&mut contents))
.map_err(|e| DaemonError::PidFile(format!("cannot read {}: {}", self.path.display(), e)))?; .map_err(|e| {
DaemonError::PidFile(format!("cannot read {}: {}", self.path.display(), e))
})?;
let pid: i32 = contents let pid: i32 = contents
.trim() .trim()
@ -269,12 +271,16 @@ impl PidFile {
// Write our PID // Write our PID
let pid = getpid(); let pid = getpid();
let mut file = flock.unlock().map_err(|(_, errno)| { let mut file = flock
DaemonError::PidFile(format!("unlock failed: {}", errno)) .unlock()
})?; .map_err(|(_, errno)| DaemonError::PidFile(format!("unlock failed: {}", errno)))?;
writeln!(file, "{}", pid).map_err(|e| { writeln!(file, "{}", pid).map_err(|e| {
DaemonError::PidFile(format!("cannot write PID to {}: {}", self.path.display(), e)) DaemonError::PidFile(format!(
"cannot write PID to {}: {}",
self.path.display(),
e
))
})?; })?;
// Re-acquire lock and keep it // Re-acquire lock and keep it
@ -373,7 +379,8 @@ pub fn drop_privileges(user: Option<&str>, group: Option<&str>) -> Result<(), Da
/// Looks up a user by name and returns their UID. /// Looks up a user by name and returns their UID.
fn lookup_user(name: &str) -> Result<Uid, DaemonError> { fn lookup_user(name: &str) -> Result<Uid, DaemonError> {
// Use libc getpwnam // Use libc getpwnam
let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; let c_name =
std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?;
unsafe { unsafe {
let pwd = libc::getpwnam(c_name.as_ptr()); let pwd = libc::getpwnam(c_name.as_ptr());
@ -387,7 +394,8 @@ fn lookup_user(name: &str) -> Result<Uid, DaemonError> {
/// Looks up a user's primary GID by username. /// Looks up a user's primary GID by username.
fn lookup_user_primary_gid(name: &str) -> Result<Gid, DaemonError> { fn lookup_user_primary_gid(name: &str) -> Result<Gid, DaemonError> {
let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?; let c_name =
std::ffi::CString::new(name).map_err(|_| DaemonError::UserNotFound(name.to_string()))?;
unsafe { unsafe {
let pwd = libc::getpwnam(c_name.as_ptr()); let pwd = libc::getpwnam(c_name.as_ptr());
@ -401,7 +409,8 @@ fn lookup_user_primary_gid(name: &str) -> Result<Gid, DaemonError> {
/// Looks up a group by name and returns its GID. /// Looks up a group by name and returns its GID.
fn lookup_group(name: &str) -> Result<Gid, DaemonError> { fn lookup_group(name: &str) -> Result<Gid, DaemonError> {
let c_name = std::ffi::CString::new(name).map_err(|_| DaemonError::GroupNotFound(name.to_string()))?; let c_name =
std::ffi::CString::new(name).map_err(|_| DaemonError::GroupNotFound(name.to_string()))?;
unsafe { unsafe {
let grp = libc::getgrnam(c_name.as_ptr()); let grp = libc::getgrnam(c_name.as_ptr());
@ -444,9 +453,8 @@ pub fn signal_pid_file<P: AsRef<Path>>(
))); )));
} }
nix::sys::signal::kill(Pid::from_raw(pid), signal).map_err(|e| { nix::sys::signal::kill(Pid::from_raw(pid), signal)
DaemonError::PidFile(format!("cannot signal process {}: {}", pid, e)) .map_err(|e| DaemonError::PidFile(format!("cannot signal process {}: {}", pid, e)))?;
})?;
Ok(()) Ok(())
} }

View File

@ -63,7 +63,10 @@ impl LoggingGuard {
pub fn init_logging( pub fn init_logging(
opts: &LoggingOptions, opts: &LoggingOptions,
initial_filter: &str, initial_filter: &str,
) -> (reload::Handle<EnvFilter, impl tracing::Subscriber + Send + Sync>, LoggingGuard) { ) -> (
reload::Handle<EnvFilter, impl tracing::Subscriber + Send + Sync>,
LoggingGuard,
) {
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new(initial_filter)); let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new(initial_filter));
match &opts.destination { match &opts.destination {
@ -101,7 +104,8 @@ pub fn init_logging(
// Extract directory and filename prefix // Extract directory and filename prefix
let path = Path::new(path); let path = Path::new(path);
let dir = path.parent().unwrap_or(Path::new("/var/log")); let dir = path.parent().unwrap_or(Path::new("/var/log"));
let prefix = path.file_name() let prefix = path
.file_name()
.and_then(|s| s.to_str()) .and_then(|s| s.to_str())
.unwrap_or("telemt"); .unwrap_or("telemt");
@ -182,7 +186,11 @@ impl std::io::Write for SyslogWriter {
.unwrap_or_else(|_| std::ffi::CString::new("(invalid utf8)").unwrap()); .unwrap_or_else(|_| std::ffi::CString::new("(invalid utf8)").unwrap());
unsafe { unsafe {
libc::syslog(priority, b"%s\0".as_ptr() as *const libc::c_char, c_msg.as_ptr()); libc::syslog(
priority,
b"%s\0".as_ptr() as *const libc::c_char,
c_msg.as_ptr(),
);
} }
Ok(buf.len()) Ok(buf.len())
@ -255,7 +263,10 @@ mod tests {
#[test] #[test]
fn test_parse_log_destination_default() { fn test_parse_log_destination_default() {
let args: Vec<String> = vec![]; let args: Vec<String> = vec![];
assert!(matches!(parse_log_destination(&args), LogDestination::Stderr)); assert!(matches!(
parse_log_destination(&args),
LogDestination::Stderr
));
} }
#[test] #[test]
@ -286,6 +297,9 @@ mod tests {
#[test] #[test]
fn test_parse_log_destination_syslog() { fn test_parse_log_destination_syslog() {
let args = vec!["--syslog".to_string()]; let args = vec!["--syslog".to_string()];
assert!(matches!(parse_log_destination(&args), LogDestination::Syslog)); assert!(matches!(
parse_log_destination(&args),
LogDestination::Syslog
));
} }
} }

View File

@ -149,7 +149,9 @@ fn print_help() {
} }
eprintln!(); eprintln!();
eprintln!("Options:"); eprintln!("Options:");
eprintln!(" --data-path <DIR> Set data directory (absolute path; overrides config value)"); eprintln!(
" --data-path <DIR> Set data directory (absolute path; overrides config value)"
);
eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent"); eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help"); eprintln!(" --help, -h Show this help");
@ -173,16 +175,10 @@ fn print_help() {
eprintln!(); eprintln!();
} }
eprintln!("Setup (fire-and-forget):"); eprintln!("Setup (fire-and-forget):");
eprintln!( eprintln!(" --init Generate config, install systemd service, start");
" --init Generate config, install systemd service, start"
);
eprintln!(" --port <PORT> Listen port (default: 443)"); eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!( eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)");
" --domain <DOMAIN> TLS domain for masking (default: www.google.com)" eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)");
);
eprintln!(
" --secret <HEX> 32-char hex secret (auto-generated if omitted)"
);
eprintln!(" --user <NAME> Username (default: user)"); eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)"); eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install"); eprintln!(" --no-start Don't start the service after install");

View File

@ -83,7 +83,6 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
async fn run_inner( async fn run_inner(
daemon_opts: DaemonOptions, daemon_opts: DaemonOptions,
) -> std::result::Result<(), Box<dyn std::error::Error>> { ) -> std::result::Result<(), Box<dyn std::error::Error>> {
// Acquire PID file if daemonizing or if explicitly requested // Acquire PID file if daemonizing or if explicitly requested
// Keep it alive until shutdown (underscore prefix = intentionally kept for RAII cleanup) // Keep it alive until shutdown (underscore prefix = intentionally kept for RAII cleanup)
let _pid_file = if daemon_opts.daemonize || daemon_opts.pid_file.is_some() { let _pid_file = if daemon_opts.daemonize || daemon_opts.pid_file.is_some() {
@ -665,10 +664,7 @@ async fn run_inner(
// Drop privileges after binding sockets (which may require root for port < 1024) // Drop privileges after binding sockets (which may require root for port < 1024)
if daemon_opts.user.is_some() || daemon_opts.group.is_some() { if daemon_opts.user.is_some() || daemon_opts.group.is_some() {
if let Err(e) = drop_privileges( if let Err(e) = drop_privileges(daemon_opts.user.as_deref(), daemon_opts.group.as_deref()) {
daemon_opts.user.as_deref(),
daemon_opts.group.as_deref(),
) {
error!(error = %e, "Failed to drop privileges"); error!(error = %e, "Failed to drop privileges");
std::process::exit(1); std::process::exit(1);
} }

View File

@ -11,10 +11,10 @@
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
#[cfg(unix)]
use tokio::signal::unix::{SignalKind, signal};
#[cfg(not(unix))] #[cfg(not(unix))]
use tokio::signal; use tokio::signal;
#[cfg(unix)]
use tokio::signal::unix::{SignalKind, signal};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::stats::Stats; use crate::stats::Stats;
@ -94,7 +94,8 @@ async fn perform_shutdown(
// Graceful ME pool shutdown // Graceful ME pool shutdown
if let Some(pool) = &me_pool { if let Some(pool) = &me_pool {
match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()).await match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all())
.await
{ {
Ok(total) => { Ok(total) => {
info!( info!(
@ -159,15 +160,12 @@ fn dump_stats(stats: &Stats, process_started_at: Instant) {
/// - SIGUSR1: Log rotation acknowledgment (for external log rotation tools) /// - SIGUSR1: Log rotation acknowledgment (for external log rotation tools)
/// - SIGUSR2: Dump runtime status to log /// - SIGUSR2: Dump runtime status to log
#[cfg(unix)] #[cfg(unix)]
pub(crate) fn spawn_signal_handlers( pub(crate) fn spawn_signal_handlers(stats: Arc<Stats>, process_started_at: Instant) {
stats: Arc<Stats>,
process_started_at: Instant,
) {
tokio::spawn(async move { tokio::spawn(async move {
let mut sigusr1 = signal(SignalKind::user_defined1()) let mut sigusr1 =
.expect("Failed to register SIGUSR1 handler"); signal(SignalKind::user_defined1()).expect("Failed to register SIGUSR1 handler");
let mut sigusr2 = signal(SignalKind::user_defined2()) let mut sigusr2 =
.expect("Failed to register SIGUSR2 handler"); signal(SignalKind::user_defined2()).expect("Failed to register SIGUSR2 handler");
loop { loop {
tokio::select! { tokio::select! {
@ -184,10 +182,7 @@ pub(crate) fn spawn_signal_handlers(
/// No-op on non-Unix platforms. /// No-op on non-Unix platforms.
#[cfg(not(unix))] #[cfg(not(unix))]
pub(crate) fn spawn_signal_handlers( pub(crate) fn spawn_signal_handlers(_stats: Arc<Stats>, _process_started_at: Instant) {
_stats: Arc<Stats>,
_process_started_at: Instant,
) {
// No SIGUSR1/SIGUSR2 on non-Unix // No SIGUSR1/SIGUSR2 on non-Unix
} }

View File

@ -8,8 +8,6 @@ mod crypto;
mod daemon; mod daemon;
mod error; mod error;
mod ip_tracker; mod ip_tracker;
mod logging;
mod service;
#[cfg(test)] #[cfg(test)]
#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"] #[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"]
mod ip_tracker_encapsulation_adversarial_tests; mod ip_tracker_encapsulation_adversarial_tests;
@ -19,11 +17,13 @@ mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"] #[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests; mod ip_tracker_regression_tests;
mod logging;
mod maestro; mod maestro;
mod metrics; mod metrics;
mod network; mod network;
mod protocol; mod protocol;
mod proxy; mod proxy;
mod service;
mod startup; mod startup;
mod stats; mod stats;
mod stream; mod stream;

View File

@ -877,7 +877,8 @@ impl RunningClientHandler {
let first_byte = if self.config.timeouts.client_first_byte_idle_secs == 0 { let first_byte = if self.config.timeouts.client_first_byte_idle_secs == 0 {
None None
} else { } else {
let idle_timeout = Duration::from_secs(self.config.timeouts.client_first_byte_idle_secs); let idle_timeout =
Duration::from_secs(self.config.timeouts.client_first_byte_idle_secs);
let mut first_byte = [0u8; 1]; let mut first_byte = [0u8; 1];
match timeout(idle_timeout, self.stream.read(&mut first_byte)).await { match timeout(idle_timeout, self.stream.read(&mut first_byte)).await {
Ok(Ok(0)) => { Ok(Ok(0)) => {
@ -1365,7 +1366,11 @@ impl RunningClientHandler {
.access .access
.user_max_tcp_conns .user_max_tcp_conns
.get(user) .get(user)
.map(|v| *v as u64); .copied()
.filter(|limit| *limit > 0)
.or((config.access.user_max_tcp_conns_global_each > 0)
.then_some(config.access.user_max_tcp_conns_global_each))
.map(|v| v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) { if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded { return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(), user: user.to_string(),
@ -1424,7 +1429,11 @@ impl RunningClientHandler {
.access .access
.user_max_tcp_conns .user_max_tcp_conns
.get(user) .get(user)
.map(|v| *v as u64); .copied()
.filter(|limit| *limit > 0)
.or((config.access.user_max_tcp_conns_global_each > 0)
.then_some(config.access.user_max_tcp_conns_global_each))
.map(|v| v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) { if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded { return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(), user: user.to_string(),

View File

@ -1740,7 +1740,8 @@ async fn fragmented_tls_mtproto_with_interleaved_ccs_is_accepted() {
.await .await
.unwrap(); .unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
let tls_response_len = u16::from_be_bytes([tls_response_head[3], tls_response_head[4]]) as usize; let tls_response_len =
u16::from_be_bytes([tls_response_head[3], tls_response_head[4]]) as usize;
let mut tls_response_body = vec![0u8; tls_response_len]; let mut tls_response_body = vec![0u8; tls_response_len];
client_side client_side
.read_exact(&mut tls_response_body) .read_exact(&mut tls_response_body)
@ -2533,14 +2534,16 @@ async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() {
} }
#[tokio::test] #[tokio::test]
async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { async fn zero_tcp_limit_uses_global_fallback_and_rejects_without_side_effects() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config
.access .access
.user_max_tcp_conns .user_max_tcp_conns
.insert("user".to_string(), 0); .insert("user".to_string(), 0);
config.access.user_max_tcp_conns_global_each = 1;
let stats = Stats::new(); let stats = Stats::new();
stats.increment_user_curr_connects("user");
let ip_tracker = UserIpTracker::new(); let ip_tracker = UserIpTracker::new();
let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap();
@ -2557,10 +2560,75 @@ async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() {
result, result,
Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user"
)); ));
assert_eq!(
stats.get_user_curr_connects("user"),
1,
"TCP-limit rejection must keep pre-existing in-flight connection count unchanged"
);
assert_eq!(ip_tracker.get_active_ip_count("user").await, 0);
}
#[tokio::test]
async fn zero_tcp_limit_with_disabled_global_fallback_is_unlimited() {
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert("user".to_string(), 0);
config.access.user_max_tcp_conns_global_each = 0;
let stats = Stats::new();
let ip_tracker = UserIpTracker::new();
let peer_addr: SocketAddr = "198.51.100.212:50002".parse().unwrap();
let result = RunningClientHandler::check_user_limits_static(
"user",
&config,
&stats,
peer_addr,
&ip_tracker,
)
.await;
assert!(
result.is_ok(),
"per-user zero with global fallback disabled must not enforce a TCP limit"
);
assert_eq!(stats.get_user_curr_connects("user"), 0); assert_eq!(stats.get_user_curr_connects("user"), 0);
assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); assert_eq!(ip_tracker.get_active_ip_count("user").await, 0);
} }
#[tokio::test]
async fn global_tcp_fallback_applies_when_per_user_limit_is_missing() {
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns_global_each = 1;
let stats = Stats::new();
stats.increment_user_curr_connects("user");
let ip_tracker = UserIpTracker::new();
let peer_addr: SocketAddr = "198.51.100.213:50003".parse().unwrap();
let result = RunningClientHandler::check_user_limits_static(
"user",
&config,
&stats,
peer_addr,
&ip_tracker,
)
.await;
assert!(matches!(
result,
Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user"
));
assert_eq!(
stats.get_user_curr_connects("user"),
1,
"Global fallback TCP-limit rejection must keep pre-existing counter unchanged"
);
assert_eq!(ip_tracker.get_active_ip_count("user").await, 0);
}
#[tokio::test] #[tokio::test]
async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() {
let user = "check-helper-user"; let user = "check-helper-user";

View File

@ -111,11 +111,17 @@ pub fn generate_service_file(init_system: InitSystem, opts: &ServiceOptions) ->
/// Generates an enhanced systemd unit file. /// Generates an enhanced systemd unit file.
fn generate_systemd_unit(opts: &ServiceOptions) -> String { fn generate_systemd_unit(opts: &ServiceOptions) -> String {
let user_line = opts.user.map(|u| format!("User={}", u)).unwrap_or_default(); let user_line = opts.user.map(|u| format!("User={}", u)).unwrap_or_default();
let group_line = opts.group.map(|g| format!("Group={}", g)).unwrap_or_default(); let group_line = opts
let working_dir = opts.working_dir.map(|d| format!("WorkingDirectory={}", d)).unwrap_or_default(); .group
.map(|g| format!("Group={}", g))
.unwrap_or_default();
let working_dir = opts
.working_dir
.map(|d| format!("WorkingDirectory={}", d))
.unwrap_or_default();
format!( format!(
r#"[Unit] r#"[Unit]
Description={description} Description={description}
Documentation=https://github.com/telemt/telemt Documentation=https://github.com/telemt/telemt
After=network-online.target After=network-online.target
@ -176,7 +182,7 @@ fn generate_openrc_script(opts: &ServiceOptions) -> String {
let group = opts.group.unwrap_or("root"); let group = opts.group.unwrap_or("root");
format!( format!(
r#"#!/sbin/openrc-run r#"#!/sbin/openrc-run
# OpenRC init script for telemt # OpenRC init script for telemt
description="{description}" description="{description}"
@ -218,7 +224,7 @@ fn generate_freebsd_rc_script(opts: &ServiceOptions) -> String {
let group = opts.group.unwrap_or("wheel"); let group = opts.group.unwrap_or("wheel");
format!( format!(
r#"#!/bin/sh r#"#!/bin/sh
# #
# PROVIDE: telemt # PROVIDE: telemt
# REQUIRE: LOGIN NETWORKING # REQUIRE: LOGIN NETWORKING
@ -284,7 +290,7 @@ run_rc_command "$1"
pub fn installation_instructions(init_system: InitSystem) -> &'static str { pub fn installation_instructions(init_system: InitSystem) -> &'static str {
match init_system { match init_system {
InitSystem::Systemd => { InitSystem::Systemd => {
r#"To install and enable the service: r#"To install and enable the service:
sudo systemctl daemon-reload sudo systemctl daemon-reload
sudo systemctl enable telemt sudo systemctl enable telemt
sudo systemctl start telemt sudo systemctl start telemt
@ -300,7 +306,7 @@ To reload configuration:
"# "#
} }
InitSystem::OpenRC => { InitSystem::OpenRC => {
r#"To install and enable the service: r#"To install and enable the service:
sudo chmod +x /etc/init.d/telemt sudo chmod +x /etc/init.d/telemt
sudo rc-update add telemt default sudo rc-update add telemt default
sudo rc-service telemt start sudo rc-service telemt start
@ -313,7 +319,7 @@ To reload configuration:
"# "#
} }
InitSystem::FreeBSDRc => { InitSystem::FreeBSDRc => {
r#"To install and enable the service: r#"To install and enable the service:
sudo chmod +x /usr/local/etc/rc.d/telemt sudo chmod +x /usr/local/etc/rc.d/telemt
sudo sysrc telemt_enable="YES" sudo sysrc telemt_enable="YES"
sudo service telemt start sudo service telemt start
@ -326,7 +332,7 @@ To reload configuration:
"# "#
} }
InitSystem::Unknown => { InitSystem::Unknown => {
r#"No supported init system detected. r#"No supported init system detected.
You may need to create a service file manually or run telemt directly: You may need to create a service file manually or run telemt directly:
telemt start /etc/telemt/config.toml telemt start /etc/telemt/config.toml
"# "#
@ -369,8 +375,14 @@ mod tests {
#[test] #[test]
fn test_service_file_paths() { fn test_service_file_paths() {
assert_eq!(service_file_path(InitSystem::Systemd), "/etc/systemd/system/telemt.service"); assert_eq!(
service_file_path(InitSystem::Systemd),
"/etc/systemd/system/telemt.service"
);
assert_eq!(service_file_path(InitSystem::OpenRC), "/etc/init.d/telemt"); assert_eq!(service_file_path(InitSystem::OpenRC), "/etc/init.d/telemt");
assert_eq!(service_file_path(InitSystem::FreeBSDRc), "/usr/local/etc/rc.d/telemt"); assert_eq!(
service_file_path(InitSystem::FreeBSDRc),
"/usr/local/etc/rc.d/telemt"
);
} }
} }