diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 251c911..1016c6b 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -38,6 +38,22 @@ use super::MePool; const ME_KDF_DRIFT_STRICT: bool = false; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +enum KdfClientPortSource { + LocalSocket = 0, + SocksBound = 1, +} + +impl KdfClientPortSource { + fn from_socks_bound_port(socks_bound_port: Option) -> Self { + if socks_bound_port.is_some() { + Self::SocksBound + } else { + Self::LocalSocket + } + } +} + /// Result of a successful ME handshake with timings. pub(crate) struct HandshakeOutput { pub rd: ReadHalf, @@ -52,18 +68,18 @@ pub(crate) struct HandshakeOutput { impl MePool { fn kdf_material_fingerprint( - local_addr_nat: SocketAddr, + local_ip_nat: IpAddr, peer_addr_nat: SocketAddr, - client_port_for_kdf: u16, - reflected: Option, - socks_bound_addr: Option, + reflected_ip: Option, + socks_bound_ip: Option, + client_port_source: KdfClientPortSource, ) -> u64 { let mut hasher = DefaultHasher::new(); - local_addr_nat.hash(&mut hasher); + local_ip_nat.hash(&mut hasher); peer_addr_nat.hash(&mut hasher); - client_port_for_kdf.hash(&mut hasher); - reflected.hash(&mut hasher); - socks_bound_addr.hash(&mut hasher); + reflected_ip.hash(&mut hasher); + socks_bound_ip.hash(&mut hasher); + client_port_source.hash(&mut hasher); hasher.finish() } @@ -359,35 +375,48 @@ impl MePool { let ts_bytes = crypto_ts.to_le_bytes(); let server_port_bytes = peer_addr_nat.port().to_le_bytes(); - let client_port_for_kdf = socks_bound_addr + let socks_bound_port = socks_bound_addr .map(|bound| bound.port()) - .filter(|port| *port != 0) - .unwrap_or(local_addr_nat.port()); + .filter(|port| *port != 0); + let client_port_for_kdf = socks_bound_port.unwrap_or(local_addr_nat.port()); + let client_port_source = KdfClientPortSource::from_socks_bound_port(socks_bound_port); let kdf_fingerprint = Self::kdf_material_fingerprint( - local_addr_nat, + local_addr_nat.ip(), peer_addr_nat, - client_port_for_kdf, - reflected, - socks_bound_addr, + reflected.map(|value| value.ip()), + socks_bound_addr.map(|value| value.ip()), + client_port_source, ); let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.lock().await; - if let Some(prev_fingerprint) = kdf_fingerprint_guard.get(&peer_addr_nat).copied() - && prev_fingerprint != kdf_fingerprint + if let Some((prev_fingerprint, prev_client_port)) = + kdf_fingerprint_guard.get(&peer_addr_nat).copied() { - self.stats.increment_me_kdf_drift_total(); - warn!( - %peer_addr_nat, - %local_addr_nat, - client_port_for_kdf, - "ME KDF input drift detected for endpoint" - ); - if ME_KDF_DRIFT_STRICT { - return Err(ProxyError::InvalidHandshake( - "ME KDF input drift detected (strict mode)".to_string(), - )); + if prev_fingerprint != kdf_fingerprint { + self.stats.increment_me_kdf_drift_total(); + warn!( + %peer_addr_nat, + %local_addr_nat, + client_port_for_kdf, + client_port_source = ?client_port_source, + "ME KDF material drift detected for endpoint" + ); + if ME_KDF_DRIFT_STRICT { + return Err(ProxyError::InvalidHandshake( + "ME KDF material drift detected (strict mode)".to_string(), + )); + } + } else if prev_client_port != client_port_for_kdf { + self.stats.increment_me_kdf_port_only_drift_total(); + debug!( + %peer_addr_nat, + previous_client_port_for_kdf = prev_client_port, + client_port_for_kdf, + client_port_source = ?client_port_source, + "ME KDF client port changed with stable material" + ); } } - kdf_fingerprint_guard.insert(peer_addr_nat, kdf_fingerprint); + kdf_fingerprint_guard.insert(peer_addr_nat, (kdf_fingerprint, client_port_for_kdf)); drop(kdf_fingerprint_guard); let client_port_bytes = client_port_for_kdf.to_le_bytes(); diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index c07ec4f..192bf1b 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -395,6 +395,19 @@ async fn maybe_rotate_single_endpoint_shadow( } let endpoint = endpoints[0]; + if pool.is_endpoint_quarantined(endpoint).await { + pool.stats + .increment_me_single_endpoint_shadow_rotate_skipped_quarantine_total(); + shadow_rotate_deadline.insert(key, now + Duration::from_secs(SHADOW_ROTATE_RETRY_SECS)); + debug!( + dc = %dc, + ?family, + %endpoint, + "Single-endpoint shadow rotation skipped: endpoint is quarantined" + ); + return; + } + let Some(writer_ids) = live_writer_ids_by_addr.get(&endpoint) else { shadow_rotate_deadline.insert(key, now + Duration::from_secs(SHADOW_ROTATE_RETRY_SECS)); return; diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index 92071bd..6e14617 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -37,7 +37,7 @@ impl MePool { ); } - async fn is_endpoint_quarantined(&self, addr: SocketAddr) -> bool { + pub(super) async fn is_endpoint_quarantined(&self, addr: SocketAddr) -> bool { let mut guard = self.endpoint_quarantine.lock().await; let now = Instant::now(); guard.retain(|_, expiry| *expiry > now);