diff --git a/src/crypto/random.rs b/src/crypto/random.rs index 99aa5f3..f3432e0 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -49,19 +49,32 @@ impl SecureRandom { } } - /// Generate random bytes - pub fn bytes(&self, len: usize) -> Vec { + /// Fill a caller-provided buffer with random bytes. + pub fn fill(&self, out: &mut [u8]) { let mut inner = self.inner.lock(); const CHUNK_SIZE: usize = 512; - - while inner.buffer.len() < len { - let mut chunk = vec![0u8; CHUNK_SIZE]; - inner.rng.fill_bytes(&mut chunk); - inner.cipher.apply(&mut chunk); - inner.buffer.extend_from_slice(&chunk); + + let mut written = 0usize; + while written < out.len() { + if inner.buffer.is_empty() { + let mut chunk = vec![0u8; CHUNK_SIZE]; + inner.rng.fill_bytes(&mut chunk); + inner.cipher.apply(&mut chunk); + inner.buffer.extend_from_slice(&chunk); + } + + let take = (out.len() - written).min(inner.buffer.len()); + out[written..written + take].copy_from_slice(&inner.buffer[..take]); + inner.buffer.drain(..take); + written += take; } - - inner.buffer.drain(..len).collect() + } + + /// Generate random bytes + pub fn bytes(&self, len: usize) -> Vec { + let mut out = vec![0u8; len]; + self.fill(&mut out); + out } /// Generate random number in range [0, max) diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 7b97049..3b98112 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -95,6 +95,7 @@ where let user_clone = user.clone(); let me_writer = tokio::spawn(async move { let mut writer = crypto_writer; + let mut frame_buf = Vec::with_capacity(16 * 1024); loop { tokio::select! { msg = me_rx_task.recv() => { @@ -102,7 +103,15 @@ where Some(MeResponse::Data { flags, data }) => { trace!(conn_id, bytes = data.len(), flags, "ME->C data"); stats_clone.add_user_octets_to(&user_clone, data.len() as u64); - write_client_payload(&mut writer, proto_tag, flags, &data, rng_clone.as_ref()).await?; + write_client_payload( + &mut writer, + proto_tag, + flags, + &data, + rng_clone.as_ref(), + &mut frame_buf, + ) + .await?; // Drain all immediately queued ME responses and flush once. while let Ok(next) = me_rx_task.try_recv() { @@ -116,6 +125,7 @@ where flags, &data, rng_clone.as_ref(), + &mut frame_buf, ).await?; } MeResponse::Ack(confirm) => { @@ -363,6 +373,7 @@ async fn write_client_payload( flags: u32, data: &[u8], rng: &SecureRandom, + frame_buf: &mut Vec, ) -> Result<()> where W: AsyncWrite + Unpin + Send + 'static, @@ -384,7 +395,8 @@ where if quickack { first |= 0x80; } - let mut frame_buf = Vec::with_capacity(1 + data.len()); + frame_buf.clear(); + frame_buf.reserve(1 + data.len()); frame_buf.push(first); frame_buf.extend_from_slice(data); client_writer @@ -397,7 +409,8 @@ where first |= 0x80; } let lw = (len_words as u32).to_le_bytes(); - let mut frame_buf = Vec::with_capacity(4 + data.len()); + frame_buf.clear(); + frame_buf.reserve(4 + data.len()); frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); frame_buf.extend_from_slice(data); client_writer @@ -428,11 +441,14 @@ where len_val |= 0x8000_0000; } let total = 4 + data.len() + padding_len; - let mut frame_buf = Vec::with_capacity(total); + frame_buf.clear(); + frame_buf.reserve(total); frame_buf.extend_from_slice(&len_val.to_le_bytes()); frame_buf.extend_from_slice(data); if padding_len > 0 { - frame_buf.extend_from_slice(&rng.bytes(padding_len)); + let start = frame_buf.len(); + frame_buf.resize(start + padding_len, 0); + rng.fill(&mut frame_buf[start..]); } client_writer .write_all(&frame_buf) diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index dd9589e..6d83761 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -223,7 +223,7 @@ pub(crate) struct RpcWriter { impl RpcWriter { pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { let frame = build_rpc_frame(self.seq_no, payload, self.crc_mode); - self.seq_no += 1; + self.seq_no = self.seq_no.wrapping_add(1); let pad = (16 - (frame.len() % 16)) % 16; let mut buf = frame;