diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index cdfd844..f719349 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -55,6 +55,7 @@ const STICKY_HINT_MAX_ENTRIES: usize = 65_536; const CANDIDATE_HINT_TRACK_CAP: usize = 64; const OVERLOAD_CANDIDATE_BUDGET_HINTED: usize = 16; const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8; +const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64; const RECENT_USER_RING_SCAN_LIMIT: usize = 32; type HmacSha256 = Hmac; @@ -551,6 +552,19 @@ fn auth_probe_note_saturation_in(shared: &ProxySharedState, now: Instant) { } } +fn auth_probe_note_expensive_invalid_scan_in( + shared: &ProxySharedState, + now: Instant, + validation_checks: usize, + overload: bool, +) { + if overload || validation_checks < EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD { + return; + } + + auth_probe_note_saturation_in(shared, now); +} + fn auth_probe_record_failure_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = &shared.handshake.auth_probe; @@ -1378,7 +1392,14 @@ where } if !matched { - auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + let failure_now = Instant::now(); + auth_probe_note_expensive_invalid_scan_in( + shared, + failure_now, + validation_checks, + overload, + ); + auth_probe_record_failure_in(shared, peer.ip(), failure_now); maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, @@ -1753,7 +1774,14 @@ where } if !matched { - auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + let failure_now = Instant::now(); + auth_probe_note_expensive_invalid_scan_in( + shared, + failure_now, + validation_checks, + overload, + ); + auth_probe_record_failure_in(shared, peer.ip(), failure_now); maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index f1f6584..b0ddb8f 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -36,7 +36,11 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: PooledBuffer, flags: u32 }, + Data { + payload: PooledBuffer, + flags: u32, + _permit: OwnedSemaphorePermit, + }, Close, } @@ -47,6 +51,8 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +const C2ME_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024; +const C2ME_QUEUED_PERMITS_PER_SLOT: usize = 4; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); const TINY_FRAME_DEBT_PER_TINY: u32 = 8; const TINY_FRAME_DEBT_LIMIT: u32 = 512; @@ -571,6 +577,43 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } +fn c2me_payload_permits(payload_len: usize) -> u32 { + payload_len + .max(1) + .div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT) + .min(u32::MAX as usize) as u32 +} + +fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize { + channel_capacity + .saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT) + .max(c2me_payload_permits(frame_limit) as usize) + .max(1) +} + +async fn acquire_c2me_payload_permit( + semaphore: &Arc, + payload_len: usize, + send_timeout: Option, + stats: &Stats, +) -> Result { + let permits = c2me_payload_permits(payload_len); + let acquire = semaphore.clone().acquire_many_owned(permits); + match send_timeout { + Some(send_timeout) => match timeout(send_timeout, acquire).await { + Ok(Ok(permit)) => Ok(permit), + Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())), + Err(_) => { + stats.increment_me_c2me_send_timeout_total(); + Err(ProxyError::Proxy("ME sender byte budget timeout".into())) + } + }, + None => acquire + .await + .map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())), + } +} + fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } @@ -1122,13 +1165,19 @@ where 0 => None, timeout_ms => Some(Duration::from_millis(timeout_ms)), }; + let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit); + let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget)); let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); let c2me_sender = tokio::spawn(async move { let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { match cmd { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { + payload, + flags, + _permit, + } => { me_pool_c2me .send_proxy_req( conn_id, @@ -1624,11 +1673,29 @@ where if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { flags |= RPC_FLAG_NOT_ENCRYPTED; } + let payload_permit = match acquire_c2me_payload_permit( + &c2me_byte_semaphore, + payload.len(), + c2me_send_timeout, + stats.as_ref(), + ) + .await + { + Ok(permit) => permit, + Err(e) => { + main_result = Err(e); + break; + } + }; // Keep client read loop lightweight: route heavy ME send path via a dedicated task. if enqueue_c2me_command_in( shared.as_ref(), &c2me_tx, - C2MeCommand::Data { payload, flags }, + C2MeCommand::Data { + payload, + flags, + _permit: payload_permit, + }, c2me_send_timeout, stats.as_ref(), ) diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index df91cac..dd7ad08 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -1252,6 +1252,97 @@ async fn tls_overload_budget_limits_candidate_scan_depth() { ); } +#[tokio::test] +async fn tls_expensive_invalid_scan_activates_saturation_budget() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + for idx in 0..80u8 { + config.access.users.insert( + format!("user-{idx}"), + format!("{:032x}", u128::from(idx) + 1), + ); + } + config.rebuild_runtime_user_auth().unwrap(); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let shared = ProxySharedState::new(); + let attacker_secret = [0xEFu8; 16]; + let handshake = make_valid_tls_handshake(&attacker_secret, 0); + + let first_peer: SocketAddr = "198.51.100.214:44326".parse().unwrap(); + let first = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(first, HandshakeResult::BadClient { .. })); + assert!( + auth_probe_saturation_state_for_testing_in_shared(shared.as_ref()) + .lock() + .unwrap() + .is_some(), + "expensive invalid scan must activate global saturation" + ); + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + 80, + "first invalid probe preserves full first-hit compatibility before enabling saturation" + ); + + { + let mut saturation = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref()) + .lock() + .unwrap(); + let state = saturation.as_mut().expect("saturation must be present"); + state.blocked_until = Instant::now() + Duration::from_millis(200); + } + + let second_peer: SocketAddr = "198.51.100.215:44326".parse().unwrap(); + let second = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + second_peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(second, HandshakeResult::BadClient { .. })); + assert_eq!( + shared + .handshake + .auth_budget_exhausted_total + .load(Ordering::Relaxed), + 1, + "second invalid probe must be capped by overload budget" + ); + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + 80 + OVERLOAD_CANDIDATE_BUDGET_UNHINTED as u64, + "saturation budget must bound follow-up invalid scans" + ); +} + #[tokio::test] async fn mtproto_runtime_snapshot_prefers_preferred_user_hint() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs index 54eb784..6d398c8 100644 --- a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs +++ b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs @@ -12,6 +12,12 @@ fn make_pooled_payload(data: &[u8]) -> PooledBuffer { payload } +fn make_c2me_permit() -> tokio::sync::OwnedSemaphorePermit { + Arc::new(tokio::sync::Semaphore::new(1)) + .try_acquire_many_owned(1) + .expect("test permit must be available") +} + #[test] #[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] fn should_emit_full_desync_filters_duplicates() { @@ -107,6 +113,7 @@ async fn c2me_channel_full_path_yields_then_sends() { tx.send(C2MeCommand::Data { payload: make_pooled_payload(&[0xAA]), flags: 1, + _permit: make_c2me_permit(), }) .await .expect("priming queue with one frame must succeed"); @@ -119,6 +126,7 @@ async fn c2me_channel_full_path_yields_then_sends() { C2MeCommand::Data { payload: make_pooled_payload(&[0xBB, 0xCC]), flags: 2, + _permit: make_c2me_permit(), }, None, &stats, @@ -138,7 +146,7 @@ async fn c2me_channel_full_path_yields_then_sends() { .expect("receiver should observe primed frame") .expect("first queued command must exist"); match first { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { payload, flags, .. } => { assert_eq!(payload.as_ref(), &[0xAA]); assert_eq!(flags, 1); } @@ -155,7 +163,7 @@ async fn c2me_channel_full_path_yields_then_sends() { .expect("receiver should observe backpressure-resumed frame") .expect("second queued command must exist"); match second { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { payload, flags, .. } => { assert_eq!(payload.as_ref(), &[0xBB, 0xCC]); assert_eq!(flags, 2); }