Time-To-Life for TLS Full Certificate

This commit is contained in:
Alexey
2026-02-23 05:47:44 +03:00
parent cfe8fc72a5
commit b5d0564f2a
5 changed files with 107 additions and 8 deletions

View File

@@ -1,7 +1,8 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{SystemTime, Duration};
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::RwLock;
use tokio::time::sleep;
@@ -14,7 +15,7 @@ use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsFetchResult};
pub struct TlsFrontCache {
memory: RwLock<HashMap<String, Arc<CachedTlsData>>>,
default: Arc<CachedTlsData>,
full_cert_sent: RwLock<HashSet<String>>,
full_cert_sent: RwLock<HashMap<(String, IpAddr), Instant>>,
disk_path: PathBuf,
}
@@ -47,7 +48,7 @@ impl TlsFrontCache {
Self {
memory: RwLock::new(map),
default,
full_cert_sent: RwLock::new(HashSet::new()),
full_cert_sent: RwLock::new(HashMap::new()),
disk_path: disk_path.as_ref().to_path_buf(),
}
}
@@ -61,9 +62,41 @@ impl TlsFrontCache {
self.memory.read().await.contains_key(domain)
}
/// Returns true only on first request for a domain after process start.
pub async fn take_full_cert_budget(&self, domain: &str) -> bool {
self.full_cert_sent.write().await.insert(domain.to_string())
/// Returns true when full cert payload should be sent for (domain, client_ip)
/// according to TTL policy.
pub async fn take_full_cert_budget_for_ip(
&self,
domain: &str,
client_ip: IpAddr,
ttl: Duration,
) -> bool {
if ttl.is_zero() {
self.full_cert_sent
.write()
.await
.insert((domain.to_string(), client_ip), Instant::now());
return true;
}
let now = Instant::now();
let mut guard = self.full_cert_sent.write().await;
guard.retain(|_, seen_at| now.duration_since(*seen_at) < ttl);
let key = (domain.to_string(), client_ip);
match guard.get_mut(&key) {
Some(seen_at) => {
if now.duration_since(*seen_at) >= ttl {
*seen_at = now;
true
} else {
false
}
}
None => {
guard.insert(key, now);
true
}
}
}
pub async fn set(&self, domain: &str, data: CachedTlsData) {
@@ -174,3 +207,50 @@ impl TlsFrontCache {
&self.disk_path
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_take_full_cert_budget_for_ip_uses_ttl() {
let cache = TlsFrontCache::new(
&["example.com".to_string()],
1024,
"tlsfront-test-cache",
);
let ip: IpAddr = "127.0.0.1".parse().expect("ip");
let ttl = Duration::from_millis(80);
assert!(cache
.take_full_cert_budget_for_ip("example.com", ip, ttl)
.await);
assert!(!cache
.take_full_cert_budget_for_ip("example.com", ip, ttl)
.await);
tokio::time::sleep(Duration::from_millis(90)).await;
assert!(cache
.take_full_cert_budget_for_ip("example.com", ip, ttl)
.await);
}
#[tokio::test]
async fn test_take_full_cert_budget_for_ip_zero_ttl_always_allows_full_payload() {
let cache = TlsFrontCache::new(
&["example.com".to_string()],
1024,
"tlsfront-test-cache",
);
let ip: IpAddr = "127.0.0.1".parse().expect("ip");
let ttl = Duration::ZERO;
assert!(cache
.take_full_cert_budget_for_ip("example.com", ip, ttl)
.await);
assert!(cache
.take_full_cert_budget_for_ip("example.com", ip, ttl)
.await);
}
}