mirror of https://github.com/telemt/telemt.git
commit
5ee4556cea
|
|
@ -24,6 +24,7 @@ zeroize = { version = "1.8", features = ["derive"] }
|
|||
|
||||
# Network
|
||||
socket2 = { version = "0.5", features = ["all"] }
|
||||
nix = { version = "0.28", default-features = false, features = ["net"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
|
@ -47,6 +48,7 @@ regex = "1.11"
|
|||
crossbeam-queue = "0.3"
|
||||
num-bigint = "0.4"
|
||||
num-traits = "0.2"
|
||||
anyhow = "1.0"
|
||||
|
||||
# HTTP
|
||||
reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
|
||||
|
|
@ -54,6 +56,9 @@ hyper = { version = "1", features = ["server", "http1"] }
|
|||
hyper-util = { version = "0.1", features = ["tokio", "server-auto"] }
|
||||
http-body-util = "0.1"
|
||||
httpdate = "1.0"
|
||||
tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] }
|
||||
rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] }
|
||||
webpki-roots = "0.26"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ pub(crate) fn default_fake_cert_len() -> usize {
|
|||
2048
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_front_dir() -> String {
|
||||
"tlsfront".to_string()
|
||||
}
|
||||
|
||||
pub(crate) fn default_replay_check_len() -> usize {
|
||||
65_536
|
||||
}
|
||||
|
|
|
|||
|
|
@ -163,6 +163,21 @@ impl ProxyConfig {
|
|||
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
|
||||
}
|
||||
|
||||
// Merge primary + extra TLS domains, deduplicate (primary always first).
|
||||
if !config.censorship.tls_domains.is_empty() {
|
||||
let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len());
|
||||
all.push(config.censorship.tls_domain.clone());
|
||||
for d in std::mem::take(&mut config.censorship.tls_domains) {
|
||||
if !d.is_empty() && !all.contains(&d) {
|
||||
all.push(d);
|
||||
}
|
||||
}
|
||||
// keep primary as tls_domain; store remaining back to tls_domains
|
||||
if all.len() > 1 {
|
||||
config.censorship.tls_domains = all[1..].to_vec();
|
||||
}
|
||||
}
|
||||
|
||||
// Migration: prefer_ipv6 -> network.prefer.
|
||||
if config.general.prefer_ipv6 {
|
||||
if config.network.prefer == 4 {
|
||||
|
|
@ -180,7 +195,7 @@ impl ProxyConfig {
|
|||
validate_network_cfg(&mut config.network)?;
|
||||
|
||||
// Random fake_cert_len only when default is in use.
|
||||
if config.censorship.fake_cert_len == default_fake_cert_len() {
|
||||
if !config.censorship.tls_emulation && config.censorship.fake_cert_len == default_fake_cert_len() {
|
||||
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
|
||||
}
|
||||
|
||||
|
|
@ -235,7 +250,7 @@ impl ProxyConfig {
|
|||
// Migration: Populate upstreams if empty (Default Direct).
|
||||
if config.upstreams.is_empty() {
|
||||
config.upstreams.push(UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct { interface: None },
|
||||
upstream_type: UpstreamType::Direct { interface: None, bind_addresses: None },
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
|
|
|
|||
|
|
@ -295,6 +295,11 @@ pub struct ServerConfig {
|
|||
#[serde(default)]
|
||||
pub listen_tcp: Option<bool>,
|
||||
|
||||
/// Accept HAProxy PROXY protocol headers on incoming connections.
|
||||
/// When enabled, real client IPs are extracted from PROXY v1/v2 headers.
|
||||
#[serde(default)]
|
||||
pub proxy_protocol: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub metrics_port: Option<u16>,
|
||||
|
||||
|
|
@ -314,6 +319,7 @@ impl Default for ServerConfig {
|
|||
listen_unix_sock: None,
|
||||
listen_unix_sock_perm: None,
|
||||
listen_tcp: None,
|
||||
proxy_protocol: false,
|
||||
metrics_port: None,
|
||||
metrics_whitelist: default_metrics_whitelist(),
|
||||
listeners: Vec::new(),
|
||||
|
|
@ -362,6 +368,10 @@ pub struct AntiCensorshipConfig {
|
|||
#[serde(default = "default_tls_domain")]
|
||||
pub tls_domain: String,
|
||||
|
||||
/// Additional TLS domains for generating multiple proxy links.
|
||||
#[serde(default)]
|
||||
pub tls_domains: Vec<String>,
|
||||
|
||||
#[serde(default = "default_true")]
|
||||
pub mask: bool,
|
||||
|
||||
|
|
@ -376,17 +386,28 @@ pub struct AntiCensorshipConfig {
|
|||
|
||||
#[serde(default = "default_fake_cert_len")]
|
||||
pub fake_cert_len: usize,
|
||||
|
||||
/// Enable TLS certificate emulation using cached real certificates.
|
||||
#[serde(default)]
|
||||
pub tls_emulation: bool,
|
||||
|
||||
/// Directory to store TLS front cache (on disk).
|
||||
#[serde(default = "default_tls_front_dir")]
|
||||
pub tls_front_dir: String,
|
||||
}
|
||||
|
||||
impl Default for AntiCensorshipConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tls_domain: default_tls_domain(),
|
||||
tls_domains: Vec::new(),
|
||||
mask: true,
|
||||
mask_host: None,
|
||||
mask_port: default_mask_port(),
|
||||
mask_unix_sock: None,
|
||||
fake_cert_len: default_fake_cert_len(),
|
||||
tls_emulation: false,
|
||||
tls_front_dir: default_tls_front_dir(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -446,6 +467,8 @@ pub enum UpstreamType {
|
|||
Direct {
|
||||
#[serde(default)]
|
||||
interface: Option<String>,
|
||||
#[serde(default)]
|
||||
bind_addresses: Option<Vec<String>>,
|
||||
},
|
||||
Socks4 {
|
||||
address: String,
|
||||
|
|
|
|||
71
src/main.rs
71
src/main.rs
|
|
@ -23,6 +23,7 @@ mod proxy;
|
|||
mod stats;
|
||||
mod stream;
|
||||
mod transport;
|
||||
mod tls_front;
|
||||
mod util;
|
||||
|
||||
use crate::config::{LogLevel, ProxyConfig};
|
||||
|
|
@ -36,6 +37,7 @@ use crate::transport::middle_proxy::{
|
|||
MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line,
|
||||
};
|
||||
use crate::transport::{ListenOptions, UpstreamManager, create_listener};
|
||||
use crate::tls_front::TlsFrontCache;
|
||||
|
||||
fn parse_cli() -> (String, bool, Option<String>) {
|
||||
let mut config_path = "config.toml".to_string();
|
||||
|
|
@ -129,12 +131,22 @@ fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
|
|||
);
|
||||
}
|
||||
if config.general.modes.tls {
|
||||
let domain_hex = hex::encode(&config.censorship.tls_domain);
|
||||
info!(
|
||||
target: "telemt::links",
|
||||
" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||
host, port, secret, domain_hex
|
||||
);
|
||||
let mut domains = Vec::with_capacity(1 + config.censorship.tls_domains.len());
|
||||
domains.push(config.censorship.tls_domain.clone());
|
||||
for d in &config.censorship.tls_domains {
|
||||
if !domains.contains(d) {
|
||||
domains.push(d.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for domain in domains {
|
||||
let domain_hex = hex::encode(&domain);
|
||||
info!(
|
||||
target: "telemt::links",
|
||||
" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||
host, port, secret, domain_hex
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!(target: "telemt::links", "User '{}' in show_link not found", user_name);
|
||||
|
|
@ -247,6 +259,46 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||
info!("IP limits configured for {} users", config.access.user_max_unique_ips.len());
|
||||
}
|
||||
|
||||
// TLS front cache (optional emulation)
|
||||
let mut tls_domains = Vec::with_capacity(1 + config.censorship.tls_domains.len());
|
||||
tls_domains.push(config.censorship.tls_domain.clone());
|
||||
for d in &config.censorship.tls_domains {
|
||||
if !tls_domains.contains(d) {
|
||||
tls_domains.push(d.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let tls_cache: Option<Arc<TlsFrontCache>> = if config.censorship.tls_emulation {
|
||||
let cache = Arc::new(TlsFrontCache::new(
|
||||
&tls_domains,
|
||||
config.censorship.fake_cert_len,
|
||||
&config.censorship.tls_front_dir,
|
||||
));
|
||||
|
||||
let cache_clone = cache.clone();
|
||||
let domains = tls_domains.clone();
|
||||
let port = config.censorship.mask_port;
|
||||
tokio::spawn(async move {
|
||||
for domain in domains {
|
||||
match crate::tls_front::fetcher::fetch_real_tls(
|
||||
&domain,
|
||||
port,
|
||||
&domain,
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(res) => cache_clone.update_from_fetch(&domain, res).await,
|
||||
Err(e) => warn!(domain = %domain, error = %e, "TLS emulation fetch failed"),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Some(cache)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Connection concurrency limit
|
||||
let _max_connections = Arc::new(Semaphore::new(10_000));
|
||||
|
||||
|
|
@ -715,6 +767,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
|
|||
let buffer_pool = buffer_pool.clone();
|
||||
let rng = rng.clone();
|
||||
let me_pool = me_pool.clone();
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
|
|
@ -733,13 +786,14 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
|
|||
let buffer_pool = buffer_pool.clone();
|
||||
let rng = rng.clone();
|
||||
let me_pool = me_pool.clone();
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::proxy::client::handle_client_stream(
|
||||
stream, fake_peer, config, stats,
|
||||
upstream_manager, replay_checker, buffer_pool, rng,
|
||||
me_pool, ip_tracker,
|
||||
me_pool, tls_cache, ip_tracker,
|
||||
).await {
|
||||
debug!(error = %e, "Unix socket connection error");
|
||||
}
|
||||
|
|
@ -787,6 +841,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
|
|||
let buffer_pool = buffer_pool.clone();
|
||||
let rng = rng.clone();
|
||||
let me_pool = me_pool.clone();
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
|
|
@ -800,6 +855,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
|
|||
let buffer_pool = buffer_pool.clone();
|
||||
let rng = rng.clone();
|
||||
let me_pool = me_pool.clone();
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
|
|
@ -813,6 +869,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
|
|||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
)
|
||||
.run()
|
||||
|
|
|
|||
|
|
@ -397,6 +397,84 @@ pub fn build_server_hello(
|
|||
response
|
||||
}
|
||||
|
||||
/// Extract SNI (server_name) from a TLS ClientHello.
|
||||
pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
||||
if handshake.len() < 43 || handshake[0] != TLS_RECORD_HANDSHAKE {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut pos = 5; // after record header
|
||||
if handshake.get(pos).copied()? != 0x01 {
|
||||
return None; // not ClientHello
|
||||
}
|
||||
|
||||
// Handshake length bytes
|
||||
pos += 4; // type + len (3)
|
||||
|
||||
// version (2) + random (32)
|
||||
pos += 2 + 32;
|
||||
if pos + 1 > handshake.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let session_id_len = *handshake.get(pos)? as usize;
|
||||
pos += 1 + session_id_len;
|
||||
if pos + 2 > handshake.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let cipher_suites_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
|
||||
pos += 2 + cipher_suites_len;
|
||||
if pos + 1 > handshake.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let comp_len = *handshake.get(pos)? as usize;
|
||||
pos += 1 + comp_len;
|
||||
if pos + 2 > handshake.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
|
||||
pos += 2;
|
||||
let ext_end = pos + ext_len;
|
||||
if ext_end > handshake.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
while pos + 4 <= ext_end {
|
||||
let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
|
||||
let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize;
|
||||
pos += 4;
|
||||
if pos + elen > ext_end {
|
||||
break;
|
||||
}
|
||||
if etype == 0x0000 && elen >= 5 {
|
||||
// server_name extension
|
||||
let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
|
||||
let mut sn_pos = pos + 2;
|
||||
let sn_end = std::cmp::min(sn_pos + list_len, pos + elen);
|
||||
while sn_pos + 3 <= sn_end {
|
||||
let name_type = handshake[sn_pos];
|
||||
let name_len = u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize;
|
||||
sn_pos += 3;
|
||||
if sn_pos + name_len > sn_end {
|
||||
break;
|
||||
}
|
||||
if name_type == 0 && name_len > 0 {
|
||||
if let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) {
|
||||
return Some(host.to_string());
|
||||
}
|
||||
}
|
||||
sn_pos += name_len;
|
||||
}
|
||||
}
|
||||
pos += elen;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if bytes look like a TLS ClientHello
|
||||
pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
|
||||
if first_bytes.len() < 3 {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ use crate::protocol::tls;
|
|||
use crate::stats::{ReplayChecker, Stats};
|
||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||
use crate::transport::middle_proxy::MePool;
|
||||
use crate::transport::{UpstreamManager, configure_client_socket};
|
||||
use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol};
|
||||
use crate::tls_front::TlsFrontCache;
|
||||
|
||||
use crate::proxy::direct_relay::handle_via_direct;
|
||||
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
|
||||
|
|
@ -47,13 +48,35 @@ pub async fn handle_client_stream<S>(
|
|||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
stats.increment_connects_all();
|
||||
debug!(peer = %peer, "New connection (generic stream)");
|
||||
let mut real_peer = peer;
|
||||
|
||||
if config.server.proxy_protocol {
|
||||
match parse_proxy_protocol(&mut stream, peer).await {
|
||||
Ok(info) => {
|
||||
debug!(
|
||||
peer = %peer,
|
||||
client = %info.src_addr,
|
||||
version = info.version,
|
||||
"PROXY protocol header parsed"
|
||||
);
|
||||
real_peer = info.src_addr;
|
||||
}
|
||||
Err(e) => {
|
||||
stats.increment_connects_bad();
|
||||
warn!(peer = %peer, error = %e, "Invalid PROXY protocol header");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(peer = %real_peer, "New connection (generic stream)");
|
||||
|
||||
let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake);
|
||||
let stats_for_timeout = stats.clone();
|
||||
|
|
@ -69,13 +92,13 @@ where
|
|||
stream.read_exact(&mut first_bytes).await?;
|
||||
|
||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
||||
debug!(peer = %real_peer, is_tls = is_tls, "Handshake type detected");
|
||||
|
||||
if is_tls {
|
||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||
|
||||
if tls_len < 512 {
|
||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
||||
debug!(peer = %real_peer, tls_len = tls_len, "TLS handshake too short");
|
||||
stats.increment_connects_bad();
|
||||
let (reader, writer) = tokio::io::split(stream);
|
||||
handle_bad_client(reader, writer, &first_bytes, &config).await;
|
||||
|
|
@ -89,8 +112,8 @@ where
|
|||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
|
||||
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
||||
&handshake, read_half, write_half, peer,
|
||||
&config, &replay_checker, &rng,
|
||||
&handshake, read_half, write_half, real_peer,
|
||||
&config, &replay_checker, &rng, tls_cache.clone(),
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient { reader, writer } => {
|
||||
|
|
@ -107,7 +130,7 @@ where
|
|||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
||||
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
&mtproto_handshake, tls_reader, tls_writer, peer,
|
||||
&mtproto_handshake, tls_reader, tls_writer, real_peer,
|
||||
&config, &replay_checker, true,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
|
|
@ -123,12 +146,12 @@ where
|
|||
RunningClientHandler::handle_authenticated_static(
|
||||
crypto_reader, crypto_writer, success,
|
||||
upstream_manager, stats, config, buffer_pool, rng, me_pool,
|
||||
local_addr, peer, ip_tracker.clone(),
|
||||
local_addr, real_peer, ip_tracker.clone(),
|
||||
),
|
||||
)))
|
||||
} else {
|
||||
if !config.general.modes.classic && !config.general.modes.secure {
|
||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||
debug!(peer = %real_peer, "Non-TLS modes disabled");
|
||||
stats.increment_connects_bad();
|
||||
let (reader, writer) = tokio::io::split(stream);
|
||||
handle_bad_client(reader, writer, &first_bytes, &config).await;
|
||||
|
|
@ -142,7 +165,7 @@ where
|
|||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
&handshake, read_half, write_half, peer,
|
||||
&handshake, read_half, write_half, real_peer,
|
||||
&config, &replay_checker, false,
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
|
|
@ -166,7 +189,7 @@ where
|
|||
rng,
|
||||
me_pool,
|
||||
local_addr,
|
||||
peer,
|
||||
real_peer,
|
||||
ip_tracker.clone(),
|
||||
)
|
||||
)))
|
||||
|
|
@ -203,6 +226,7 @@ pub struct RunningClientHandler {
|
|||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
}
|
||||
|
||||
|
|
@ -217,6 +241,7 @@ impl ClientHandler {
|
|||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
) -> RunningClientHandler {
|
||||
RunningClientHandler {
|
||||
|
|
@ -229,6 +254,7 @@ impl ClientHandler {
|
|||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
}
|
||||
}
|
||||
|
|
@ -275,6 +301,25 @@ impl RunningClientHandler {
|
|||
}
|
||||
|
||||
async fn do_handshake(mut self) -> Result<HandshakeOutcome> {
|
||||
if self.config.server.proxy_protocol {
|
||||
match parse_proxy_protocol(&mut self.stream, self.peer).await {
|
||||
Ok(info) => {
|
||||
debug!(
|
||||
peer = %self.peer,
|
||||
client = %info.src_addr,
|
||||
version = info.version,
|
||||
"PROXY protocol header parsed"
|
||||
);
|
||||
self.peer = info.src_addr;
|
||||
}
|
||||
Err(e) => {
|
||||
self.stats.increment_connects_bad();
|
||||
warn!(peer = %self.peer, error = %e, "Invalid PROXY protocol header");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut first_bytes = [0u8; 5];
|
||||
self.stream.read_exact(&mut first_bytes).await?;
|
||||
|
||||
|
|
@ -327,6 +372,7 @@ impl RunningClientHandler {
|
|||
&config,
|
||||
&replay_checker,
|
||||
&self.rng,
|
||||
self.tls_cache.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
//! MTProto Handshake
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tracing::{debug, warn, trace, info};
|
||||
use zeroize::Zeroize;
|
||||
|
|
@ -12,6 +13,7 @@ use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
|
|||
use crate::error::{ProxyError, HandshakeResult};
|
||||
use crate::stats::ReplayChecker;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::tls_front::{TlsFrontCache, emulator};
|
||||
|
||||
/// Result of successful handshake
|
||||
///
|
||||
|
|
@ -55,6 +57,7 @@ pub async fn handle_tls_handshake<R, W>(
|
|||
config: &ProxyConfig,
|
||||
replay_checker: &ReplayChecker,
|
||||
rng: &SecureRandom,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
|
|
@ -102,13 +105,37 @@ where
|
|||
None => return HandshakeResult::BadClient { reader, writer },
|
||||
};
|
||||
|
||||
let response = tls::build_server_hello(
|
||||
secret,
|
||||
&validation.digest,
|
||||
&validation.session_id,
|
||||
config.censorship.fake_cert_len,
|
||||
rng,
|
||||
);
|
||||
let cached = if config.censorship.tls_emulation {
|
||||
if let Some(cache) = tls_cache.as_ref() {
|
||||
if let Some(sni) = tls::extract_sni_from_client_hello(handshake) {
|
||||
Some(cache.get(&sni).await)
|
||||
} else {
|
||||
Some(cache.get(&config.censorship.tls_domain).await)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let response = if let Some(cached_entry) = cached {
|
||||
emulator::build_emulated_server_hello(
|
||||
secret,
|
||||
&validation.digest,
|
||||
&validation.session_id,
|
||||
&cached_entry,
|
||||
rng,
|
||||
)
|
||||
} else {
|
||||
tls::build_server_hello(
|
||||
secret,
|
||||
&validation.digest,
|
||||
&validation.session_id,
|
||||
config.censorship.fake_cert_len,
|
||||
rng,
|
||||
)
|
||||
};
|
||||
|
||||
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, Duration};
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::sleep;
|
||||
use tracing::{debug, warn, info};
|
||||
|
||||
use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsFetchResult};
|
||||
|
||||
/// Lightweight in-memory + optional on-disk cache for TLS fronting data.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsFrontCache {
|
||||
memory: RwLock<HashMap<String, Arc<CachedTlsData>>>,
|
||||
default: Arc<CachedTlsData>,
|
||||
disk_path: PathBuf,
|
||||
}
|
||||
|
||||
impl TlsFrontCache {
|
||||
pub fn new(domains: &[String], default_len: usize, disk_path: impl AsRef<Path>) -> Self {
|
||||
let default_template = ParsedServerHello {
|
||||
version: [0x03, 0x03],
|
||||
random: [0u8; 32],
|
||||
session_id: Vec::new(),
|
||||
cipher_suite: [0x13, 0x01],
|
||||
compression: 0,
|
||||
extensions: Vec::new(),
|
||||
};
|
||||
|
||||
let default = Arc::new(CachedTlsData {
|
||||
server_hello_template: default_template,
|
||||
cert_info: None,
|
||||
app_data_records_sizes: vec![default_len],
|
||||
total_app_data_len: default_len,
|
||||
fetched_at: SystemTime::now(),
|
||||
domain: "default".to_string(),
|
||||
});
|
||||
|
||||
let mut map = HashMap::new();
|
||||
for d in domains {
|
||||
map.insert(d.clone(), default.clone());
|
||||
}
|
||||
|
||||
Self {
|
||||
memory: RwLock::new(map),
|
||||
default,
|
||||
disk_path: disk_path.as_ref().to_path_buf(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get(&self, sni: &str) -> Arc<CachedTlsData> {
|
||||
let guard = self.memory.read().await;
|
||||
guard.get(sni).cloned().unwrap_or_else(|| self.default.clone())
|
||||
}
|
||||
|
||||
pub async fn set(&self, domain: &str, data: CachedTlsData) {
|
||||
let mut guard = self.memory.write().await;
|
||||
guard.insert(domain.to_string(), Arc::new(data));
|
||||
}
|
||||
|
||||
/// Spawn background updater that periodically refreshes cached domains using provided fetcher.
|
||||
pub fn spawn_updater<F>(
|
||||
self: Arc<Self>,
|
||||
domains: Vec<String>,
|
||||
interval: Duration,
|
||||
fetcher: F,
|
||||
) where
|
||||
F: Fn(String) -> tokio::task::JoinHandle<()> + Send + Sync + 'static,
|
||||
{
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
for domain in &domains {
|
||||
fetcher(domain.clone()).await;
|
||||
}
|
||||
sleep(interval).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Replace cached entry from a fetch result.
|
||||
pub async fn update_from_fetch(&self, domain: &str, fetched: TlsFetchResult) {
|
||||
let data = CachedTlsData {
|
||||
server_hello_template: fetched.server_hello_parsed,
|
||||
cert_info: None,
|
||||
app_data_records_sizes: fetched.app_data_records_sizes.clone(),
|
||||
total_app_data_len: fetched.total_app_data_len,
|
||||
fetched_at: SystemTime::now(),
|
||||
domain: domain.to_string(),
|
||||
};
|
||||
|
||||
self.set(domain, data).await;
|
||||
debug!(domain = %domain, len = fetched.total_app_data_len, "TLS cache updated");
|
||||
}
|
||||
|
||||
pub fn default_entry(&self) -> Arc<CachedTlsData> {
|
||||
self.default.clone()
|
||||
}
|
||||
|
||||
pub fn disk_path(&self) -> &Path {
|
||||
&self.disk_path
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
use crate::crypto::{sha256_hmac, SecureRandom};
|
||||
use crate::protocol::constants::{
|
||||
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION,
|
||||
};
|
||||
use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key};
|
||||
use crate::tls_front::types::CachedTlsData;
|
||||
|
||||
/// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata.
|
||||
pub fn build_emulated_server_hello(
|
||||
secret: &[u8],
|
||||
client_digest: &[u8; TLS_DIGEST_LEN],
|
||||
session_id: &[u8],
|
||||
cached: &CachedTlsData,
|
||||
rng: &SecureRandom,
|
||||
) -> Vec<u8> {
|
||||
// --- ServerHello ---
|
||||
let mut extensions = Vec::new();
|
||||
// KeyShare (x25519)
|
||||
let key = gen_fake_x25519_key(rng);
|
||||
extensions.extend_from_slice(&0x0033u16.to_be_bytes()); // key_share
|
||||
extensions.extend_from_slice(&(2 + 2 + 32u16).to_be_bytes()); // len
|
||||
extensions.extend_from_slice(&0x001du16.to_be_bytes()); // X25519
|
||||
extensions.extend_from_slice(&(32u16).to_be_bytes());
|
||||
extensions.extend_from_slice(&key);
|
||||
// supported_versions (TLS1.3)
|
||||
extensions.extend_from_slice(&0x002bu16.to_be_bytes());
|
||||
extensions.extend_from_slice(&(2u16).to_be_bytes());
|
||||
extensions.extend_from_slice(&0x0304u16.to_be_bytes());
|
||||
|
||||
let extensions_len = extensions.len() as u16;
|
||||
|
||||
let body_len = 2 + // version
|
||||
32 + // random
|
||||
1 + session_id.len() + // session id
|
||||
2 + // cipher
|
||||
1 + // compression
|
||||
2 + extensions.len(); // extensions
|
||||
|
||||
let mut message = Vec::with_capacity(4 + body_len);
|
||||
message.push(0x02); // ServerHello
|
||||
let len_bytes = (body_len as u32).to_be_bytes();
|
||||
message.extend_from_slice(&len_bytes[1..4]);
|
||||
message.extend_from_slice(&cached.server_hello_template.version); // 0x0303
|
||||
message.extend_from_slice(&[0u8; 32]); // random placeholder
|
||||
message.push(session_id.len() as u8);
|
||||
message.extend_from_slice(session_id);
|
||||
message.extend_from_slice(&cached.server_hello_template.cipher_suite);
|
||||
message.push(cached.server_hello_template.compression);
|
||||
message.extend_from_slice(&extensions_len.to_be_bytes());
|
||||
message.extend_from_slice(&extensions);
|
||||
|
||||
let mut server_hello = Vec::with_capacity(5 + message.len());
|
||||
server_hello.push(TLS_RECORD_HANDSHAKE);
|
||||
server_hello.extend_from_slice(&TLS_VERSION);
|
||||
server_hello.extend_from_slice(&(message.len() as u16).to_be_bytes());
|
||||
server_hello.extend_from_slice(&message);
|
||||
|
||||
// --- ChangeCipherSpec ---
|
||||
let change_cipher_spec = [
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
0x00,
|
||||
0x01,
|
||||
0x01,
|
||||
];
|
||||
|
||||
// --- ApplicationData (fake encrypted records) ---
|
||||
let sizes = if cached.app_data_records_sizes.is_empty() {
|
||||
vec![cached.total_app_data_len.max(1024)]
|
||||
} else {
|
||||
cached.app_data_records_sizes.clone()
|
||||
};
|
||||
|
||||
let mut app_data = Vec::new();
|
||||
for size in sizes {
|
||||
let mut rec = Vec::with_capacity(5 + size);
|
||||
rec.push(TLS_RECORD_APPLICATION);
|
||||
rec.extend_from_slice(&TLS_VERSION);
|
||||
rec.extend_from_slice(&(size as u16).to_be_bytes());
|
||||
rec.extend_from_slice(&rng.bytes(size));
|
||||
app_data.extend_from_slice(&rec);
|
||||
}
|
||||
|
||||
// --- Combine ---
|
||||
let mut response = Vec::with_capacity(server_hello.len() + change_cipher_spec.len() + app_data.len());
|
||||
response.extend_from_slice(&server_hello);
|
||||
response.extend_from_slice(&change_cipher_spec);
|
||||
response.extend_from_slice(&app_data);
|
||||
|
||||
// --- HMAC ---
|
||||
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
|
||||
hmac_input.extend_from_slice(client_digest);
|
||||
hmac_input.extend_from_slice(&response);
|
||||
let digest = sha256_hmac(secret, &hmac_input);
|
||||
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
|
||||
|
||||
response
|
||||
}
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::client::TlsStream;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tracing::debug;
|
||||
|
||||
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::client::ClientConfig;
|
||||
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
|
||||
use rustls::{DigitallySignedStruct, Error as RustlsError};
|
||||
|
||||
use crate::tls_front::types::{ParsedServerHello, TlsFetchResult};
|
||||
|
||||
/// No-op verifier: accept any certificate (we only need lengths and metadata).
|
||||
#[derive(Debug)]
|
||||
struct NoVerify;
|
||||
|
||||
impl ServerCertVerifier for NoVerify {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &ServerName<'_>,
|
||||
_ocsp: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> Result<ServerCertVerified, RustlsError> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, RustlsError> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, RustlsError> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
use rustls::SignatureScheme::*;
|
||||
vec![
|
||||
RSA_PKCS1_SHA256,
|
||||
RSA_PSS_SHA256,
|
||||
ECDSA_NISTP256_SHA256,
|
||||
ECDSA_NISTP384_SHA384,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
fn build_client_config() -> Arc<ClientConfig> {
|
||||
let root = rustls::RootCertStore::empty();
|
||||
|
||||
let provider = rustls::crypto::ring::default_provider();
|
||||
let mut config = ClientConfig::builder_with_provider(Arc::new(provider))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
.expect("protocol versions")
|
||||
.with_root_certificates(root)
|
||||
.with_no_client_auth();
|
||||
|
||||
config
|
||||
.dangerous()
|
||||
.set_certificate_verifier(Arc::new(NoVerify));
|
||||
|
||||
Arc::new(config)
|
||||
}
|
||||
|
||||
/// Fetch real TLS metadata for the given SNI: negotiated cipher and cert lengths.
|
||||
pub async fn fetch_real_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
sni: &str,
|
||||
connect_timeout: Duration,
|
||||
) -> Result<TlsFetchResult> {
|
||||
let addr = format!("{host}:{port}");
|
||||
let stream = timeout(connect_timeout, TcpStream::connect(addr)).await??;
|
||||
|
||||
let config = build_client_config();
|
||||
let connector = TlsConnector::from(config);
|
||||
|
||||
let server_name = ServerName::try_from(sni.to_owned())
|
||||
.or_else(|_| ServerName::try_from(host.to_owned()))
|
||||
.map_err(|_| RustlsError::General("invalid SNI".into()))?;
|
||||
|
||||
let tls_stream: TlsStream<TcpStream> = connector.connect(server_name, stream).await?;
|
||||
|
||||
// Extract negotiated parameters and certificates
|
||||
let (_io, session) = tls_stream.get_ref();
|
||||
let cipher_suite = session
|
||||
.negotiated_cipher_suite()
|
||||
.map(|s| u16::from(s.suite()).to_be_bytes())
|
||||
.unwrap_or([0x13, 0x01]);
|
||||
|
||||
let certs: Vec<CertificateDer<'static>> = session
|
||||
.peer_certificates()
|
||||
.map(|slice| slice.to_vec())
|
||||
.unwrap_or_default();
|
||||
|
||||
let total_cert_len: usize = certs.iter().map(|c| c.len()).sum::<usize>().max(1024);
|
||||
|
||||
// Heuristic: split across two records if large to mimic real servers a bit.
|
||||
let app_data_records_sizes = if total_cert_len > 3000 {
|
||||
vec![total_cert_len / 2, total_cert_len - total_cert_len / 2]
|
||||
} else {
|
||||
vec![total_cert_len]
|
||||
};
|
||||
|
||||
let parsed = ParsedServerHello {
|
||||
version: [0x03, 0x03],
|
||||
random: [0u8; 32],
|
||||
session_id: Vec::new(),
|
||||
cipher_suite,
|
||||
compression: 0,
|
||||
extensions: Vec::new(),
|
||||
};
|
||||
|
||||
debug!(
|
||||
sni = %sni,
|
||||
len = total_cert_len,
|
||||
cipher = format!("0x{:04x}", u16::from_be_bytes(cipher_suite)),
|
||||
"Fetched TLS metadata"
|
||||
);
|
||||
|
||||
Ok(TlsFetchResult {
|
||||
server_hello_parsed: parsed,
|
||||
app_data_records_sizes: app_data_records_sizes.clone(),
|
||||
total_app_data_len: app_data_records_sizes.iter().sum(),
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
pub mod types;
|
||||
pub mod cache;
|
||||
pub mod fetcher;
|
||||
pub mod emulator;
|
||||
|
||||
pub use cache::TlsFrontCache;
|
||||
pub use types::{CachedTlsData, TlsFetchResult};
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
use std::time::SystemTime;
|
||||
|
||||
/// Parsed representation of an unencrypted TLS ServerHello.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParsedServerHello {
|
||||
pub version: [u8; 2],
|
||||
pub random: [u8; 32],
|
||||
pub session_id: Vec<u8>,
|
||||
pub cipher_suite: [u8; 2],
|
||||
pub compression: u8,
|
||||
pub extensions: Vec<TlsExtension>,
|
||||
}
|
||||
|
||||
/// Generic TLS extension container.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TlsExtension {
|
||||
pub ext_type: u16,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Basic certificate metadata (optional, informative).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParsedCertificateInfo {
|
||||
pub not_after_unix: Option<i64>,
|
||||
pub not_before_unix: Option<i64>,
|
||||
pub issuer_cn: Option<String>,
|
||||
pub subject_cn: Option<String>,
|
||||
pub san_names: Vec<String>,
|
||||
}
|
||||
|
||||
/// Cached data per SNI used by the emulator.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedTlsData {
|
||||
pub server_hello_template: ParsedServerHello,
|
||||
pub cert_info: Option<ParsedCertificateInfo>,
|
||||
pub app_data_records_sizes: Vec<usize>,
|
||||
pub total_app_data_len: usize,
|
||||
pub fetched_at: SystemTime,
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
/// Result of attempting to fetch real TLS artifacts.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TlsFetchResult {
|
||||
pub server_hello_parsed: ParsedServerHello,
|
||||
pub app_data_records_sizes: Vec<usize>,
|
||||
pub total_app_data_len: usize,
|
||||
}
|
||||
|
|
@ -122,6 +122,38 @@ pub fn get_local_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
|||
stream.local_addr().ok()
|
||||
}
|
||||
|
||||
/// Resolve primary IP address of a network interface by name.
|
||||
/// Returns the first address matching the requested family (IPv4/IPv6).
|
||||
#[cfg(unix)]
|
||||
pub fn resolve_interface_ip(name: &str, want_ipv6: bool) -> Option<IpAddr> {
|
||||
use nix::ifaddrs::getifaddrs;
|
||||
|
||||
if let Ok(addrs) = getifaddrs() {
|
||||
for iface in addrs {
|
||||
if iface.interface_name == name {
|
||||
if let Some(address) = iface.address {
|
||||
if let Some(v4) = address.as_sockaddr_in() {
|
||||
if !want_ipv6 {
|
||||
return Some(IpAddr::V4(v4.ip()));
|
||||
}
|
||||
} else if let Some(v6) = address.as_sockaddr_in6() {
|
||||
if want_ipv6 {
|
||||
return Some(IpAddr::V6(v6.ip().clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Stub for non-Unix platforms: interface name resolution unsupported.
|
||||
#[cfg(not(unix))]
|
||||
pub fn resolve_interface_ip(_name: &str, _want_ipv6: bool) -> Option<IpAddr> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Get peer address of a socket
|
||||
pub fn get_peer_addr(stream: &TcpStream) -> Option<SocketAddr> {
|
||||
stream.peer_addr().ok()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
use std::collections::HashMap;
|
||||
use std::net::{SocketAddr, IpAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::RwLock;
|
||||
|
|
@ -15,7 +16,7 @@ use tracing::{debug, warn, info, trace};
|
|||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::error::{Result, ProxyError};
|
||||
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
|
||||
use crate::transport::socket::create_outgoing_socket_bound;
|
||||
use crate::transport::socket::{create_outgoing_socket_bound, resolve_interface_ip};
|
||||
use crate::transport::socks::{connect_socks4, connect_socks5};
|
||||
|
||||
/// Number of Telegram datacenters
|
||||
|
|
@ -84,6 +85,8 @@ struct UpstreamState {
|
|||
dc_latency: [LatencyEma; NUM_DCS],
|
||||
/// Per-DC IP version preference (learned from connectivity tests)
|
||||
dc_ip_pref: [IpPreference; NUM_DCS],
|
||||
/// Round-robin counter for bind_addresses selection
|
||||
bind_rr: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl UpstreamState {
|
||||
|
|
@ -95,6 +98,7 @@ impl UpstreamState {
|
|||
last_check: std::time::Instant::now(),
|
||||
dc_latency: [LatencyEma::new(0.3); NUM_DCS],
|
||||
dc_ip_pref: [IpPreference::Unknown; NUM_DCS],
|
||||
bind_rr: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -166,6 +170,46 @@ impl UpstreamManager {
|
|||
}
|
||||
}
|
||||
|
||||
fn resolve_bind_address(
|
||||
interface: &Option<String>,
|
||||
bind_addresses: &Option<Vec<String>>,
|
||||
target: SocketAddr,
|
||||
rr: Option<&AtomicUsize>,
|
||||
) -> Option<IpAddr> {
|
||||
let want_ipv6 = target.is_ipv6();
|
||||
|
||||
if let Some(addrs) = bind_addresses {
|
||||
let candidates: Vec<IpAddr> = addrs
|
||||
.iter()
|
||||
.filter_map(|s| s.parse::<IpAddr>().ok())
|
||||
.filter(|ip| ip.is_ipv6() == want_ipv6)
|
||||
.collect();
|
||||
|
||||
if !candidates.is_empty() {
|
||||
if let Some(counter) = rr {
|
||||
let idx = counter.fetch_add(1, Ordering::Relaxed) % candidates.len();
|
||||
return Some(candidates[idx]);
|
||||
}
|
||||
return candidates.first().copied();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(iface) = interface {
|
||||
if let Ok(ip) = iface.parse::<IpAddr>() {
|
||||
if ip.is_ipv6() == want_ipv6 {
|
||||
return Some(ip);
|
||||
}
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
if let Some(ip) = resolve_interface_ip(iface, want_ipv6) {
|
||||
return Some(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Select upstream using latency-weighted random selection.
|
||||
async fn select_upstream(&self, dc_idx: Option<i16>, scope: Option<&str>) -> Option<usize> {
|
||||
let upstreams = self.upstreams.read().await;
|
||||
|
|
@ -262,7 +306,12 @@ impl UpstreamManager {
|
|||
|
||||
let start = Instant::now();
|
||||
|
||||
match self.connect_via_upstream(&upstream, target).await {
|
||||
let bind_rr = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard.get(idx).map(|u| u.bind_rr.clone())
|
||||
};
|
||||
|
||||
match self.connect_via_upstream(&upstream, target, bind_rr).await {
|
||||
Ok(stream) => {
|
||||
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||
let mut guard = self.upstreams.write().await;
|
||||
|
|
@ -294,13 +343,27 @@ impl UpstreamManager {
|
|||
}
|
||||
}
|
||||
|
||||
async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> {
|
||||
async fn connect_via_upstream(
|
||||
&self,
|
||||
config: &UpstreamConfig,
|
||||
target: SocketAddr,
|
||||
bind_rr: Option<Arc<AtomicUsize>>,
|
||||
) -> Result<TcpStream> {
|
||||
match &config.upstream_type {
|
||||
UpstreamType::Direct { interface } => {
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
UpstreamType::Direct { interface, bind_addresses } => {
|
||||
let bind_ip = Self::resolve_bind_address(
|
||||
interface,
|
||||
bind_addresses,
|
||||
target,
|
||||
bind_rr.as_deref(),
|
||||
);
|
||||
|
||||
let socket = create_outgoing_socket_bound(target, bind_ip)?;
|
||||
if let Some(ip) = bind_ip {
|
||||
debug!(bind = %ip, target = %target, "Bound outgoing socket");
|
||||
} else if interface.is_some() || bind_addresses.is_some() {
|
||||
debug!(target = %target, "No matching bind address for target family");
|
||||
}
|
||||
|
||||
socket.set_nonblocking(true)?;
|
||||
match socket.connect(&target.into()) {
|
||||
|
|
@ -323,8 +386,12 @@ impl UpstreamManager {
|
|||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
||||
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
let bind_ip = Self::resolve_bind_address(
|
||||
interface,
|
||||
&None,
|
||||
proxy_addr,
|
||||
bind_rr.as_deref(),
|
||||
);
|
||||
|
||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||
|
||||
|
|
@ -354,8 +421,12 @@ impl UpstreamManager {
|
|||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
||||
|
||||
let bind_ip = interface.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok());
|
||||
let bind_ip = Self::resolve_bind_address(
|
||||
interface,
|
||||
&None,
|
||||
proxy_addr,
|
||||
bind_rr.as_deref(),
|
||||
);
|
||||
|
||||
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
|
||||
|
||||
|
|
@ -398,18 +469,18 @@ impl UpstreamManager {
|
|||
ipv4_enabled: bool,
|
||||
ipv6_enabled: bool,
|
||||
) -> Vec<StartupPingResult> {
|
||||
let upstreams: Vec<(usize, UpstreamConfig)> = {
|
||||
let upstreams: Vec<(usize, UpstreamConfig, Arc<AtomicUsize>)> = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard.iter().enumerate()
|
||||
.map(|(i, u)| (i, u.config.clone()))
|
||||
.map(|(i, u)| (i, u.config.clone(), u.bind_rr.clone()))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut all_results = Vec::new();
|
||||
|
||||
for (upstream_idx, upstream_config) in &upstreams {
|
||||
for (upstream_idx, upstream_config, bind_rr) in &upstreams {
|
||||
let upstream_name = match &upstream_config.upstream_type {
|
||||
UpstreamType::Direct { interface } => {
|
||||
UpstreamType::Direct { interface, .. } => {
|
||||
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
|
||||
}
|
||||
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
|
||||
|
|
@ -424,7 +495,7 @@ impl UpstreamManager {
|
|||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(DC_PING_TIMEOUT_SECS),
|
||||
self.ping_single_dc(&upstream_config, addr_v6)
|
||||
self.ping_single_dc(&upstream_config, Some(bind_rr.clone()), addr_v6)
|
||||
).await;
|
||||
|
||||
let ping_result = match result {
|
||||
|
|
@ -475,7 +546,7 @@ impl UpstreamManager {
|
|||
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(DC_PING_TIMEOUT_SECS),
|
||||
self.ping_single_dc(&upstream_config, addr_v4)
|
||||
self.ping_single_dc(&upstream_config, Some(bind_rr.clone()), addr_v4)
|
||||
).await;
|
||||
|
||||
let ping_result = match result {
|
||||
|
|
@ -538,7 +609,7 @@ impl UpstreamManager {
|
|||
}
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(DC_PING_TIMEOUT_SECS),
|
||||
self.ping_single_dc(&upstream_config, addr)
|
||||
self.ping_single_dc(&upstream_config, Some(bind_rr.clone()), addr)
|
||||
).await;
|
||||
|
||||
let ping_result = match result {
|
||||
|
|
@ -607,9 +678,14 @@ impl UpstreamManager {
|
|||
all_results
|
||||
}
|
||||
|
||||
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
|
||||
async fn ping_single_dc(
|
||||
&self,
|
||||
config: &UpstreamConfig,
|
||||
bind_rr: Option<Arc<AtomicUsize>>,
|
||||
target: SocketAddr,
|
||||
) -> Result<f64> {
|
||||
let start = Instant::now();
|
||||
let _stream = self.connect_via_upstream(config, target).await?;
|
||||
let _stream = self.connect_via_upstream(config, target, bind_rr).await?;
|
||||
Ok(start.elapsed().as_secs_f64() * 1000.0)
|
||||
}
|
||||
|
||||
|
|
@ -649,15 +725,16 @@ impl UpstreamManager {
|
|||
let count = self.upstreams.read().await.len();
|
||||
|
||||
for i in 0..count {
|
||||
let config = {
|
||||
let (config, bind_rr) = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard[i].config.clone()
|
||||
let u = &guard[i];
|
||||
(u.config.clone(), u.bind_rr.clone())
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
self.connect_via_upstream(&config, dc_addr)
|
||||
self.connect_via_upstream(&config, dc_addr, Some(bind_rr.clone()))
|
||||
).await;
|
||||
|
||||
match result {
|
||||
|
|
@ -686,7 +763,7 @@ impl UpstreamManager {
|
|||
let start2 = Instant::now();
|
||||
let result2 = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
self.connect_via_upstream(&config, fallback_addr)
|
||||
self.connect_via_upstream(&config, fallback_addr, Some(bind_rr.clone()))
|
||||
).await;
|
||||
|
||||
let mut guard = self.upstreams.write().await;
|
||||
|
|
|
|||
Loading…
Reference in New Issue