Rustks CryptoProvider fixes + Rustfmt

This commit is contained in:
Alexey 2026-03-24 10:33:06 +03:00
parent f7868aa00f
commit 8b92b80b4a
No known key found for this signature in database
4 changed files with 93 additions and 72 deletions

View File

@ -29,5 +29,6 @@ mod util;
#[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let _ = rustls::crypto::ring::default_provider().install_default();
maestro::run().await
}

View File

@ -282,8 +282,25 @@ fn auth_probe_record_failure_with_state(
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT {
for entry in state.iter() {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
}
} else {
let start_offset =
auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
@ -312,7 +329,8 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
|| (fail_streak == current_fail
&& last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
@ -320,6 +338,7 @@ fn auth_probe_record_failure_with_state(
}
}
}
}
for stale_key in stale_keys {
state.remove(&stale_key);
@ -608,11 +627,35 @@ where
}
let client_sni = tls::extract_sni_from_client_hello(handshake);
let preferred_user_hint = client_sni
.as_deref()
.filter(|sni| config.access.users.contains_key(*sni));
let matched_tls_domain = client_sni
.as_deref()
.and_then(|sni| find_matching_tls_domain(config, sni));
if client_sni.is_some() && matched_tls_domain.is_none() {
let alpn_list = if config.censorship.alpn_enforce {
tls::extract_alpn_from_client_hello(handshake)
} else {
Vec::new()
};
let selected_alpn = if config.censorship.alpn_enforce {
if alpn_list.iter().any(|p| p == b"h2") {
Some(b"h2".to_vec())
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
None
}
} else {
None
};
if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
@ -627,7 +670,7 @@ where
};
}
let secrets = decode_user_secrets(config, client_sni.as_deref());
let secrets = decode_user_secrets(config, preferred_user_hint);
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
@ -684,27 +727,6 @@ where
None
};
let alpn_list = if config.censorship.alpn_enforce {
tls::extract_alpn_from_client_hello(handshake)
} else {
Vec::new()
};
let selected_alpn = if config.censorship.alpn_enforce {
if alpn_list.iter().any(|p| p == b"h2") {
Some(b"h2".to_vec())
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
maybe_apply_server_hello_delay(config).await;
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else {
None
}
} else {
None
};
// Add replay digest only for policy-valid handshakes.
replay_checker.add_tls_digest(digest_half);

View File

@ -241,7 +241,10 @@ fn order_profiles(
return ordered;
}
if let Some(pos) = ordered.iter().position(|profile| *profile == cached.profile) {
if let Some(pos) = ordered
.iter()
.position(|profile| *profile == cached.profile)
{
if pos != 0 {
ordered.swap(0, pos);
}
@ -951,14 +954,8 @@ async fn fetch_via_raw_tls(
#[cfg(not(unix))]
let _ = unix_sock;
let stream = connect_tcp_with_upstream(
host,
port,
connect_timeout,
upstream,
scope,
strict_route,
)
let stream =
connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route)
.await?;
fetch_via_raw_tls_stream(
stream,
@ -1109,14 +1106,8 @@ async fn fetch_via_rustls(
#[cfg(not(unix))]
let _ = unix_sock;
let stream = connect_tcp_with_upstream(
host,
port,
connect_timeout,
upstream,
scope,
strict_route,
)
let stream =
connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route)
.await?;
fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await
}
@ -1215,7 +1206,9 @@ pub async fn fetch_real_tls_with_strategy(
if elapsed >= total_budget {
return match raw_result {
Some(raw) => Ok(raw),
None => Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))),
None => {
Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted")))
}
};
}
@ -1250,9 +1243,7 @@ pub async fn fetch_real_tls_with_strategy(
warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only");
Ok(raw)
} else if let Some(raw_err) = raw_last_error {
Err(anyhow!(
"TLS fetch failed (raw: {raw_err}; rustls: {err})"
))
Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})"))
} else {
Err(err)
}
@ -1386,7 +1377,10 @@ mod tests {
#[test]
fn test_order_profiles_drops_expired_cached_winner() {
let strategy = TlsFetchStrategy {
profiles: vec![TlsFetchProfile::ModernFirefoxLike, TlsFetchProfile::CompatTls12],
profiles: vec![
TlsFetchProfile::ModernFirefoxLike,
TlsFetchProfile::CompatTls12,
],
strict_route: true,
attempt_timeout: Duration::from_secs(1),
total_budget: Duration::from_secs(2),
@ -1394,7 +1388,8 @@ mod tests {
deterministic: false,
profile_cache_ttl: Duration::from_secs(5),
};
let cache_key = profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None);
let cache_key =
profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None);
profile_cache().remove(&cache_key);
profile_cache().insert(
cache_key.clone(),

View File

@ -27,7 +27,10 @@ pub(crate) struct HttpsGetResponse {
fn build_tls_client_config() -> Arc<rustls::ClientConfig> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = rustls::ClientConfig::builder()
let provider = rustls::crypto::ring::default_provider();
let config = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.expect("HTTPS fetch rustls protocol versions must be valid")
.with_root_certificates(root_store)
.with_no_client_auth();
Arc::new(config)