diff --git a/src/config/defaults.rs b/src/config/defaults.rs index e885cbe..cb95637 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -24,6 +24,13 @@ const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_WARM_WRITERS_GLOBAL: u32 = 256; const DEFAULT_ME_WRITER_CMD_CHANNEL_CAPACITY: usize = 4096; const DEFAULT_ME_ROUTE_CHANNEL_CAPACITY: usize = 768; const DEFAULT_ME_C2ME_CHANNEL_CAPACITY: usize = 1024; +const DEFAULT_ME_READER_ROUTE_DATA_WAIT_MS: u64 = 2; +const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_FRAMES: usize = 32; +const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_BYTES: usize = 128 * 1024; +const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_DELAY_US: u64 = 1500; +const DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE: bool = false; +const DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES: usize = 64 * 1024; +const DEFAULT_DIRECT_RELAY_COPY_BUF_S2C_BYTES: usize = 256 * 1024; const DEFAULT_ME_WRITER_PICK_SAMPLE_SIZE: u8 = 3; const DEFAULT_ME_HEALTH_INTERVAL_MS_UNHEALTHY: u64 = 1000; const DEFAULT_ME_HEALTH_INTERVAL_MS_HEALTHY: u64 = 3000; @@ -316,6 +323,34 @@ pub(crate) fn default_me_c2me_channel_capacity() -> usize { DEFAULT_ME_C2ME_CHANNEL_CAPACITY } +pub(crate) fn default_me_reader_route_data_wait_ms() -> u64 { + DEFAULT_ME_READER_ROUTE_DATA_WAIT_MS +} + +pub(crate) fn default_me_d2c_flush_batch_max_frames() -> usize { + DEFAULT_ME_D2C_FLUSH_BATCH_MAX_FRAMES +} + +pub(crate) fn default_me_d2c_flush_batch_max_bytes() -> usize { + DEFAULT_ME_D2C_FLUSH_BATCH_MAX_BYTES +} + +pub(crate) fn default_me_d2c_flush_batch_max_delay_us() -> u64 { + DEFAULT_ME_D2C_FLUSH_BATCH_MAX_DELAY_US +} + +pub(crate) fn default_me_d2c_ack_flush_immediate() -> bool { + DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE +} + +pub(crate) fn default_direct_relay_copy_buf_c2s_bytes() -> usize { + DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES +} + +pub(crate) fn default_direct_relay_copy_buf_s2c_bytes() -> usize { + DEFAULT_DIRECT_RELAY_COPY_BUF_S2C_BYTES +} + pub(crate) fn default_me_writer_pick_sample_size() -> u8 { DEFAULT_ME_WRITER_PICK_SAMPLE_SIZE } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 34b2d76..632ca8c 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -96,6 +96,13 @@ pub struct HotFields { pub me_route_backpressure_base_timeout_ms: u64, pub me_route_backpressure_high_timeout_ms: u64, pub me_route_backpressure_high_watermark_pct: u8, + pub me_reader_route_data_wait_ms: u64, + pub me_d2c_flush_batch_max_frames: usize, + pub me_d2c_flush_batch_max_bytes: usize, + pub me_d2c_flush_batch_max_delay_us: u64, + pub me_d2c_ack_flush_immediate: bool, + pub direct_relay_copy_buf_c2s_bytes: usize, + pub direct_relay_copy_buf_s2c_bytes: usize, pub me_health_interval_ms_unhealthy: u64, pub me_health_interval_ms_healthy: u64, pub me_admission_poll_ms: u64, @@ -203,6 +210,13 @@ impl HotFields { me_route_backpressure_base_timeout_ms: cfg.general.me_route_backpressure_base_timeout_ms, me_route_backpressure_high_timeout_ms: cfg.general.me_route_backpressure_high_timeout_ms, me_route_backpressure_high_watermark_pct: cfg.general.me_route_backpressure_high_watermark_pct, + me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms, + me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames, + me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes, + me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us, + me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate, + direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes, + direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes, me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy, me_health_interval_ms_healthy: cfg.general.me_health_interval_ms_healthy, me_admission_poll_ms: cfg.general.me_admission_poll_ms, @@ -352,6 +366,13 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { new.general.me_route_backpressure_high_timeout_ms; cfg.general.me_route_backpressure_high_watermark_pct = new.general.me_route_backpressure_high_watermark_pct; + cfg.general.me_reader_route_data_wait_ms = new.general.me_reader_route_data_wait_ms; + cfg.general.me_d2c_flush_batch_max_frames = new.general.me_d2c_flush_batch_max_frames; + cfg.general.me_d2c_flush_batch_max_bytes = new.general.me_d2c_flush_batch_max_bytes; + cfg.general.me_d2c_flush_batch_max_delay_us = new.general.me_d2c_flush_batch_max_delay_us; + cfg.general.me_d2c_ack_flush_immediate = new.general.me_d2c_ack_flush_immediate; + cfg.general.direct_relay_copy_buf_c2s_bytes = new.general.direct_relay_copy_buf_c2s_bytes; + cfg.general.direct_relay_copy_buf_s2c_bytes = new.general.direct_relay_copy_buf_s2c_bytes; cfg.general.me_health_interval_ms_unhealthy = new.general.me_health_interval_ms_unhealthy; cfg.general.me_health_interval_ms_healthy = new.general.me_health_interval_ms_healthy; cfg.general.me_admission_poll_ms = new.general.me_admission_poll_ms; @@ -821,6 +842,7 @@ fn log_changes( != new_hot.me_route_backpressure_high_timeout_ms || old_hot.me_route_backpressure_high_watermark_pct != new_hot.me_route_backpressure_high_watermark_pct + || old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms || old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy || old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy @@ -828,10 +850,11 @@ fn log_changes( || old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms { info!( - "config reload: me_route_backpressure: base={}ms high={}ms watermark={}%; me_health_interval: unhealthy={}ms healthy={}ms; me_admission_poll={}ms; me_warn_rate_limit={}ms", + "config reload: me_route_backpressure: base={}ms high={}ms watermark={}%; me_reader_route_data_wait_ms={}; me_health_interval: unhealthy={}ms healthy={}ms; me_admission_poll={}ms; me_warn_rate_limit={}ms", new_hot.me_route_backpressure_base_timeout_ms, new_hot.me_route_backpressure_high_timeout_ms, new_hot.me_route_backpressure_high_watermark_pct, + new_hot.me_reader_route_data_wait_ms, new_hot.me_health_interval_ms_unhealthy, new_hot.me_health_interval_ms_healthy, new_hot.me_admission_poll_ms, @@ -839,6 +862,24 @@ fn log_changes( ); } + if old_hot.me_d2c_flush_batch_max_frames != new_hot.me_d2c_flush_batch_max_frames + || old_hot.me_d2c_flush_batch_max_bytes != new_hot.me_d2c_flush_batch_max_bytes + || old_hot.me_d2c_flush_batch_max_delay_us != new_hot.me_d2c_flush_batch_max_delay_us + || old_hot.me_d2c_ack_flush_immediate != new_hot.me_d2c_ack_flush_immediate + || old_hot.direct_relay_copy_buf_c2s_bytes != new_hot.direct_relay_copy_buf_c2s_bytes + || old_hot.direct_relay_copy_buf_s2c_bytes != new_hot.direct_relay_copy_buf_s2c_bytes + { + info!( + "config reload: relay_tuning: me_d2c_frames={} me_d2c_bytes={} me_d2c_delay_us={} me_ack_flush_immediate={} direct_buf_c2s={} direct_buf_s2c={}", + new_hot.me_d2c_flush_batch_max_frames, + new_hot.me_d2c_flush_batch_max_bytes, + new_hot.me_d2c_flush_batch_max_delay_us, + new_hot.me_d2c_ack_flush_immediate, + new_hot.direct_relay_copy_buf_c2s_bytes, + new_hot.direct_relay_copy_buf_s2c_bytes, + ); + } + if old_hot.users != new_hot.users { let mut added: Vec<&String> = new_hot.users.keys() .filter(|u| !old_hot.users.contains_key(*u)) diff --git a/src/config/load.rs b/src/config/load.rs index 623ec4d..3f1cd5c 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -303,6 +303,42 @@ impl ProxyConfig { )); } + if config.general.me_reader_route_data_wait_ms > 20 { + return Err(ProxyError::Config( + "general.me_reader_route_data_wait_ms must be within [0, 20]".to_string(), + )); + } + + if !(1..=512).contains(&config.general.me_d2c_flush_batch_max_frames) { + return Err(ProxyError::Config( + "general.me_d2c_flush_batch_max_frames must be within [1, 512]".to_string(), + )); + } + + if !(4096..=2 * 1024 * 1024).contains(&config.general.me_d2c_flush_batch_max_bytes) { + return Err(ProxyError::Config( + "general.me_d2c_flush_batch_max_bytes must be within [4096, 2097152]".to_string(), + )); + } + + if config.general.me_d2c_flush_batch_max_delay_us > 5000 { + return Err(ProxyError::Config( + "general.me_d2c_flush_batch_max_delay_us must be within [0, 5000]".to_string(), + )); + } + + if !(4096..=1024 * 1024).contains(&config.general.direct_relay_copy_buf_c2s_bytes) { + return Err(ProxyError::Config( + "general.direct_relay_copy_buf_c2s_bytes must be within [4096, 1048576]".to_string(), + )); + } + + if !(8192..=2 * 1024 * 1024).contains(&config.general.direct_relay_copy_buf_s2c_bytes) { + return Err(ProxyError::Config( + "general.direct_relay_copy_buf_s2c_bytes must be within [8192, 2097152]".to_string(), + )); + } + if config.general.me_health_interval_ms_unhealthy == 0 { return Err(ProxyError::Config( "general.me_health_interval_ms_unhealthy must be > 0".to_string(), diff --git a/src/config/types.rs b/src/config/types.rs index 588c82f..eeb8cfa 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -458,6 +458,36 @@ pub struct GeneralConfig { #[serde(default = "default_me_c2me_channel_capacity")] pub me_c2me_channel_capacity: usize, + /// Bounded wait in milliseconds for routing ME DATA to per-connection queue. + /// `0` keeps legacy no-wait behavior. + #[serde(default = "default_me_reader_route_data_wait_ms")] + pub me_reader_route_data_wait_ms: u64, + + /// Maximum number of ME->Client responses coalesced before flush. + #[serde(default = "default_me_d2c_flush_batch_max_frames")] + pub me_d2c_flush_batch_max_frames: usize, + + /// Maximum total payload bytes coalesced before flush. + #[serde(default = "default_me_d2c_flush_batch_max_bytes")] + pub me_d2c_flush_batch_max_bytes: usize, + + /// Maximum wait in microseconds to coalesce additional ME->Client responses. + /// `0` disables timed coalescing. + #[serde(default = "default_me_d2c_flush_batch_max_delay_us")] + pub me_d2c_flush_batch_max_delay_us: u64, + + /// Flush client writer immediately after quick-ack write. + #[serde(default = "default_me_d2c_ack_flush_immediate")] + pub me_d2c_ack_flush_immediate: bool, + + /// Copy buffer size for client->DC direction in direct relay. + #[serde(default = "default_direct_relay_copy_buf_c2s_bytes")] + pub direct_relay_copy_buf_c2s_bytes: usize, + + /// Copy buffer size for DC->client direction in direct relay. + #[serde(default = "default_direct_relay_copy_buf_s2c_bytes")] + pub direct_relay_copy_buf_s2c_bytes: usize, + /// Max pending ciphertext buffer per client writer (bytes). /// Controls FakeTLS backpressure vs throughput. #[serde(default = "default_crypto_pending_buffer")] @@ -861,6 +891,13 @@ impl Default for GeneralConfig { me_writer_cmd_channel_capacity: default_me_writer_cmd_channel_capacity(), me_route_channel_capacity: default_me_route_channel_capacity(), me_c2me_channel_capacity: default_me_c2me_channel_capacity(), + me_reader_route_data_wait_ms: default_me_reader_route_data_wait_ms(), + me_d2c_flush_batch_max_frames: default_me_d2c_flush_batch_max_frames(), + me_d2c_flush_batch_max_bytes: default_me_d2c_flush_batch_max_bytes(), + me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(), + me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(), + direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(), + direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(), me_warmup_stagger_enabled: default_true(), me_warmup_step_delay_ms: default_warmup_step_delay_ms(), me_warmup_step_jitter_ms: default_warmup_step_jitter_ms(), diff --git a/src/main.rs b/src/main.rs index d28fabe..3b6a543 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1055,6 +1055,7 @@ async fn main() -> std::result::Result<(), Box> { config.general.me_route_backpressure_base_timeout_ms, config.general.me_route_backpressure_high_timeout_ms, config.general.me_route_backpressure_high_watermark_pct, + config.general.me_reader_route_data_wait_ms, config.general.me_health_interval_ms_unhealthy, config.general.me_health_interval_ms_healthy, config.general.me_warn_rate_limit_ms, @@ -1559,6 +1560,7 @@ async fn main() -> std::result::Result<(), Box> { cfg.general.me_route_backpressure_base_timeout_ms, cfg.general.me_route_backpressure_high_timeout_ms, cfg.general.me_route_backpressure_high_watermark_pct, + cfg.general.me_reader_route_data_wait_ms, ); } } diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index e39e446..d4b0f2e 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -64,6 +64,8 @@ where client_writer, tg_reader, tg_writer, + config.general.direct_relay_copy_buf_c2s_bytes, + config.general.direct_relay_copy_buf_s2c_bytes, user, Arc::clone(&stats), buffer_pool, diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index cae8273..0006914 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -30,6 +30,8 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; +const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; static DESYNC_DEDUP: OnceLock>> = OnceLock::new(); struct RelayForensicsState { @@ -44,6 +46,31 @@ struct RelayForensicsState { desync_all_full: bool, } +#[derive(Clone, Copy)] +struct MeD2cFlushPolicy { + max_frames: usize, + max_bytes: usize, + max_delay: Duration, + ack_flush_immediate: bool, +} + +impl MeD2cFlushPolicy { + fn from_config(config: &ProxyConfig) -> Self { + Self { + max_frames: config + .general + .me_d2c_flush_batch_max_frames + .max(ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN), + max_bytes: config + .general + .me_d2c_flush_batch_max_bytes + .max(ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN), + max_delay: Duration::from_micros(config.general.me_d2c_flush_batch_max_delay_us), + ack_flush_immediate: config.general.me_d2c_ack_flush_immediate, + } + } +} + fn hash_value(value: &T) -> u64 { let mut hasher = DefaultHasher::new(); value.hash(&mut hasher); @@ -313,71 +340,152 @@ where let rng_clone = rng.clone(); let user_clone = user.clone(); let bytes_me2c_clone = bytes_me2c.clone(); + let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); let me_writer = tokio::spawn(async move { let mut writer = crypto_writer; let mut frame_buf = Vec::with_capacity(16 * 1024); loop { tokio::select! { msg = me_rx_task.recv() => { - match msg { - Some(MeResponse::Data { flags, data }) => { - trace!(conn_id, bytes = data.len(), flags, "ME->C data"); - bytes_me2c_clone.fetch_add(data.len() as u64, Ordering::Relaxed); - stats_clone.add_user_octets_to(&user_clone, data.len() as u64); - write_client_payload( - &mut writer, - proto_tag, - flags, - &data, - rng_clone.as_ref(), - &mut frame_buf, - ) - .await?; + let Some(first) = msg else { + debug!(conn_id, "ME channel closed"); + return Err(ProxyError::Proxy("ME connection lost".into())); + }; - // Drain all immediately queued ME responses and flush once. - while let Ok(next) = me_rx_task.try_recv() { - match next { - MeResponse::Data { flags, data } => { - trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); - bytes_me2c_clone.fetch_add(data.len() as u64, Ordering::Relaxed); - stats_clone.add_user_octets_to(&user_clone, data.len() as u64); - write_client_payload( - &mut writer, - proto_tag, - flags, - &data, - rng_clone.as_ref(), - &mut frame_buf, - ).await?; + let mut batch_frames = 0usize; + let mut batch_bytes = 0usize; + let mut flush_immediately = false; + + match process_me_writer_response( + first, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + false, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately = immediate; + } + MeWriterResponseOutcome::Close => { + let _ = writer.flush().await; + return Ok(()); + } + } + + while !flush_immediately + && batch_frames < d2c_flush_policy.max_frames + && batch_bytes < d2c_flush_policy.max_bytes + { + let Ok(next) = me_rx_task.try_recv() else { + break; + }; + + match process_me_writer_response( + next, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + true, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately |= immediate; + } + MeWriterResponseOutcome::Close => { + let _ = writer.flush().await; + return Ok(()); + } + } + } + + if !flush_immediately + && !d2c_flush_policy.max_delay.is_zero() + && batch_frames < d2c_flush_policy.max_frames + && batch_bytes < d2c_flush_policy.max_bytes + { + match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await { + Ok(Some(next)) => { + match process_me_writer_response( + next, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + true, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately |= immediate; } - MeResponse::Ack(confirm) => { - trace!(conn_id, confirm, "ME->C quickack (batched)"); - write_client_ack(&mut writer, proto_tag, confirm).await?; - } - MeResponse::Close => { - debug!(conn_id, "ME sent close (batched)"); + MeWriterResponseOutcome::Close => { let _ = writer.flush().await; return Ok(()); } } - } - writer.flush().await.map_err(ProxyError::Io)?; - } - Some(MeResponse::Ack(confirm)) => { - trace!(conn_id, confirm, "ME->C quickack"); - write_client_ack(&mut writer, proto_tag, confirm).await?; - } - Some(MeResponse::Close) => { - debug!(conn_id, "ME sent close"); - let _ = writer.flush().await; - return Ok(()); - } - None => { - debug!(conn_id, "ME channel closed"); - return Err(ProxyError::Proxy("ME connection lost".into())); + while !flush_immediately + && batch_frames < d2c_flush_policy.max_frames + && batch_bytes < d2c_flush_policy.max_bytes + { + let Ok(extra) = me_rx_task.try_recv() else { + break; + }; + + match process_me_writer_response( + extra, + &mut writer, + proto_tag, + rng_clone.as_ref(), + &mut frame_buf, + stats_clone.as_ref(), + &user_clone, + bytes_me2c_clone.as_ref(), + conn_id, + d2c_flush_policy.ack_flush_immediate, + true, + ).await? { + MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + batch_frames = batch_frames.saturating_add(frames); + batch_bytes = batch_bytes.saturating_add(bytes); + flush_immediately |= immediate; + } + MeWriterResponseOutcome::Close => { + let _ = writer.flush().await; + return Ok(()); + } + } + } + } + Ok(None) => { + debug!(conn_id, "ME channel closed"); + return Err(ProxyError::Proxy("ME connection lost".into())); + } + Err(_) => {} } } + + writer.flush().await.map_err(ProxyError::Io)?; } _ = &mut stop_rx => { debug!(conn_id, "ME writer stop signal"); @@ -587,6 +695,81 @@ where } } +enum MeWriterResponseOutcome { + Continue { + frames: usize, + bytes: usize, + flush_immediately: bool, + }, + Close, +} + +async fn process_me_writer_response( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + match response { + MeResponse::Data { flags, data } => { + if batched { + trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); + } else { + trace!(conn_id, bytes = data.len(), flags, "ME->C data"); + } + bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); + stats.add_user_octets_to(user, data.len() as u64); + write_client_payload( + client_writer, + proto_tag, + flags, + &data, + rng, + frame_buf, + ) + .await?; + + Ok(MeWriterResponseOutcome::Continue { + frames: 1, + bytes: data.len(), + flush_immediately: false, + }) + } + MeResponse::Ack(confirm) => { + if batched { + trace!(conn_id, confirm, "ME->C quickack (batched)"); + } else { + trace!(conn_id, confirm, "ME->C quickack"); + } + write_client_ack(client_writer, proto_tag, confirm).await?; + + Ok(MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: ack_flush_immediate, + }) + } + MeResponse::Close => { + if batched { + debug!(conn_id, "ME sent close (batched)"); + } else { + debug!(conn_id, "ME sent close"); + } + Ok(MeWriterResponseOutcome::Close) + } + } +} + async fn write_client_payload( client_writer: &mut CryptoWriter, proto_tag: ProtoTag, @@ -696,9 +879,7 @@ where client_writer .write_all(&bytes) .await - .map_err(ProxyError::Io)?; - // ACK should remain low-latency. - client_writer.flush().await.map_err(ProxyError::Io) + .map_err(ProxyError::Io) } #[cfg(test)] diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index a155945..06ce0d8 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -57,7 +57,9 @@ use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional}; +use tokio::io::{ + AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes, +}; use tokio::time::Instant; use tracing::{debug, trace, warn}; use crate::error::Result; @@ -296,9 +298,8 @@ impl AsyncWrite for StatsIo { /// /// ## API compatibility /// -/// Signature is identical to the previous implementation. The `_buffer_pool` -/// parameter is retained for call-site compatibility — `copy_bidirectional` -/// manages its own internal buffers (8 KB per direction). +/// The `_buffer_pool` parameter is retained for call-site compatibility. +/// Effective relay copy buffers are configured by `c2s_buf_size` / `s2c_buf_size`. /// /// ## Guarantees preserved /// @@ -312,6 +313,8 @@ pub async fn relay_bidirectional( client_writer: CW, server_reader: SR, server_writer: SW, + c2s_buf_size: usize, + s2c_buf_size: usize, user: &str, stats: Arc, _buffer_pool: Arc, @@ -402,7 +405,12 @@ where // When the watchdog fires, select! drops the copy future, // releasing the &mut borrows on client and server. let copy_result = tokio::select! { - result = copy_bidirectional(&mut client, &mut server) => Some(result), + result = copy_bidirectional_with_sizes( + &mut client, + &mut server, + c2s_buf_size.max(1), + s2c_buf_size.max(1), + ) => Some(result), _ = watchdog => None, // Activity timeout — cancel relay }; @@ -463,4 +471,4 @@ where Ok(()) } } -} \ No newline at end of file +} diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 07ad67b..8d5b110 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -183,6 +183,7 @@ pub struct MePool { pub(super) me_writer_pick_mode: AtomicU8, pub(super) me_writer_pick_sample_size: AtomicU8, pub(super) me_socks_kdf_policy: AtomicU8, + pub(super) me_reader_route_data_wait_ms: Arc, pub(super) me_route_no_writer_mode: AtomicU8, pub(super) me_route_no_writer_wait: Duration, pub(super) me_route_inline_recovery_attempts: u32, @@ -287,6 +288,7 @@ impl MePool { me_route_backpressure_base_timeout_ms: u64, me_route_backpressure_high_timeout_ms: u64, me_route_backpressure_high_watermark_pct: u8, + me_reader_route_data_wait_ms: u64, me_health_interval_ms_unhealthy: u64, me_health_interval_ms_healthy: u64, me_warn_rate_limit_ms: u64, @@ -460,6 +462,7 @@ impl MePool { me_writer_pick_mode: AtomicU8::new(me_writer_pick_mode.as_u8()), me_writer_pick_sample_size: AtomicU8::new(me_writer_pick_sample_size.clamp(2, 4)), me_socks_kdf_policy: AtomicU8::new(me_socks_kdf_policy.as_u8()), + me_reader_route_data_wait_ms: Arc::new(AtomicU64::new(me_reader_route_data_wait_ms)), me_route_no_writer_mode: AtomicU8::new(me_route_no_writer_mode.as_u8()), me_route_no_writer_wait: Duration::from_millis(me_route_no_writer_wait_ms), me_route_inline_recovery_attempts, @@ -650,9 +653,12 @@ impl MePool { route_backpressure_base_timeout_ms: u64, route_backpressure_high_timeout_ms: u64, route_backpressure_high_watermark_pct: u8, + reader_route_data_wait_ms: u64, ) { self.me_socks_kdf_policy .store(socks_kdf_policy.as_u8(), Ordering::Relaxed); + self.me_reader_route_data_wait_ms + .store(reader_route_data_wait_ms, Ordering::Relaxed); self.registry.update_route_backpressure_policy( route_backpressure_base_timeout_ms, route_backpressure_high_timeout_ms, diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 43abf0c..7e79f10 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -208,6 +208,7 @@ impl MePool { let keepalive_jitter_signal = self.me_keepalive_jitter; let cancel_reader_token = cancel.clone(); let cancel_ping_token = cancel_ping.clone(); + let reader_route_data_wait_ms = self.me_reader_route_data_wait_ms.clone(); tokio::spawn(async move { let res = reader_loop( @@ -225,6 +226,7 @@ impl MePool { writer_id, degraded.clone(), rtt_ema_ms_x10.clone(), + reader_route_data_wait_ms, cancel_reader_token.clone(), ) .await; diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 32de774..785bc2c 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::io::ErrorKind; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; use std::time::Instant; use bytes::{Bytes, BytesMut}; @@ -35,6 +35,7 @@ pub(crate) async fn reader_loop( _writer_id: u64, degraded: Arc, writer_rtt_ema_ms_x10: Arc, + reader_route_data_wait_ms: Arc, cancel: CancellationToken, ) -> Result<()> { let mut raw = enc_leftover; @@ -57,17 +58,14 @@ pub(crate) async fn reader_loop( let blocks = raw.len() / 16 * 16; if blocks > 0 { + let mut chunk = raw.split_to(blocks); let mut new_iv = [0u8; 16]; - new_iv.copy_from_slice(&raw[blocks - 16..blocks]); - - let mut chunk = vec![0u8; blocks]; - chunk.copy_from_slice(&raw[..blocks]); + new_iv.copy_from_slice(&chunk[blocks - 16..blocks]); AesCbc::new(dk, div) - .decrypt_in_place(&mut chunk) + .decrypt_in_place(&mut chunk[..]) .map_err(|e| ProxyError::Crypto(format!("{e}")))?; div = new_iv; dec.extend_from_slice(&chunk); - let _ = raw.split_to(blocks); } while dec.len() >= 12 { @@ -85,7 +83,7 @@ pub(crate) async fn reader_loop( break; } - let frame = dec.split_to(fl); + let frame = dec.split_to(fl).freeze(); let pe = fl - 4; let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let actual_crc = rpc_crc(crc_mode, &frame[..pe]); @@ -111,21 +109,27 @@ pub(crate) async fn reader_loop( } expected_seq = expected_seq.wrapping_add(1); - let payload = &frame[8..pe]; + let payload = frame.slice(8..pe); if payload.len() < 4 { continue; } let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap()); - let body = &payload[4..]; + let body = payload.slice(4..); if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 { let flags = u32::from_le_bytes(body[0..4].try_into().unwrap()); let cid = u64::from_le_bytes(body[4..12].try_into().unwrap()); - let data = Bytes::copy_from_slice(&body[12..]); + let data = body.slice(12..); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); - let routed = reg.route_nowait(cid, MeResponse::Data { flags, data }).await; + let data_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); + let routed = if data_wait_ms == 0 { + reg.route_nowait(cid, MeResponse::Data { flags, data }).await + } else { + reg.route_with_timeout(cid, MeResponse::Data { flags, data }, data_wait_ms) + .await + }; if !matches!(routed, RouteResult::Routed) { match routed { RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index f2682d5..ee04969 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -231,6 +231,57 @@ impl ConnRegistry { } } + pub async fn route_with_timeout( + &self, + id: u64, + resp: MeResponse, + timeout_ms: u64, + ) -> RouteResult { + if timeout_ms == 0 { + return self.route_nowait(id, resp).await; + } + + let tx = { + let inner = self.inner.read().await; + inner.map.get(&id).cloned() + }; + + let Some(tx) = tx else { + return RouteResult::NoConn; + }; + + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(resp)) => { + let high_watermark_pct = self + .route_backpressure_high_watermark_pct + .load(Ordering::Relaxed) + .clamp(1, 100); + let used = self.route_channel_capacity.saturating_sub(tx.capacity()); + let used_pct = if self.route_channel_capacity == 0 { + 100 + } else { + (used.saturating_mul(100) / self.route_channel_capacity) as u8 + }; + let high_profile = used_pct >= high_watermark_pct; + let timeout_dur = Duration::from_millis(timeout_ms.max(1)); + + match tokio::time::timeout(timeout_dur, tx.send(resp)).await { + Ok(Ok(())) => RouteResult::Routed, + Ok(Err(_)) => RouteResult::ChannelClosed, + Err(_) => { + if high_profile { + RouteResult::QueueFullHigh + } else { + RouteResult::QueueFullBase + } + } + } + } + } + } + pub async fn bind_writer( &self, conn_id: u64,