diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 8eebe6c..d1761d1 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -21,6 +21,8 @@ const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_ACTIVE_WRITERS_PER_CORE: u16 = 64; const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_WARM_WRITERS_PER_CORE: u16 = 64; const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_ACTIVE_WRITERS_GLOBAL: u32 = 256; const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_WARM_WRITERS_GLOBAL: u32 = 256; +const DEFAULT_ME_ROUTE_BACKPRESSURE_ENABLED: bool = false; +const DEFAULT_ME_ROUTE_FAIRSHARE_ENABLED: bool = false; const DEFAULT_ME_WRITER_CMD_CHANNEL_CAPACITY: usize = 4096; const DEFAULT_ME_ROUTE_CHANNEL_CAPACITY: usize = 768; const DEFAULT_ME_C2ME_CHANNEL_CAPACITY: usize = 1024; @@ -529,6 +531,14 @@ pub(crate) fn default_me_route_backpressure_base_timeout_ms() -> u64 { 25 } +pub(crate) fn default_me_route_backpressure_enabled() -> bool { + DEFAULT_ME_ROUTE_BACKPRESSURE_ENABLED +} + +pub(crate) fn default_me_route_fairshare_enabled() -> bool { + DEFAULT_ME_ROUTE_FAIRSHARE_ENABLED +} + pub(crate) fn default_me_route_backpressure_high_timeout_ms() -> u64 { 120 } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 48b56f8..6337a06 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -86,6 +86,8 @@ pub struct HotFields { pub telemetry_user_enabled: bool, pub telemetry_me_level: MeTelemetryLevel, pub me_socks_kdf_policy: MeSocksKdfPolicy, + pub me_route_backpressure_enabled: bool, + pub me_route_fairshare_enabled: bool, pub me_floor_mode: MeFloorMode, pub me_adaptive_floor_idle_secs: u64, pub me_adaptive_floor_min_writers_single_endpoint: u8, @@ -187,6 +189,8 @@ impl HotFields { telemetry_user_enabled: cfg.general.telemetry.user_enabled, telemetry_me_level: cfg.general.telemetry.me_level, me_socks_kdf_policy: cfg.general.me_socks_kdf_policy, + me_route_backpressure_enabled: cfg.general.me_route_backpressure_enabled, + me_route_fairshare_enabled: cfg.general.me_route_fairshare_enabled, me_floor_mode: cfg.general.me_floor_mode, me_adaptive_floor_idle_secs: cfg.general.me_adaptive_floor_idle_secs, me_adaptive_floor_min_writers_single_endpoint: cfg @@ -529,6 +533,8 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { new.general.me_route_backpressure_high_timeout_ms; cfg.general.me_route_backpressure_high_watermark_pct = new.general.me_route_backpressure_high_watermark_pct; + cfg.general.me_route_backpressure_enabled = new.general.me_route_backpressure_enabled; + cfg.general.me_route_fairshare_enabled = new.general.me_route_fairshare_enabled; cfg.general.me_reader_route_data_wait_ms = new.general.me_reader_route_data_wait_ms; cfg.general.me_d2c_flush_batch_max_frames = new.general.me_d2c_flush_batch_max_frames; cfg.general.me_d2c_flush_batch_max_bytes = new.general.me_d2c_flush_batch_max_bytes; @@ -1053,6 +1059,8 @@ fn log_changes( != new_hot.me_route_backpressure_high_timeout_ms || old_hot.me_route_backpressure_high_watermark_pct != new_hot.me_route_backpressure_high_watermark_pct + || old_hot.me_route_backpressure_enabled != new_hot.me_route_backpressure_enabled + || old_hot.me_route_fairshare_enabled != new_hot.me_route_fairshare_enabled || old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms || old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy || old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy @@ -1060,10 +1068,12 @@ fn log_changes( || old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms { info!( - "config reload: me_route_backpressure: base={}ms high={}ms watermark={}%; me_reader_route_data_wait_ms={}; me_health_interval: unhealthy={}ms healthy={}ms; me_admission_poll={}ms; me_warn_rate_limit={}ms", + "config reload: me_route_backpressure: enabled={} base={}ms high={}ms watermark={}%; me_route_fairshare_enabled={}; me_reader_route_data_wait_ms={}; me_health_interval: unhealthy={}ms healthy={}ms; me_admission_poll={}ms; me_warn_rate_limit={}ms", + new_hot.me_route_backpressure_enabled, new_hot.me_route_backpressure_base_timeout_ms, new_hot.me_route_backpressure_high_timeout_ms, new_hot.me_route_backpressure_high_watermark_pct, + new_hot.me_route_fairshare_enabled, new_hot.me_reader_route_data_wait_ms, new_hot.me_health_interval_ms_unhealthy, new_hot.me_health_interval_ms_healthy, diff --git a/src/config/load.rs b/src/config/load.rs index 55f38ca..1e455b8 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -640,12 +640,6 @@ impl ProxyConfig { )); } - if config.censorship.mask_relay_max_bytes == 0 { - return Err(ProxyError::Config( - "censorship.mask_relay_max_bytes must be > 0".to_string(), - )); - } - if config.censorship.mask_relay_max_bytes > 67_108_864 { return Err(ProxyError::Config( "censorship.mask_relay_max_bytes must be <= 67108864".to_string(), diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index bccd36f..41b1f94 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -238,7 +238,7 @@ mask_shape_above_cap_blur_max_bytes = 8 } #[test] -fn load_rejects_zero_mask_relay_max_bytes() { +fn load_accepts_zero_mask_relay_max_bytes_as_unlimited() { let path = write_temp_config( r#" [censorship] @@ -246,12 +246,9 @@ mask_relay_max_bytes = 0 "#, ); - let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0"); - let msg = err.to_string(); - assert!( - msg.contains("censorship.mask_relay_max_bytes must be > 0"), - "error must explain non-zero relay cap invariant, got: {msg}" - ); + let cfg = ProxyConfig::load(&path) + .expect("mask_relay_max_bytes=0 must be accepted as unlimited relay cap"); + assert_eq!(cfg.censorship.mask_relay_max_bytes, 0); remove_temp_config(&path); } diff --git a/src/config/types.rs b/src/config/types.rs index 9914b63..f422e4e 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -729,6 +729,14 @@ pub struct GeneralConfig { #[serde(default)] pub me_socks_kdf_policy: MeSocksKdfPolicy, + /// Enable route-level ME backpressure controls in reader fairness path. + #[serde(default = "default_me_route_backpressure_enabled")] + pub me_route_backpressure_enabled: bool, + + /// Enable worker-local fairshare scheduler for ME reader routing. + #[serde(default = "default_me_route_fairshare_enabled")] + pub me_route_fairshare_enabled: bool, + /// Base backpressure timeout in milliseconds for ME route channel send. #[serde(default = "default_me_route_backpressure_base_timeout_ms")] pub me_route_backpressure_base_timeout_ms: u64, @@ -1059,6 +1067,8 @@ impl Default for GeneralConfig { disable_colors: false, telemetry: TelemetryConfig::default(), me_socks_kdf_policy: MeSocksKdfPolicy::Strict, + me_route_backpressure_enabled: default_me_route_backpressure_enabled(), + me_route_fairshare_enabled: default_me_route_fairshare_enabled(), me_route_backpressure_base_timeout_ms: default_me_route_backpressure_base_timeout_ms(), me_route_backpressure_high_timeout_ms: default_me_route_backpressure_high_timeout_ms(), me_route_backpressure_high_watermark_pct: @@ -1758,6 +1768,7 @@ pub struct AntiCensorshipConfig { pub mask_shape_above_cap_blur_max_bytes: usize, /// Maximum bytes relayed per direction on unauthenticated masking fallback paths. + /// Set to 0 to disable byte cap (unlimited within relay/idle timeouts). #[serde(default = "default_mask_relay_max_bytes")] pub mask_relay_max_bytes: usize, diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index b647915..7002cec 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -277,6 +277,8 @@ pub(crate) async fn initialize_me_pool( config.general.me_socks_kdf_policy, config.general.me_writer_cmd_channel_capacity, config.general.me_route_channel_capacity, + config.general.me_route_backpressure_enabled, + config.general.me_route_fairshare_enabled, config.general.me_route_backpressure_base_timeout_ms, config.general.me_route_backpressure_high_timeout_ms, config.general.me_route_backpressure_high_watermark_pct, diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index da059bd..cd5e545 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -122,6 +122,8 @@ pub(crate) async fn spawn_runtime_tasks( if let Some(pool) = &me_pool_for_policy { pool.update_runtime_transport_policy( cfg.general.me_socks_kdf_policy, + cfg.general.me_route_backpressure_enabled, + cfg.general.me_route_fairshare_enabled, cfg.general.me_route_backpressure_base_timeout_ms, cfg.general.me_route_backpressure_high_timeout_ms, cfg.general.me_route_backpressure_high_watermark_pct, diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index d49e4c3..c48ec9c 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -60,21 +60,18 @@ where let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; let mut ended_by_eof = false; - - if byte_cap == 0 { - return CopyOutcome { - total, - ended_by_eof, - }; - } + let unlimited = byte_cap == 0; loop { - let remaining_budget = byte_cap.saturating_sub(total); - if remaining_budget == 0 { - break; - } - - let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let read_len = if unlimited { + MASK_BUFFER_SIZE + } else { + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + remaining_budget.min(MASK_BUFFER_SIZE) + }; let read_res = timeout(idle_timeout, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, @@ -930,21 +927,21 @@ async fn consume_client_data( byte_cap: usize, idle_timeout: Duration, ) { - if byte_cap == 0 { - return; - } - // Keep drain path fail-closed under slow-loris stalls. let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; + let unlimited = byte_cap == 0; loop { - let remaining_budget = byte_cap.saturating_sub(total); - if remaining_budget == 0 { - break; - } - - let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let read_len = if unlimited { + MASK_BUFFER_SIZE + } else { + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + remaining_budget.min(MASK_BUFFER_SIZE) + }; let n = match timeout(idle_timeout, reader.read(&mut buf[..read_len])).await { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, @@ -955,7 +952,7 @@ async fn consume_client_data( } total = total.saturating_add(n); - if total >= byte_cap { + if !unlimited && total >= byte_cap { break; } } diff --git a/src/proxy/tests/masking_consume_stress_adversarial_tests.rs b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs index 7579a9c..efe9f49 100644 --- a/src/proxy/tests/masking_consume_stress_adversarial_tests.rs +++ b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs @@ -58,11 +58,22 @@ async fn consume_stall_stress_finishes_within_idle_budget() { } #[tokio::test] -async fn consume_zero_cap_returns_immediately() { +async fn consume_zero_cap_is_idle_bounded_on_stall() { let started = Instant::now(); - consume_client_data(tokio::io::empty(), 0, MASK_RELAY_IDLE_TIMEOUT).await; + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(OneByteThenStall { sent: false }, 0, MASK_RELAY_IDLE_TIMEOUT), + ) + .await + .expect("zero-cap consume path must remain bounded by timeout guards"); + + let elapsed = started.elapsed(); assert!( - started.elapsed() < MASK_RELAY_IDLE_TIMEOUT, - "zero byte cap must return immediately" + elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2), + "zero cap must not short-circuit before idle timeout path, got {elapsed:?}" + ); + assert!( + elapsed < MASK_RELAY_TIMEOUT, + "zero-cap consume path must complete before relay timeout, got {elapsed:?}" ); } diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs index c5d542e..84e0a86 100644 --- a/src/proxy/tests/masking_production_cap_regression_security_tests.rs +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -148,9 +148,10 @@ async fn positive_copy_with_production_cap_stops_exactly_at_budget() { } #[tokio::test] -async fn negative_consume_with_zero_cap_performs_no_reads() { - let read_calls = Arc::new(AtomicUsize::new(0)); - let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls)); +async fn consume_with_zero_cap_drains_until_eof() { + let payload = 256 * 1024; + let total_read = Arc::new(AtomicUsize::new(0)); + let reader = BudgetProbeReader::new(payload, Arc::clone(&total_read)); consume_client_data_with_timeout_and_cap( reader, @@ -161,9 +162,27 @@ async fn negative_consume_with_zero_cap_performs_no_reads() { .await; assert_eq!( - read_calls.load(Ordering::Relaxed), - 0, - "zero cap must return before reading attacker-controlled bytes" + total_read.load(Ordering::Relaxed), + payload, + "zero cap must disable byte budget and drain finite payload to EOF" + ); +} + +#[tokio::test] +async fn copy_with_zero_cap_drains_until_eof() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let payload = 73 * 1024; + let mut reader = FinitePatternReader::new(payload, 3072, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = + copy_with_idle_timeout(&mut reader, &mut writer, 0, true, MASK_RELAY_IDLE_TIMEOUT).await; + + assert_eq!(outcome.total, payload); + assert_eq!(writer.written, payload); + assert!( + outcome.ended_by_eof, + "zero cap must not terminate relay early on byte budget" ); } diff --git a/src/transport/middle_proxy/fairness/pressure.rs b/src/transport/middle_proxy/fairness/pressure.rs index 6c84a4c..cd61024 100644 --- a/src/transport/middle_proxy/fairness/pressure.rs +++ b/src/transport/middle_proxy/fairness/pressure.rs @@ -12,6 +12,7 @@ pub(crate) struct PressureSignals { #[derive(Debug, Clone)] pub(crate) struct PressureConfig { + pub(crate) backpressure_enabled: bool, pub(crate) evaluate_every_rounds: u32, pub(crate) transition_hysteresis_rounds: u8, pub(crate) standing_ratio_pressured_pct: u8, @@ -32,6 +33,7 @@ pub(crate) struct PressureConfig { impl Default for PressureConfig { fn default() -> Self { Self { + backpressure_enabled: true, evaluate_every_rounds: 8, transition_hysteresis_rounds: 3, standing_ratio_pressured_pct: 20, @@ -99,6 +101,13 @@ impl PressureEvaluator { force: bool, ) -> PressureState { self.rotate_window_if_needed(now, cfg); + if !cfg.backpressure_enabled { + self.state = PressureState::Normal; + self.candidate_state = PressureState::Normal; + self.candidate_hits = 0; + self.rounds_since_eval = 0; + return self.state; + } self.rounds_since_eval = self.rounds_since_eval.saturating_add(1); if !force && self.rounds_since_eval < cfg.evaluate_every_rounds.max(1) { return self.state; @@ -133,6 +142,10 @@ impl PressureEvaluator { max_total_queued_bytes: u64, signals: PressureSignals, ) -> PressureState { + if !cfg.backpressure_enabled { + return PressureState::Normal; + } + let queue_ratio_pct = if max_total_queued_bytes == 0 { 100 } else { @@ -146,57 +159,59 @@ impl PressureEvaluator { ((signals.standing_flows.saturating_mul(100)) / signals.active_flows).min(100) as u8 }; - let mut pressured = false; - let mut saturated = false; + let mut pressure_score = 0u8; - let queue_saturated_pct = cfg - .queue_ratio_shedding_pct - .min(cfg.queue_ratio_saturated_pct); if queue_ratio_pct >= cfg.queue_ratio_pressured_pct { - pressured = true; + pressure_score = pressure_score.max(1); } - if queue_ratio_pct >= queue_saturated_pct { - saturated = true; + if queue_ratio_pct >= cfg.queue_ratio_shedding_pct { + pressure_score = pressure_score.max(2); + } + if queue_ratio_pct >= cfg.queue_ratio_saturated_pct { + pressure_score = pressure_score.max(3); } - let standing_saturated_pct = cfg - .standing_ratio_shedding_pct - .min(cfg.standing_ratio_saturated_pct); if standing_ratio_pct >= cfg.standing_ratio_pressured_pct { - pressured = true; + pressure_score = pressure_score.max(1); } - if standing_ratio_pct >= standing_saturated_pct { - saturated = true; + if standing_ratio_pct >= cfg.standing_ratio_shedding_pct { + pressure_score = pressure_score.max(2); + } + if standing_ratio_pct >= cfg.standing_ratio_saturated_pct { + pressure_score = pressure_score.max(3); } - let rejects_saturated = cfg.rejects_shedding.min(cfg.rejects_saturated); if self.admission_rejects_window >= cfg.rejects_pressured { - pressured = true; + pressure_score = pressure_score.max(1); } - if self.admission_rejects_window >= rejects_saturated { - saturated = true; + if self.admission_rejects_window >= cfg.rejects_shedding { + pressure_score = pressure_score.max(2); + } + if self.admission_rejects_window >= cfg.rejects_saturated { + pressure_score = pressure_score.max(3); } - let stalls_saturated = cfg.stalls_shedding.min(cfg.stalls_saturated); if self.route_stalls_window >= cfg.stalls_pressured { - pressured = true; + pressure_score = pressure_score.max(1); } - if self.route_stalls_window >= stalls_saturated { - saturated = true; + if self.route_stalls_window >= cfg.stalls_shedding { + pressure_score = pressure_score.max(2); + } + if self.route_stalls_window >= cfg.stalls_saturated { + pressure_score = pressure_score.max(3); } if signals.backpressured_flows > signals.active_flows.saturating_div(2) && signals.active_flows > 0 { - pressured = true; + pressure_score = pressure_score.max(2); } - if saturated { - PressureState::Saturated - } else if pressured { - PressureState::Pressured - } else { - PressureState::Normal + match pressure_score { + 0 => PressureState::Normal, + 1 => PressureState::Pressured, + 2 => PressureState::Shedding, + _ => PressureState::Saturated, } } diff --git a/src/transport/middle_proxy/fairness/scheduler.rs b/src/transport/middle_proxy/fairness/scheduler.rs index b8079ce..434bbcd 100644 --- a/src/transport/middle_proxy/fairness/scheduler.rs +++ b/src/transport/middle_proxy/fairness/scheduler.rs @@ -14,6 +14,7 @@ use super::pressure::{PressureConfig, PressureEvaluator, PressureSignals}; #[derive(Debug, Clone)] pub(crate) struct WorkerFairnessConfig { pub(crate) worker_id: u16, + pub(crate) backpressure_enabled: bool, pub(crate) max_active_flows: usize, pub(crate) max_total_queued_bytes: u64, pub(crate) max_flow_queued_bytes: u64, @@ -36,6 +37,7 @@ impl Default for WorkerFairnessConfig { fn default() -> Self { Self { worker_id: 0, + backpressure_enabled: true, max_active_flows: 4096, max_total_queued_bytes: 16 * 1024 * 1024, max_flow_queued_bytes: 512 * 1024, @@ -107,7 +109,8 @@ pub(crate) struct WorkerFairnessState { } impl WorkerFairnessState { - pub(crate) fn new(config: WorkerFairnessConfig, now: Instant) -> Self { + pub(crate) fn new(mut config: WorkerFairnessConfig, now: Instant) -> Self { + config.pressure.backpressure_enabled = config.backpressure_enabled; let bucket_count = config.soft_bucket_count.max(1); Self { config, @@ -134,6 +137,15 @@ impl WorkerFairnessState { self.pressure.state() } + pub(crate) fn set_backpressure_enabled(&mut self, enabled: bool) { + if self.config.backpressure_enabled == enabled { + return; + } + self.config.backpressure_enabled = enabled; + self.config.pressure.backpressure_enabled = enabled; + self.evaluate_pressure(Instant::now(), true); + } + pub(crate) fn snapshot(&self) -> WorkerFairnessSnapshot { WorkerFairnessSnapshot { pressure_state: self.pressure.state(), @@ -166,7 +178,7 @@ impl WorkerFairnessState { }; let frame_bytes = frame.queued_bytes(); - if self.pressure.state() == PressureState::Saturated { + if self.config.backpressure_enabled && self.pressure.state() == PressureState::Saturated { self.pressure .note_admission_reject(now, &self.config.pressure); self.enqueue_rejects = self.enqueue_rejects.saturating_add(1); @@ -231,7 +243,8 @@ impl WorkerFairnessState { return AdmissionDecision::RejectFlowCap; } - if self.pressure.state() >= PressureState::Shedding + if self.config.backpressure_enabled + && self.pressure.state() >= PressureState::Shedding && entry.fairness.standing_state == StandingQueueState::Standing { self.pressure @@ -422,8 +435,10 @@ impl WorkerFairnessState { DispatchAction::Continue } DispatchFeedback::QueueFull => { - self.pressure.note_route_stall(now, &self.config.pressure); - self.downstream_stalls = self.downstream_stalls.saturating_add(1); + if self.config.backpressure_enabled { + self.pressure.note_route_stall(now, &self.config.pressure); + self.downstream_stalls = self.downstream_stalls.saturating_add(1); + } let state = self.pressure.state(); let Some(flow) = self.flows.get_mut(&conn_id) else { self.evaluate_pressure(now, true); @@ -433,16 +448,19 @@ impl WorkerFairnessState { let before_membership = Self::flow_membership(&flow.fairness); let mut enqueue_active = false; - flow.fairness.consecutive_stalls = - flow.fairness.consecutive_stalls.saturating_add(1); - flow.fairness.scheduler_state = FlowSchedulerState::Backpressured; - flow.fairness.pressure_class = FlowPressureClass::Backpressured; + if self.config.backpressure_enabled { + flow.fairness.consecutive_stalls = + flow.fairness.consecutive_stalls.saturating_add(1); + flow.fairness.scheduler_state = FlowSchedulerState::Backpressured; + flow.fairness.pressure_class = FlowPressureClass::Backpressured; + } - let should_shed_frame = matches!(state, PressureState::Saturated) - || (matches!(state, PressureState::Shedding) - && flow.fairness.standing_state == StandingQueueState::Standing - && flow.fairness.consecutive_stalls - >= self.config.max_consecutive_stalls_before_shed); + let should_shed_frame = self.config.backpressure_enabled + && (matches!(state, PressureState::Saturated) + || (matches!(state, PressureState::Shedding) + && flow.fairness.standing_state == StandingQueueState::Standing + && flow.fairness.consecutive_stalls + >= self.config.max_consecutive_stalls_before_shed)); if should_shed_frame { self.shed_drops = self.shed_drops.saturating_add(1); @@ -467,8 +485,9 @@ impl WorkerFairnessState { Self::classify_flow(&self.config, state, now, &mut flow.fairness); let after_membership = Self::flow_membership(&flow.fairness); - let should_close_flow = flow.fairness.consecutive_stalls - >= self.config.max_consecutive_stalls_before_close + let should_close_flow = self.config.backpressure_enabled + && flow.fairness.consecutive_stalls + >= self.config.max_consecutive_stalls_before_close && self.pressure.state() == PressureState::Saturated; ( before_membership, diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 000bca0..399fd13 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1794,6 +1794,8 @@ mod tests { MeSocksKdfPolicy::default(), general.me_writer_cmd_channel_capacity, general.me_route_channel_capacity, + general.me_route_backpressure_enabled, + general.me_route_fairshare_enabled, general.me_route_backpressure_base_timeout_ms, general.me_route_backpressure_high_timeout_ms, general.me_route_backpressure_high_watermark_pct, diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index b89a844..404e864 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -396,6 +396,8 @@ pub(super) struct WriterSelectionPolicyCore { pub(super) struct TransportPolicyCore { pub(super) me_socks_kdf_policy: AtomicU8, + pub(super) me_route_backpressure_enabled: Arc, + pub(super) me_route_fairshare_enabled: Arc, pub(super) me_reader_route_data_wait_ms: Arc, } @@ -548,6 +550,8 @@ impl MePool { me_socks_kdf_policy: MeSocksKdfPolicy, me_writer_cmd_channel_capacity: usize, me_route_channel_capacity: usize, + me_route_backpressure_enabled: bool, + me_route_fairshare_enabled: bool, me_route_backpressure_base_timeout_ms: u64, me_route_backpressure_high_timeout_ms: u64, me_route_backpressure_high_watermark_pct: u8, @@ -783,6 +787,10 @@ impl MePool { }), transport_policy: Arc::new(TransportPolicyCore { me_socks_kdf_policy: AtomicU8::new(me_socks_kdf_policy.as_u8()), + me_route_backpressure_enabled: Arc::new(AtomicBool::new( + me_route_backpressure_enabled, + )), + me_route_fairshare_enabled: Arc::new(AtomicBool::new(me_route_fairshare_enabled)), me_reader_route_data_wait_ms: Arc::new(AtomicU64::new( me_reader_route_data_wait_ms, )), @@ -1245,6 +1253,8 @@ impl MePool { pub fn update_runtime_transport_policy( &self, socks_kdf_policy: MeSocksKdfPolicy, + route_backpressure_enabled: bool, + route_fairshare_enabled: bool, route_backpressure_base_timeout_ms: u64, route_backpressure_high_timeout_ms: u64, route_backpressure_high_watermark_pct: u8, @@ -1253,6 +1263,12 @@ impl MePool { self.transport_policy .me_socks_kdf_policy .store(socks_kdf_policy.as_u8(), Ordering::Relaxed); + self.transport_policy + .me_route_backpressure_enabled + .store(route_backpressure_enabled, Ordering::Relaxed); + self.transport_policy + .me_route_fairshare_enabled + .store(route_fairshare_enabled, Ordering::Relaxed); self.transport_policy .me_reader_route_data_wait_ms .store(reader_route_data_wait_ms, Ordering::Relaxed); diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 52c8fae..0644e8d 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -436,6 +436,9 @@ impl MePool { let cancel_signal = cancel.clone(); let cancel_select = cancel.clone(); let cancel_cleanup = cancel.clone(); + let route_backpressure_enabled = + self.transport_policy.me_route_backpressure_enabled.clone(); + let route_fairshare_enabled = self.transport_policy.me_route_fairshare_enabled.clone(); let reader_route_data_wait_ms = self.transport_policy.me_reader_route_data_wait_ms.clone(); tokio::spawn(async move { @@ -458,6 +461,8 @@ impl MePool { writer_id, degraded, rtt_ema_ms_x10, + route_backpressure_enabled, + route_fairshare_enabled, reader_route_data_wait_ms, cancel_reader, ) => WriterLifecycleExit::Reader(reader_res), diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index e1e919f..97fa329 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -45,7 +45,15 @@ fn is_data_route_queue_full(result: RouteResult) -> bool { ) } -fn should_close_on_queue_full_streak(streak: u8, pressure_state: PressureState) -> bool { +fn should_close_on_queue_full_streak_with_policy( + streak: u8, + pressure_state: PressureState, + backpressure_enabled: bool, +) -> bool { + if !backpressure_enabled { + return false; + } + if pressure_state < PressureState::Shedding { return false; } @@ -160,6 +168,7 @@ async fn drain_fairness_scheduler( reg: &ConnRegistry, tx: &mpsc::Sender, data_route_queue_full_streak: &mut HashMap, + backpressure_enabled: bool, route_wait_ms: u64, stats: &Stats, ) { @@ -188,7 +197,11 @@ async fn drain_fairness_scheduler( if is_data_route_queue_full(routed) { let streak = data_route_queue_full_streak.entry(cid).or_insert(0); *streak = streak.saturating_add(1); - if should_close_on_queue_full_streak(*streak, pressure_state) { + if should_close_on_queue_full_streak_with_policy( + *streak, + pressure_state, + backpressure_enabled, + ) { fairness.remove_flow(cid); data_route_queue_full_streak.remove(&cid); reg.unregister(cid).await; @@ -220,6 +233,8 @@ pub(crate) async fn reader_loop( writer_id: u64, degraded: Arc, writer_rtt_ema_ms_x10: Arc, + route_backpressure_enabled: Arc, + route_fairshare_enabled: Arc, reader_route_data_wait_ms: Arc, cancel: CancellationToken, ) -> Result<()> { @@ -236,14 +251,19 @@ pub(crate) async fn reader_loop( max_flow_queued_bytes: (reg.route_channel_capacity() as u64) .saturating_mul(2 * 1024) .clamp(64 * 1024, 2 * 1024 * 1024), + backpressure_enabled: route_backpressure_enabled.load(Ordering::Relaxed), ..WorkerFairnessConfig::default() }, Instant::now(), ); let mut fairness_snapshot = fairness.snapshot(); loop { + let backpressure_enabled = route_backpressure_enabled.load(Ordering::Relaxed); + let fairshare_enabled = route_fairshare_enabled.load(Ordering::Relaxed); + fairness.set_backpressure_enabled(backpressure_enabled); + let fairness_has_backlog = should_schedule_fairness_retry(&fairness_snapshot); let mut tmp = [0u8; 65_536]; - let backlog_retry_enabled = should_schedule_fairness_retry(&fairness_snapshot); + let backlog_retry_enabled = fairness_has_backlog; let backlog_retry_delay = fairness_retry_delay(reader_route_data_wait_ms.load(Ordering::Relaxed)); let mut retry_only = false; @@ -262,6 +282,7 @@ pub(crate) async fn reader_loop( reg.as_ref(), &tx, &mut data_route_queue_full_streak, + backpressure_enabled, route_wait_ms, stats.as_ref(), ) @@ -346,20 +367,56 @@ pub(crate) async fn reader_loop( let data = body.slice(12..); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); - let admission = fairness.enqueue_data(cid, flags, data, Instant::now()); - if !matches!(admission, AdmissionDecision::Admit) { - stats.increment_me_route_drop_queue_full(); - stats.increment_me_route_drop_queue_full_high(); - let streak = data_route_queue_full_streak.entry(cid).or_insert(0); - *streak = streak.saturating_add(1); - let pressure_state = fairness.pressure_state(); - if should_close_on_queue_full_streak(*streak, pressure_state) - || matches!(admission, AdmissionDecision::RejectSaturated) - { + if fairshare_enabled { + let admission = fairness.enqueue_data(cid, flags, data, Instant::now()); + if !matches!(admission, AdmissionDecision::Admit) { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_high(); + let streak = data_route_queue_full_streak.entry(cid).or_insert(0); + *streak = streak.saturating_add(1); + let pressure_state = fairness.pressure_state(); + if should_close_on_queue_full_streak_with_policy( + *streak, + pressure_state, + backpressure_enabled, + ) || (backpressure_enabled + && matches!(admission, AdmissionDecision::RejectSaturated)) + { + fairness.remove_flow(cid); + data_route_queue_full_streak.remove(&cid); + reg.unregister(cid).await; + send_close_conn(&tx, cid).await; + } + } + } else { + let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); + let routed = + route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; + if matches!(routed, RouteResult::Routed) { + data_route_queue_full_streak.remove(&cid); + continue; + } + report_route_drop(routed, stats.as_ref()); + if should_close_on_route_result_for_data(routed) { fairness.remove_flow(cid); data_route_queue_full_streak.remove(&cid); reg.unregister(cid).await; send_close_conn(&tx, cid).await; + continue; + } + if is_data_route_queue_full(routed) { + let streak = data_route_queue_full_streak.entry(cid).or_insert(0); + *streak = streak.saturating_add(1); + if should_close_on_queue_full_streak_with_policy( + *streak, + PressureState::Shedding, + backpressure_enabled, + ) { + fairness.remove_flow(cid); + data_route_queue_full_streak.remove(&cid); + reg.unregister(cid).await; + send_close_conn(&tx, cid).await; + } } } } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { @@ -465,6 +522,7 @@ pub(crate) async fn reader_loop( reg.as_ref(), &tx, &mut data_route_queue_full_streak, + backpressure_enabled, route_wait_ms, stats.as_ref(), ) @@ -486,9 +544,9 @@ mod tests { use super::{ MeResponse, RouteResult, WorkerFairnessSnapshot, fairness_retry_delay, - is_data_route_queue_full, route_data_with_retry, should_close_on_queue_full_streak, - should_close_on_route_result_for_ack, should_close_on_route_result_for_data, - should_schedule_fairness_retry, + is_data_route_queue_full, route_data_with_retry, + should_close_on_queue_full_streak_with_policy, should_close_on_route_result_for_ack, + should_close_on_route_result_for_data, should_schedule_fairness_retry, }; #[test] @@ -511,22 +569,35 @@ mod tests { assert!(is_data_route_queue_full(RouteResult::QueueFullBase)); assert!(is_data_route_queue_full(RouteResult::QueueFullHigh)); assert!(!is_data_route_queue_full(RouteResult::NoConn)); - assert!(!should_close_on_queue_full_streak(1, PressureState::Normal)); - assert!(!should_close_on_queue_full_streak( + assert!(!should_close_on_queue_full_streak_with_policy( + 1, + PressureState::Normal, + true + )); + assert!(!should_close_on_queue_full_streak_with_policy( 2, - PressureState::Pressured + PressureState::Pressured, + true )); - assert!(!should_close_on_queue_full_streak( + assert!(!should_close_on_queue_full_streak_with_policy( 3, - PressureState::Pressured + PressureState::Pressured, + true )); - assert!(should_close_on_queue_full_streak( + assert!(should_close_on_queue_full_streak_with_policy( 3, - PressureState::Shedding + PressureState::Shedding, + true )); - assert!(should_close_on_queue_full_streak( + assert!(should_close_on_queue_full_streak_with_policy( u8::MAX, - PressureState::Saturated + PressureState::Saturated, + true + )); + assert!(!should_close_on_queue_full_streak_with_policy( + u8::MAX, + PressureState::Saturated, + false )); } diff --git a/src/transport/middle_proxy/tests/health_adversarial_tests.rs b/src/transport/middle_proxy/tests/health_adversarial_tests.rs index 4bee91c..ea88c67 100644 --- a/src/transport/middle_proxy/tests/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/health_adversarial_tests.rs @@ -104,6 +104,8 @@ async fn make_pool( MeSocksKdfPolicy::default(), general.me_writer_cmd_channel_capacity, general.me_route_channel_capacity, + general.me_route_backpressure_enabled, + general.me_route_fairshare_enabled, general.me_route_backpressure_base_timeout_ms, general.me_route_backpressure_high_timeout_ms, general.me_route_backpressure_high_watermark_pct, diff --git a/src/transport/middle_proxy/tests/health_integration_tests.rs b/src/transport/middle_proxy/tests/health_integration_tests.rs index 0a6e110..9b3f93e 100644 --- a/src/transport/middle_proxy/tests/health_integration_tests.rs +++ b/src/transport/middle_proxy/tests/health_integration_tests.rs @@ -102,6 +102,8 @@ async fn make_pool( MeSocksKdfPolicy::default(), general.me_writer_cmd_channel_capacity, general.me_route_channel_capacity, + general.me_route_backpressure_enabled, + general.me_route_fairshare_enabled, general.me_route_backpressure_base_timeout_ms, general.me_route_backpressure_high_timeout_ms, general.me_route_backpressure_high_watermark_pct, diff --git a/src/transport/middle_proxy/tests/health_regression_tests.rs b/src/transport/middle_proxy/tests/health_regression_tests.rs index 92398b4..aa1f9ed 100644 --- a/src/transport/middle_proxy/tests/health_regression_tests.rs +++ b/src/transport/middle_proxy/tests/health_regression_tests.rs @@ -97,6 +97,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc { MeSocksKdfPolicy::default(), general.me_writer_cmd_channel_capacity, general.me_route_channel_capacity, + general.me_route_backpressure_enabled, + general.me_route_fairshare_enabled, general.me_route_backpressure_base_timeout_ms, general.me_route_backpressure_high_timeout_ms, general.me_route_backpressure_high_watermark_pct, diff --git a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs index 4463444..6519c05 100644 --- a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs @@ -86,6 +86,8 @@ async fn make_pool() -> Arc { MeSocksKdfPolicy::default(), general.me_writer_cmd_channel_capacity, general.me_route_channel_capacity, + general.me_route_backpressure_enabled, + general.me_route_fairshare_enabled, general.me_route_backpressure_base_timeout_ms, general.me_route_backpressure_high_timeout_ms, general.me_route_backpressure_high_watermark_pct, diff --git a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs index 0184e11..5f9f130 100644 --- a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs @@ -91,6 +91,8 @@ async fn make_pool() -> Arc { MeSocksKdfPolicy::default(), general.me_writer_cmd_channel_capacity, general.me_route_channel_capacity, + general.me_route_backpressure_enabled, + general.me_route_fairshare_enabled, general.me_route_backpressure_base_timeout_ms, general.me_route_backpressure_high_timeout_ms, general.me_route_backpressure_high_watermark_pct, diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs index de52d18..6e556d4 100644 --- a/src/transport/middle_proxy/tests/send_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -97,6 +97,8 @@ async fn make_pool() -> (Arc, Arc) { MeSocksKdfPolicy::default(), general.me_writer_cmd_channel_capacity, general.me_route_channel_capacity, + general.me_route_backpressure_enabled, + general.me_route_fairshare_enabled, general.me_route_backpressure_base_timeout_ms, general.me_route_backpressure_high_timeout_ms, general.me_route_backpressure_high_watermark_pct,