diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 2d4dd42..2ab02ce 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -31,16 +31,24 @@ struct UserConnectionReservation { ip_tracker: Arc, user: String, ip: IpAddr, + tracks_ip: bool, active: bool, } impl UserConnectionReservation { - fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + fn new( + stats: Arc, + ip_tracker: Arc, + user: String, + ip: IpAddr, + tracks_ip: bool, + ) -> Self { Self { stats, ip_tracker, user, ip, + tracks_ip, active: true, } } @@ -49,7 +57,9 @@ impl UserConnectionReservation { if !self.active { return; } - self.ip_tracker.remove_ip(&self.user, self.ip).await; + if self.tracks_ip { + self.ip_tracker.remove_ip(&self.user, self.ip).await; + } self.active = false; self.stats.decrement_user_curr_connects(&self.user); } @@ -62,7 +72,9 @@ impl Drop for UserConnectionReservation { } self.active = false; self.stats.decrement_user_curr_connects(&self.user); - self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); + if self.tracks_ip { + self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); + } } } @@ -1600,19 +1612,22 @@ impl RunningClientHandler { }); } - match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => {} - Err(reason) => { - stats.decrement_user_curr_connects(user); - warn!( - user = %user, - ip = %peer_addr.ip(), - reason = %reason, - "IP limit exceeded" - ); - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); + let tracks_ip = ip_tracker.get_user_limit(user).await.is_some(); + if tracks_ip { + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } } } @@ -1621,6 +1636,7 @@ impl RunningClientHandler { ip_tracker, user.to_string(), peer_addr.ip(), + tracks_ip, )) } @@ -1663,25 +1679,27 @@ impl RunningClientHandler { }); } - match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.decrement_user_curr_connects(user); - } - Err(reason) => { - stats.decrement_user_curr_connects(user); - warn!( - user = %user, - ip = %peer_addr.ip(), - reason = %reason, - "IP limit exceeded" - ); - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); + if ip_tracker.get_user_limit(user).await.is_some() { + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => { + ip_tracker.remove_ip(user, peer_addr.ip()).await; + } + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } } } + stats.decrement_user_curr_connects(user); Ok(()) } } diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index cdfd844..f719349 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -55,6 +55,7 @@ const STICKY_HINT_MAX_ENTRIES: usize = 65_536; const CANDIDATE_HINT_TRACK_CAP: usize = 64; const OVERLOAD_CANDIDATE_BUDGET_HINTED: usize = 16; const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8; +const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64; const RECENT_USER_RING_SCAN_LIMIT: usize = 32; type HmacSha256 = Hmac; @@ -551,6 +552,19 @@ fn auth_probe_note_saturation_in(shared: &ProxySharedState, now: Instant) { } } +fn auth_probe_note_expensive_invalid_scan_in( + shared: &ProxySharedState, + now: Instant, + validation_checks: usize, + overload: bool, +) { + if overload || validation_checks < EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD { + return; + } + + auth_probe_note_saturation_in(shared, now); +} + fn auth_probe_record_failure_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = &shared.handshake.auth_probe; @@ -1378,7 +1392,14 @@ where } if !matched { - auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + let failure_now = Instant::now(); + auth_probe_note_expensive_invalid_scan_in( + shared, + failure_now, + validation_checks, + overload, + ); + auth_probe_record_failure_in(shared, peer.ip(), failure_now); maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, @@ -1753,7 +1774,14 @@ where } if !matched { - auth_probe_record_failure_in(shared, peer.ip(), Instant::now()); + let failure_now = Instant::now(); + auth_probe_note_expensive_invalid_scan_in( + shared, + failure_now, + validation_checks, + overload, + ); + auth_probe_record_failure_in(shared, peer.ip(), failure_now); maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index f1f6584..b0ddb8f 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -36,7 +36,11 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: PooledBuffer, flags: u32 }, + Data { + payload: PooledBuffer, + flags: u32, + _permit: OwnedSemaphorePermit, + }, Close, } @@ -47,6 +51,8 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +const C2ME_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024; +const C2ME_QUEUED_PERMITS_PER_SLOT: usize = 4; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); const TINY_FRAME_DEBT_PER_TINY: u32 = 8; const TINY_FRAME_DEBT_LIMIT: u32 = 512; @@ -571,6 +577,43 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } +fn c2me_payload_permits(payload_len: usize) -> u32 { + payload_len + .max(1) + .div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT) + .min(u32::MAX as usize) as u32 +} + +fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize { + channel_capacity + .saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT) + .max(c2me_payload_permits(frame_limit) as usize) + .max(1) +} + +async fn acquire_c2me_payload_permit( + semaphore: &Arc, + payload_len: usize, + send_timeout: Option, + stats: &Stats, +) -> Result { + let permits = c2me_payload_permits(payload_len); + let acquire = semaphore.clone().acquire_many_owned(permits); + match send_timeout { + Some(send_timeout) => match timeout(send_timeout, acquire).await { + Ok(Ok(permit)) => Ok(permit), + Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())), + Err(_) => { + stats.increment_me_c2me_send_timeout_total(); + Err(ProxyError::Proxy("ME sender byte budget timeout".into())) + } + }, + None => acquire + .await + .map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())), + } +} + fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } @@ -1122,13 +1165,19 @@ where 0 => None, timeout_ms => Some(Duration::from_millis(timeout_ms)), }; + let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit); + let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget)); let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); let c2me_sender = tokio::spawn(async move { let mut sent_since_yield = 0usize; while let Some(cmd) = c2me_rx.recv().await { match cmd { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { + payload, + flags, + _permit, + } => { me_pool_c2me .send_proxy_req( conn_id, @@ -1624,11 +1673,29 @@ where if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { flags |= RPC_FLAG_NOT_ENCRYPTED; } + let payload_permit = match acquire_c2me_payload_permit( + &c2me_byte_semaphore, + payload.len(), + c2me_send_timeout, + stats.as_ref(), + ) + .await + { + Ok(permit) => permit, + Err(e) => { + main_result = Err(e); + break; + } + }; // Keep client read loop lightweight: route heavy ME send path via a dedicated task. if enqueue_c2me_command_in( shared.as_ref(), &c2me_tx, - C2MeCommand::Data { payload, flags }, + C2MeCommand::Data { + payload, + flags, + _permit: payload_permit, + }, c2me_send_timeout, stats.as_ref(), ) diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 480b33d..4505e17 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -281,8 +281,13 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); assert_eq!(stats.get_user_curr_connects(&user), 1); - let reservation = - UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); + let reservation = UserConnectionReservation::new( + stats.clone(), + ip_tracker.clone(), + user.clone(), + ip, + true, + ); // Drop the reservation synchronously without any tokio::spawn/await yielding! drop(reservation); @@ -320,6 +325,7 @@ async fn relay_task_abort_releases_user_gate_and_ip_reservation() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let mut cfg = ProxyConfig::default(); cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); @@ -437,6 +443,7 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let mut cfg = ProxyConfig::default(); cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); @@ -2879,6 +2886,7 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -2917,6 +2925,7 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -2947,6 +2956,7 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3029,6 +3039,7 @@ async fn release_abort_storm_does_not_leak_user_or_ip_reservations() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, ATTEMPTS + 16).await; for idx in 0..ATTEMPTS { let peer = SocketAddr::new( @@ -3079,6 +3090,7 @@ async fn release_abort_loop_preserves_immediate_same_ip_reacquire() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; for _ in 0..ITERATIONS { let reservation = RunningClientHandler::acquire_user_connection_reservation_static( @@ -3137,6 +3149,7 @@ async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, RESERVATIONS + 8).await; let mut reservations = Vec::with_capacity(RESERVATIONS); for idx in 0..RESERVATIONS { @@ -3217,6 +3230,8 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup() let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user_a, 64).await; + ip_tracker.set_user_limit(user_b, 64).await; let mut tasks = tokio::task::JoinSet::new(); for idx in 0..64usize { @@ -3278,6 +3293,7 @@ async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, RESERVATIONS + 8).await; let mut reservations = Vec::with_capacity(RESERVATIONS); for idx in 0..RESERVATIONS { @@ -3332,6 +3348,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let mut config = ProxyConfig::default(); config.access.user_max_tcp_conns.insert(user.to_string(), 1); @@ -3427,6 +3444,7 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3487,6 +3505,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3696,6 +3715,7 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, @@ -3740,6 +3760,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() { let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; let reservation = RunningClientHandler::acquire_user_connection_reservation_static( user, diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index df91cac..dd7ad08 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -1252,6 +1252,97 @@ async fn tls_overload_budget_limits_candidate_scan_depth() { ); } +#[tokio::test] +async fn tls_expensive_invalid_scan_activates_saturation_budget() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + for idx in 0..80u8 { + config.access.users.insert( + format!("user-{idx}"), + format!("{:032x}", u128::from(idx) + 1), + ); + } + config.rebuild_runtime_user_auth().unwrap(); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let shared = ProxySharedState::new(); + let attacker_secret = [0xEFu8; 16]; + let handshake = make_valid_tls_handshake(&attacker_secret, 0); + + let first_peer: SocketAddr = "198.51.100.214:44326".parse().unwrap(); + let first = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(first, HandshakeResult::BadClient { .. })); + assert!( + auth_probe_saturation_state_for_testing_in_shared(shared.as_ref()) + .lock() + .unwrap() + .is_some(), + "expensive invalid scan must activate global saturation" + ); + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + 80, + "first invalid probe preserves full first-hit compatibility before enabling saturation" + ); + + { + let mut saturation = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref()) + .lock() + .unwrap(); + let state = saturation.as_mut().expect("saturation must be present"); + state.blocked_until = Instant::now() + Duration::from_millis(200); + } + + let second_peer: SocketAddr = "198.51.100.215:44326".parse().unwrap(); + let second = handle_tls_handshake_with_shared( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + second_peer, + &config, + &replay_checker, + &rng, + None, + shared.as_ref(), + ) + .await; + + assert!(matches!(second, HandshakeResult::BadClient { .. })); + assert_eq!( + shared + .handshake + .auth_budget_exhausted_total + .load(Ordering::Relaxed), + 1, + "second invalid probe must be capped by overload budget" + ); + assert_eq!( + shared + .handshake + .auth_expensive_checks_total + .load(Ordering::Relaxed), + 80 + OVERLOAD_CANDIDATE_BUDGET_UNHINTED as u64, + "saturation budget must bound follow-up invalid scans" + ); +} + #[tokio::test] async fn mtproto_runtime_snapshot_prefers_preferred_user_hint() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs index 54eb784..6d398c8 100644 --- a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs +++ b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs @@ -12,6 +12,12 @@ fn make_pooled_payload(data: &[u8]) -> PooledBuffer { payload } +fn make_c2me_permit() -> tokio::sync::OwnedSemaphorePermit { + Arc::new(tokio::sync::Semaphore::new(1)) + .try_acquire_many_owned(1) + .expect("test permit must be available") +} + #[test] #[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] fn should_emit_full_desync_filters_duplicates() { @@ -107,6 +113,7 @@ async fn c2me_channel_full_path_yields_then_sends() { tx.send(C2MeCommand::Data { payload: make_pooled_payload(&[0xAA]), flags: 1, + _permit: make_c2me_permit(), }) .await .expect("priming queue with one frame must succeed"); @@ -119,6 +126,7 @@ async fn c2me_channel_full_path_yields_then_sends() { C2MeCommand::Data { payload: make_pooled_payload(&[0xBB, 0xCC]), flags: 2, + _permit: make_c2me_permit(), }, None, &stats, @@ -138,7 +146,7 @@ async fn c2me_channel_full_path_yields_then_sends() { .expect("receiver should observe primed frame") .expect("first queued command must exist"); match first { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { payload, flags, .. } => { assert_eq!(payload.as_ref(), &[0xAA]); assert_eq!(flags, 1); } @@ -155,7 +163,7 @@ async fn c2me_channel_full_path_yields_then_sends() { .expect("receiver should observe backpressure-resumed frame") .expect("second queued command must exist"); match second { - C2MeCommand::Data { payload, flags } => { + C2MeCommand::Data { payload, flags, .. } => { assert_eq!(payload.as_ref(), &[0xBB, 0xCC]); assert_eq!(flags, 2); } diff --git a/src/stats/beobachten.rs b/src/stats/beobachten.rs index 3d3a2da..79b2bcd 100644 --- a/src/stats/beobachten.rs +++ b/src/stats/beobachten.rs @@ -7,6 +7,7 @@ use std::time::{Duration, Instant}; use parking_lot::Mutex; const CLEANUP_INTERVAL: Duration = Duration::from_secs(30); +const MAX_BEOBACHTEN_ENTRIES: usize = 65_536; #[derive(Default)] struct BeobachtenInner { @@ -48,12 +49,23 @@ impl BeobachtenStore { Self::cleanup_if_needed(&mut guard, now, ttl); let key = (class.to_string(), ip); - let entry = guard.entries.entry(key).or_insert(BeobachtenEntry { - tries: 0, - last_seen: now, - }); - entry.tries = entry.tries.saturating_add(1); - entry.last_seen = now; + if let Some(entry) = guard.entries.get_mut(&key) { + entry.tries = entry.tries.saturating_add(1); + entry.last_seen = now; + return; + } + + if guard.entries.len() >= MAX_BEOBACHTEN_ENTRIES { + return; + } + + guard.entries.insert( + key, + BeobachtenEntry { + tries: 1, + last_seen: now, + }, + ); } pub fn snapshot_text(&self, ttl: Duration) -> String {