Rustks CryptoProvider fixes + Rustfmt

This commit is contained in:
Alexey 2026-03-24 10:33:06 +03:00
parent f7868aa00f
commit 8b92b80b4a
No known key found for this signature in database
4 changed files with 93 additions and 72 deletions

View File

@ -29,5 +29,6 @@ mod util;
#[tokio::main] #[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let _ = rustls::crypto::ring::default_provider().install_default();
maestro::run().await maestro::run().await
} }

View File

@ -282,8 +282,25 @@ fn auth_probe_record_failure_with_state(
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len(); let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT {
for entry in state.iter() {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
}
} else {
let start_offset =
auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
let mut scanned = 0usize; let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) { for entry in state.iter().skip(start_offset) {
let key = *entry.key(); let key = *entry.key();
@ -312,7 +329,8 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate { match eviction_candidate {
Some((_, current_fail, current_seen)) Some((_, current_fail, current_seen))
if fail_streak > current_fail if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {} || (fail_streak == current_fail
&& last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)), _ => eviction_candidate = Some((key, fail_streak, last_seen)),
} }
if auth_probe_state_expired(entry.value(), now) { if auth_probe_state_expired(entry.value(), now) {
@ -320,6 +338,7 @@ fn auth_probe_record_failure_with_state(
} }
} }
} }
}
for stale_key in stale_keys { for stale_key in stale_keys {
state.remove(&stale_key); state.remove(&stale_key);
@ -608,11 +627,35 @@ where
} }
let client_sni = tls::extract_sni_from_client_hello(handshake); let client_sni = tls::extract_sni_from_client_hello(handshake);
let preferred_user_hint = client_sni
.as_deref()
.filter(|sni| config.access.users.contains_key(*sni));
let matched_tls_domain = client_sni let matched_tls_domain = client_sni
.as_deref() .as_deref()
.and_then(|sni| find_matching_tls_domain(config, sni)); .and_then(|sni| find_matching_tls_domain(config, sni));
if client_sni.is_some() && matched_tls_domain.is_none() { let alpn_list = if config.censorship.alpn_enforce {
tls::extract_alpn_from_client_hello(handshake)
} else {
Vec::new()
};
let selected_alpn = if config.censorship.alpn_enforce {
if alpn_list.iter().any(|p| p == b"h2") {
Some(b"h2".to_vec())
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
None
}
} else {
None
};
if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() {
auth_probe_record_failure(peer.ip(), Instant::now()); auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await; maybe_apply_server_hello_delay(config).await;
debug!( debug!(
@ -627,7 +670,7 @@ where
}; };
} }
let secrets = decode_user_secrets(config, client_sni.as_deref()); let secrets = decode_user_secrets(config, preferred_user_hint);
let validation = match tls::validate_tls_handshake_with_replay_window( let validation = match tls::validate_tls_handshake_with_replay_window(
handshake, handshake,
@ -684,27 +727,6 @@ where
None None
}; };
let alpn_list = if config.censorship.alpn_enforce {
tls::extract_alpn_from_client_hello(handshake)
} else {
Vec::new()
};
let selected_alpn = if config.censorship.alpn_enforce {
if alpn_list.iter().any(|p| p == b"h2") {
Some(b"h2".to_vec())
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
None
}
} else {
None
};
// Add replay digest only for policy-valid handshakes. // Add replay digest only for policy-valid handshakes.
replay_checker.add_tls_digest(digest_half); replay_checker.add_tls_digest(digest_half);

View File

@ -241,7 +241,10 @@ fn order_profiles(
return ordered; return ordered;
} }
if let Some(pos) = ordered.iter().position(|profile| *profile == cached.profile) { if let Some(pos) = ordered
.iter()
.position(|profile| *profile == cached.profile)
{
if pos != 0 { if pos != 0 {
ordered.swap(0, pos); ordered.swap(0, pos);
} }
@ -951,14 +954,8 @@ async fn fetch_via_raw_tls(
#[cfg(not(unix))] #[cfg(not(unix))]
let _ = unix_sock; let _ = unix_sock;
let stream = connect_tcp_with_upstream( let stream =
host, connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route)
port,
connect_timeout,
upstream,
scope,
strict_route,
)
.await?; .await?;
fetch_via_raw_tls_stream( fetch_via_raw_tls_stream(
stream, stream,
@ -1109,14 +1106,8 @@ async fn fetch_via_rustls(
#[cfg(not(unix))] #[cfg(not(unix))]
let _ = unix_sock; let _ = unix_sock;
let stream = connect_tcp_with_upstream( let stream =
host, connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route)
port,
connect_timeout,
upstream,
scope,
strict_route,
)
.await?; .await?;
fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await
} }
@ -1215,7 +1206,9 @@ pub async fn fetch_real_tls_with_strategy(
if elapsed >= total_budget { if elapsed >= total_budget {
return match raw_result { return match raw_result {
Some(raw) => Ok(raw), Some(raw) => Ok(raw),
None => Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))), None => {
Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted")))
}
}; };
} }
@ -1250,9 +1243,7 @@ pub async fn fetch_real_tls_with_strategy(
warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only");
Ok(raw) Ok(raw)
} else if let Some(raw_err) = raw_last_error { } else if let Some(raw_err) = raw_last_error {
Err(anyhow!( Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})"))
"TLS fetch failed (raw: {raw_err}; rustls: {err})"
))
} else { } else {
Err(err) Err(err)
} }
@ -1386,7 +1377,10 @@ mod tests {
#[test] #[test]
fn test_order_profiles_drops_expired_cached_winner() { fn test_order_profiles_drops_expired_cached_winner() {
let strategy = TlsFetchStrategy { let strategy = TlsFetchStrategy {
profiles: vec![TlsFetchProfile::ModernFirefoxLike, TlsFetchProfile::CompatTls12], profiles: vec![
TlsFetchProfile::ModernFirefoxLike,
TlsFetchProfile::CompatTls12,
],
strict_route: true, strict_route: true,
attempt_timeout: Duration::from_secs(1), attempt_timeout: Duration::from_secs(1),
total_budget: Duration::from_secs(2), total_budget: Duration::from_secs(2),
@ -1394,7 +1388,8 @@ mod tests {
deterministic: false, deterministic: false,
profile_cache_ttl: Duration::from_secs(5), profile_cache_ttl: Duration::from_secs(5),
}; };
let cache_key = profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); let cache_key =
profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None);
profile_cache().remove(&cache_key); profile_cache().remove(&cache_key);
profile_cache().insert( profile_cache().insert(
cache_key.clone(), cache_key.clone(),

View File

@ -27,7 +27,10 @@ pub(crate) struct HttpsGetResponse {
fn build_tls_client_config() -> Arc<rustls::ClientConfig> { fn build_tls_client_config() -> Arc<rustls::ClientConfig> {
let mut root_store = rustls::RootCertStore::empty(); let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = rustls::ClientConfig::builder() let provider = rustls::crypto::ring::default_provider();
let config = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.expect("HTTPS fetch rustls protocol versions must be valid")
.with_root_certificates(root_store) .with_root_certificates(root_store)
.with_no_client_auth(); .with_no_client_auth();
Arc::new(config) Arc::new(config)