From 8b92b80b4af7b1f19317b0e12cbd2dbc1aa418f7 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:33:06 +0300 Subject: [PATCH] Rustks CryptoProvider fixes + Rustfmt --- src/main.rs | 1 + src/proxy/handshake.rs | 114 ++++++++++++++--------- src/tls_front/fetcher.rs | 45 ++++----- src/transport/middle_proxy/http_fetch.rs | 5 +- 4 files changed, 93 insertions(+), 72 deletions(-) diff --git a/src/main.rs b/src/main.rs index 406b321..e5d931f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,5 +29,6 @@ mod util; #[tokio::main] async fn main() -> std::result::Result<(), Box> { + let _ = rustls::crypto::ring::default_provider().install_default(); maestro::run().await } diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 9d48fe9..2ef8e1b 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -282,30 +282,9 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); - let mut scanned = 0usize; - for entry in state.iter().skip(start_offset) { - let key = *entry.key(); - let fail_streak = entry.value().fail_streak; - let last_seen = entry.value().last_seen; - match eviction_candidate { - Some((_, current_fail, current_seen)) - if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => {} - _ => eviction_candidate = Some((key, fail_streak, last_seen)), - } - if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(key); - } - scanned += 1; - if scanned >= scan_limit { - break; - } - } - - if scanned < scan_limit { - for entry in state.iter().take(scan_limit - scanned) { + if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT { + for entry in state.iter() { let key = *entry.key(); let fail_streak = entry.value().fail_streak; let last_seen = entry.value().last_seen; @@ -319,6 +298,46 @@ fn auth_probe_record_failure_with_state( stale_keys.push(key); } } + } else { + let start_offset = + auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail + && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } } for stale_key in stale_keys { @@ -608,11 +627,35 @@ where } let client_sni = tls::extract_sni_from_client_hello(handshake); + let preferred_user_hint = client_sni + .as_deref() + .filter(|sni| config.access.users.contains_key(*sni)); let matched_tls_domain = client_sni .as_deref() .and_then(|sni| find_matching_tls_domain(config, sni)); - if client_sni.is_some() && matched_tls_domain.is_none() { + let alpn_list = if config.censorship.alpn_enforce { + tls::extract_alpn_from_client_hello(handshake) + } else { + Vec::new() + }; + let selected_alpn = if config.censorship.alpn_enforce { + if alpn_list.iter().any(|p| p == b"h2") { + Some(b"h2".to_vec()) + } else if alpn_list.iter().any(|p| p == b"http/1.1") { + Some(b"http/1.1".to_vec()) + } else if !alpn_list.is_empty() { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); + return HandshakeResult::BadClient { reader, writer }; + } else { + None + } + } else { + None + }; + + if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() { auth_probe_record_failure(peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; debug!( @@ -627,7 +670,7 @@ where }; } - let secrets = decode_user_secrets(config, client_sni.as_deref()); + let secrets = decode_user_secrets(config, preferred_user_hint); let validation = match tls::validate_tls_handshake_with_replay_window( handshake, @@ -684,27 +727,6 @@ where None }; - let alpn_list = if config.censorship.alpn_enforce { - tls::extract_alpn_from_client_hello(handshake) - } else { - Vec::new() - }; - let selected_alpn = if config.censorship.alpn_enforce { - if alpn_list.iter().any(|p| p == b"h2") { - Some(b"h2".to_vec()) - } else if alpn_list.iter().any(|p| p == b"http/1.1") { - Some(b"http/1.1".to_vec()) - } else if !alpn_list.is_empty() { - maybe_apply_server_hello_delay(config).await; - debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); - return HandshakeResult::BadClient { reader, writer }; - } else { - None - } - } else { - None - }; - // Add replay digest only for policy-valid handshakes. replay_checker.add_tls_digest(digest_half); diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 503b79c..bbfc336 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -241,7 +241,10 @@ fn order_profiles( return ordered; } - if let Some(pos) = ordered.iter().position(|profile| *profile == cached.profile) { + if let Some(pos) = ordered + .iter() + .position(|profile| *profile == cached.profile) + { if pos != 0 { ordered.swap(0, pos); } @@ -951,15 +954,9 @@ async fn fetch_via_raw_tls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream( - host, - port, - connect_timeout, - upstream, - scope, - strict_route, - ) - .await?; + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; fetch_via_raw_tls_stream( stream, sni, @@ -1109,15 +1106,9 @@ async fn fetch_via_rustls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream( - host, - port, - connect_timeout, - upstream, - scope, - strict_route, - ) - .await?; + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } @@ -1215,7 +1206,9 @@ pub async fn fetch_real_tls_with_strategy( if elapsed >= total_budget { return match raw_result { Some(raw) => Ok(raw), - None => Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))), + None => { + Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))) + } }; } @@ -1250,9 +1243,7 @@ pub async fn fetch_real_tls_with_strategy( warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); Ok(raw) } else if let Some(raw_err) = raw_last_error { - Err(anyhow!( - "TLS fetch failed (raw: {raw_err}; rustls: {err})" - )) + Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})")) } else { Err(err) } @@ -1386,7 +1377,10 @@ mod tests { #[test] fn test_order_profiles_drops_expired_cached_winner() { let strategy = TlsFetchStrategy { - profiles: vec![TlsFetchProfile::ModernFirefoxLike, TlsFetchProfile::CompatTls12], + profiles: vec![ + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + ], strict_route: true, attempt_timeout: Duration::from_secs(1), total_budget: Duration::from_secs(2), @@ -1394,7 +1388,8 @@ mod tests { deterministic: false, profile_cache_ttl: Duration::from_secs(5), }; - let cache_key = profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); + let cache_key = + profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); profile_cache().remove(&cache_key); profile_cache().insert( cache_key.clone(), diff --git a/src/transport/middle_proxy/http_fetch.rs b/src/transport/middle_proxy/http_fetch.rs index 2f21934..5be601e 100644 --- a/src/transport/middle_proxy/http_fetch.rs +++ b/src/transport/middle_proxy/http_fetch.rs @@ -27,7 +27,10 @@ pub(crate) struct HttpsGetResponse { fn build_tls_client_config() -> Arc { let mut root_store = rustls::RootCertStore::empty(); root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let config = rustls::ClientConfig::builder() + let provider = rustls::crypto::ring::default_provider(); + let config = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .expect("HTTPS fetch rustls protocol versions must be valid") .with_root_certificates(root_store) .with_no_client_auth(); Arc::new(config)