mirror of
https://github.com/telemt/telemt.git
synced 2026-05-01 17:34:09 +03:00
IP Limit fixes
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -13,7 +13,7 @@ use crate::config::UserMaxUniqueIpsMode;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserIpTracker {
|
||||
active_ips: Arc<RwLock<HashMap<String, HashSet<IpAddr>>>>,
|
||||
active_ips: Arc<RwLock<HashMap<String, HashMap<IpAddr, usize>>>>,
|
||||
recent_ips: Arc<RwLock<HashMap<String, HashMap<IpAddr, Instant>>>>,
|
||||
max_ips: Arc<RwLock<HashMap<String, usize>>>,
|
||||
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
|
||||
@@ -74,7 +74,7 @@ impl UserIpTracker {
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
let user_active = active_ips
|
||||
.entry(username.to_string())
|
||||
.or_insert_with(HashSet::new);
|
||||
.or_insert_with(HashMap::new);
|
||||
|
||||
let mut recent_ips = self.recent_ips.write().await;
|
||||
let user_recent = recent_ips
|
||||
@@ -82,7 +82,8 @@ impl UserIpTracker {
|
||||
.or_insert_with(HashMap::new);
|
||||
Self::prune_recent(user_recent, now, window);
|
||||
|
||||
if user_active.contains(&ip) {
|
||||
if let Some(count) = user_active.get_mut(&ip) {
|
||||
*count = count.saturating_add(1);
|
||||
user_recent.insert(ip, now);
|
||||
return Ok(());
|
||||
}
|
||||
@@ -109,7 +110,7 @@ impl UserIpTracker {
|
||||
}
|
||||
}
|
||||
|
||||
user_active.insert(ip);
|
||||
user_active.insert(ip, 1);
|
||||
user_recent.insert(ip, now);
|
||||
Ok(())
|
||||
}
|
||||
@@ -117,7 +118,13 @@ impl UserIpTracker {
|
||||
pub async fn remove_ip(&self, username: &str, ip: IpAddr) {
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
if let Some(user_ips) = active_ips.get_mut(username) {
|
||||
user_ips.remove(&ip);
|
||||
if let Some(count) = user_ips.get_mut(&ip) {
|
||||
if *count > 1 {
|
||||
*count -= 1;
|
||||
} else {
|
||||
user_ips.remove(&ip);
|
||||
}
|
||||
}
|
||||
if user_ips.is_empty() {
|
||||
active_ips.remove(username);
|
||||
}
|
||||
@@ -144,6 +151,41 @@ impl UserIpTracker {
|
||||
counts
|
||||
}
|
||||
|
||||
pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
||||
let active_ips = self.active_ips.read().await;
|
||||
let mut out = HashMap::with_capacity(users.len());
|
||||
for user in users {
|
||||
let mut ips = active_ips
|
||||
.get(user)
|
||||
.map(|per_ip| per_ip.keys().copied().collect::<Vec<_>>())
|
||||
.unwrap_or_else(Vec::new);
|
||||
ips.sort();
|
||||
out.insert(user.clone(), ips);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap<String, Vec<IpAddr>> {
|
||||
let window = *self.limit_window.read().await;
|
||||
let now = Instant::now();
|
||||
let mut recent_ips = self.recent_ips.write().await;
|
||||
|
||||
let mut out = HashMap::with_capacity(users.len());
|
||||
for user in users {
|
||||
let mut ips = if let Some(user_recent) = recent_ips.get_mut(user) {
|
||||
Self::prune_recent(user_recent, now, window);
|
||||
user_recent.keys().copied().collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
ips.sort();
|
||||
out.insert(user.clone(), ips);
|
||||
}
|
||||
|
||||
recent_ips.retain(|_, user_recent| !user_recent.is_empty());
|
||||
out
|
||||
}
|
||||
|
||||
pub async fn get_active_ip_count(&self, username: &str) -> usize {
|
||||
let active_ips = self.active_ips.read().await;
|
||||
active_ips.get(username).map(|ips| ips.len()).unwrap_or(0)
|
||||
@@ -153,7 +195,7 @@ impl UserIpTracker {
|
||||
let active_ips = self.active_ips.read().await;
|
||||
active_ips
|
||||
.get(username)
|
||||
.map(|ips| ips.iter().copied().collect())
|
||||
.map(|ips| ips.keys().copied().collect())
|
||||
.unwrap_or_else(Vec::new)
|
||||
}
|
||||
|
||||
@@ -193,7 +235,7 @@ impl UserIpTracker {
|
||||
let active_ips = self.active_ips.read().await;
|
||||
active_ips
|
||||
.get(username)
|
||||
.map(|ips| ips.contains(&ip))
|
||||
.map(|ips| ips.contains_key(&ip))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
@@ -269,6 +311,26 @@ mod tests {
|
||||
assert_eq!(tracker.get_active_ip_count("test_user").await, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_active_window_rejects_new_ip_and_keeps_existing_session() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("test_user", 1).await;
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::ActiveWindow, 30)
|
||||
.await;
|
||||
|
||||
let ip1 = test_ipv4(10, 10, 10, 1);
|
||||
let ip2 = test_ipv4(10, 10, 10, 2);
|
||||
|
||||
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
|
||||
assert!(tracker.is_ip_active("test_user", ip1).await);
|
||||
assert!(tracker.check_and_add("test_user", ip2).await.is_err());
|
||||
|
||||
// Existing session remains active; only new unique IP is denied.
|
||||
assert!(tracker.is_ip_active("test_user", ip1).await);
|
||||
assert_eq!(tracker.get_active_ip_count("test_user").await, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reconnection_from_same_ip() {
|
||||
let tracker = UserIpTracker::new();
|
||||
@@ -281,6 +343,24 @@ mod tests {
|
||||
assert_eq!(tracker.get_active_ip_count("test_user").await, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_same_ip_disconnect_keeps_active_while_other_session_alive() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("test_user", 2).await;
|
||||
|
||||
let ip1 = test_ipv4(192, 168, 1, 1);
|
||||
|
||||
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
|
||||
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
|
||||
assert_eq!(tracker.get_active_ip_count("test_user").await, 1);
|
||||
|
||||
tracker.remove_ip("test_user", ip1).await;
|
||||
assert_eq!(tracker.get_active_ip_count("test_user").await, 1);
|
||||
|
||||
tracker.remove_ip("test_user", ip1).await;
|
||||
assert_eq!(tracker.get_active_ip_count("test_user").await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ip_removal() {
|
||||
let tracker = UserIpTracker::new();
|
||||
|
||||
Reference in New Issue
Block a user