use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::{debug, info, trace}; use crate::config::ProxyConfig; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use crate::proxy::handshake::HandshakeSuccess; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, mut crypto_writer: CryptoWriter, success: HandshakeSuccess, me_pool: Arc, stats: Arc, _config: Arc, _buffer_pool: Arc, local_addr: SocketAddr, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, { let user = success.user.clone(); let peer = success.peer; let proto_tag = success.proto_tag; info!( user = %user, peer = %peer, dc = success.dc_idx, proto = ?proto_tag, mode = "middle_proxy", "Routing via Middle-End" ); let (conn_id, mut me_rx) = me_pool.registry().register().await; stats.increment_user_connects(&user); stats.increment_user_curr_connects(&user); let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag()); debug!( user = %user, conn_id, proto_flags = format_args!("0x{:08x}", proto_flags), "ME relay started" ); let translated_local_addr = me_pool.translate_our_addr(local_addr); let result: Result<()> = loop { tokio::select! { client_frame = read_client_payload(&mut crypto_reader, proto_tag) => { match client_frame { Ok(Some(payload)) => { trace!(conn_id, bytes = payload.len(), "C->ME frame"); stats.add_user_octets_from(&user, payload.len() as u64); me_pool.send_proxy_req( conn_id, success.dc_idx, peer, translated_local_addr, &payload, proto_flags, ).await?; } Ok(None) => { debug!(conn_id, "Client EOF"); let _ = me_pool.send_close(conn_id).await; break Ok(()); } Err(e) => break Err(e), } } me_msg = me_rx.recv() => { match me_msg { Some(MeResponse::Data { flags, data }) => { trace!(conn_id, bytes = data.len(), flags, "ME->C data"); stats.add_user_octets_to(&user, data.len() as u64); write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?; } Some(MeResponse::Ack(confirm)) => { trace!(conn_id, confirm, "ME->C quickack"); write_client_ack(&mut crypto_writer, proto_tag, confirm).await?; } Some(MeResponse::Close) => { debug!(conn_id, "ME sent close"); break Ok(()); } None => { debug!(conn_id, "ME channel closed"); break Err(ProxyError::Proxy("ME connection lost".into())); } } } } }; debug!(user = %user, conn_id, "ME relay cleanup"); me_pool.registry().unregister(conn_id).await; stats.decrement_user_curr_connects(&user); result } async fn read_client_payload( client_reader: &mut CryptoReader, proto_tag: ProtoTag, ) -> Result>> where R: AsyncRead + Unpin + Send + 'static, { let len = match proto_tag { ProtoTag::Abridged => { let mut first = [0u8; 1]; match client_reader.read_exact(&mut first).await { Ok(_) => {} Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(ProxyError::Io(e)), } let len_words = if (first[0] & 0x7f) == 0x7f { let mut ext = [0u8; 3]; client_reader .read_exact(&mut ext) .await .map_err(ProxyError::Io)?; u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize } else { (first[0] & 0x7f) as usize }; len_words .checked_mul(4) .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))? } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len_buf = [0u8; 4]; match client_reader.read_exact(&mut len_buf).await { Ok(_) => {} Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(ProxyError::Io(e)), } (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize } }; if len > 16 * 1024 * 1024 { return Err(ProxyError::Proxy(format!("Frame too large: {len}"))); } let mut payload = vec![0u8; len]; client_reader .read_exact(&mut payload) .await .map_err(ProxyError::Io)?; Ok(Some(payload)) } async fn write_client_payload( client_writer: &mut CryptoWriter, proto_tag: ProtoTag, flags: u32, data: &[u8], ) -> Result<()> where W: AsyncWrite + Unpin + Send + 'static, { let quickack = (flags & RPC_FLAG_QUICKACK) != 0; match proto_tag { ProtoTag::Abridged => { if data.len() % 4 != 0 { return Err(ProxyError::Proxy(format!( "Abridged payload must be 4-byte aligned, got {}", data.len() ))); } let len_words = data.len() / 4; if len_words < 0x7f { let mut first = len_words as u8; if quickack { first |= 0x80; } client_writer .write_all(&[first]) .await .map_err(ProxyError::Io)?; } else if len_words < (1 << 24) { let mut first = 0x7fu8; if quickack { first |= 0x80; } let lw = (len_words as u32).to_le_bytes(); client_writer .write_all(&[first, lw[0], lw[1], lw[2]]) .await .map_err(ProxyError::Io)?; } else { return Err(ProxyError::Proxy(format!( "Abridged frame too large: {}", data.len() ))); } client_writer .write_all(data) .await .map_err(ProxyError::Io)?; } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len = data.len() as u32; if quickack { len |= 0x8000_0000; } client_writer .write_all(&len.to_le_bytes()) .await .map_err(ProxyError::Io)?; client_writer .write_all(data) .await .map_err(ProxyError::Io)?; } } client_writer.flush().await.map_err(ProxyError::Io) } async fn write_client_ack( client_writer: &mut CryptoWriter, proto_tag: ProtoTag, confirm: u32, ) -> Result<()> where W: AsyncWrite + Unpin + Send + 'static, { let bytes = if proto_tag == ProtoTag::Abridged { confirm.to_be_bytes() } else { confirm.to_le_bytes() }; client_writer .write_all(&bytes) .await .map_err(ProxyError::Io)?; client_writer.flush().await.map_err(ProxyError::Io) }