//! TCP Socket Configuration #[cfg(target_os = "linux")] use std::collections::HashSet; #[cfg(target_os = "linux")] use std::fs; use std::io::Result; use std::net::{SocketAddr, IpAddr}; use std::time::Duration; use tokio::net::TcpStream; use socket2::{Socket, TcpKeepalive, Domain, Type, Protocol}; use tracing::debug; const DEFAULT_SOCKET_BUFFER_BYTES: usize = 256 * 1024; /// Configure TCP socket with recommended settings for proxy use #[allow(dead_code)] pub fn configure_tcp_socket( stream: &TcpStream, keepalive: bool, keepalive_interval: Duration, ) -> Result<()> { let socket = socket2::SockRef::from(stream); // Disable Nagle's algorithm for lower latency socket.set_nodelay(true)?; // Set keepalive if enabled if keepalive { let keepalive = TcpKeepalive::new() .with_time(keepalive_interval); // Platform-specific keepalive settings #[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))] let keepalive = keepalive.with_interval(keepalive_interval); socket.set_tcp_keepalive(&keepalive)?; } // Use explicit baseline buffers to reduce slow-start stalls on high RTT links. socket.set_recv_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; socket.set_send_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; Ok(()) } /// Configure socket for accepting client connections pub fn configure_client_socket( stream: &TcpStream, keepalive_secs: u64, #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] ack_timeout_secs: u64, ) -> Result<()> { let socket = socket2::SockRef::from(stream); // Disable Nagle's algorithm socket.set_nodelay(true)?; // Set keepalive let keepalive = TcpKeepalive::new() .with_time(Duration::from_secs(keepalive_secs)); #[cfg(any(target_os = "linux", target_os = "macos", target_os = "ios"))] let keepalive = keepalive.with_interval(Duration::from_secs(keepalive_secs)); socket.set_tcp_keepalive(&keepalive)?; // Keep explicit baseline buffers for predictable throughput across busy hosts. socket.set_recv_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; socket.set_send_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; // Set TCP user timeout (Linux only) // NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout // is implemented in relay_bidirectional instead #[cfg(target_os = "linux")] { use std::io::{Error, ErrorKind}; use std::os::unix::io::AsRawFd; let fd = stream.as_raw_fd(); let timeout_ms_u64 = ack_timeout_secs .checked_mul(1000) .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs is too large"))?; let timeout_ms = i32::try_from(timeout_ms_u64) .map_err(|_| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs exceeds TCP_USER_TIMEOUT range"))?; let rc = unsafe { libc::setsockopt( fd, libc::IPPROTO_TCP, libc::TCP_USER_TIMEOUT, &timeout_ms as *const libc::c_int as *const libc::c_void, std::mem::size_of::() as libc::socklen_t, ) }; if rc != 0 { return Err(Error::last_os_error()); } } Ok(()) } /// Set socket to send RST on close (for masking) #[allow(dead_code)] pub fn set_linger_zero(stream: &TcpStream) -> Result<()> { let socket = socket2::SockRef::from(stream); socket.set_linger(Some(Duration::ZERO))?; Ok(()) } /// Create a new TCP socket for outgoing connections #[allow(dead_code)] pub fn create_outgoing_socket(addr: SocketAddr) -> Result { create_outgoing_socket_bound(addr, None) } /// Create a new TCP socket for outgoing connections, optionally bound to a specific interface pub fn create_outgoing_socket_bound(addr: SocketAddr, bind_addr: Option) -> Result { let domain = if addr.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 }; let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; // Set non-blocking socket.set_nonblocking(true)?; // Disable Nagle socket.set_nodelay(true)?; socket.set_recv_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; socket.set_send_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; if let Some(bind_ip) = bind_addr { let bind_sock_addr = SocketAddr::new(bind_ip, 0); socket.bind(&bind_sock_addr.into())?; debug!("Bound outgoing socket to {}", bind_ip); } Ok(socket) } /// Get local address of a socket #[allow(dead_code)] pub fn get_local_addr(stream: &TcpStream) -> Option { stream.local_addr().ok() } /// Resolve primary IP address of a network interface by name. /// Returns the first address matching the requested family (IPv4/IPv6). #[cfg(unix)] pub fn resolve_interface_ip(name: &str, want_ipv6: bool) -> Option { use nix::ifaddrs::getifaddrs; if let Ok(addrs) = getifaddrs() { for iface in addrs { if iface.interface_name == name && let Some(address) = iface.address { if let Some(v4) = address.as_sockaddr_in() { if !want_ipv6 { return Some(IpAddr::V4(v4.ip())); } } else if let Some(v6) = address.as_sockaddr_in6() && want_ipv6 { return Some(IpAddr::V6(v6.ip())); } } } } None } /// Stub for non-Unix platforms: interface name resolution unsupported. #[cfg(not(unix))] pub fn resolve_interface_ip(_name: &str, _want_ipv6: bool) -> Option { None } /// Get peer address of a socket #[allow(dead_code)] pub fn get_peer_addr(stream: &TcpStream) -> Option { stream.peer_addr().ok() } /// Check if address is IPv6 #[allow(dead_code)] pub fn is_ipv6(addr: &SocketAddr) -> bool { addr.is_ipv6() } /// Parse IPv4-mapped IPv6 address to IPv4 pub fn normalize_ip(addr: SocketAddr) -> SocketAddr { match addr { SocketAddr::V6(v6) => { if let Some(v4) = v6.ip().to_ipv4_mapped() { SocketAddr::new(std::net::IpAddr::V4(v4), v6.port()) } else { addr } } _ => addr, } } /// Socket options for server listening #[derive(Debug, Clone)] pub struct ListenOptions { /// Enable SO_REUSEADDR pub reuse_addr: bool, /// Enable SO_REUSEPORT (Linux/BSD) pub reuse_port: bool, /// Backlog size pub backlog: u32, /// IPv6 only (disable dual-stack) pub ipv6_only: bool, } impl Default for ListenOptions { fn default() -> Self { Self { reuse_addr: true, reuse_port: true, backlog: 1024, ipv6_only: false, } } } /// Create a listening socket with the specified options pub fn create_listener(addr: SocketAddr, options: &ListenOptions) -> Result { let domain = if addr.is_ipv4() { Domain::IPV4 } else { Domain::IPV6 }; let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; if options.reuse_addr { socket.set_reuse_address(true)?; } #[cfg(unix)] if options.reuse_port { socket.set_reuse_port(true)?; } if addr.is_ipv6() && options.ipv6_only { socket.set_only_v6(true)?; } socket.set_nonblocking(true)?; socket.bind(&addr.into())?; socket.listen(options.backlog as i32)?; debug!(addr = %addr, "Created listening socket"); Ok(socket) } /// Best-effort process list for listeners occupying the same local TCP port. #[derive(Debug, Clone)] pub struct ListenerProcessInfo { pub pid: u32, pub process: String, } /// Find processes currently listening on the local TCP port of `addr`. /// Returns an empty list when unsupported or when no owners can be resolved. pub fn find_listener_processes(addr: SocketAddr) -> Vec { #[cfg(target_os = "linux")] { find_listener_processes_linux(addr) } #[cfg(not(target_os = "linux"))] { let _ = addr; Vec::new() } } #[cfg(target_os = "linux")] fn find_listener_processes_linux(addr: SocketAddr) -> Vec { let inodes = listening_inodes_for_port(addr); if inodes.is_empty() { return Vec::new(); } let mut out = Vec::new(); let proc_entries = match fs::read_dir("/proc") { Ok(entries) => entries, Err(_) => return out, }; for entry in proc_entries.flatten() { let pid = match entry.file_name().to_string_lossy().parse::() { Ok(pid) => pid, Err(_) => continue, }; let fd_dir = entry.path().join("fd"); let fd_entries = match fs::read_dir(fd_dir) { Ok(entries) => entries, Err(_) => continue, }; let mut matched = false; for fd in fd_entries.flatten() { let link_target = match fs::read_link(fd.path()) { Ok(link) => link, Err(_) => continue, }; let link_str = link_target.to_string_lossy(); let Some(rest) = link_str.strip_prefix("socket:[") else { continue; }; let Some(inode_str) = rest.strip_suffix(']') else { continue; }; let Ok(inode) = inode_str.parse::() else { continue; }; if inodes.contains(&inode) { matched = true; break; } } if matched { let process = fs::read_to_string(entry.path().join("comm")) .ok() .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .unwrap_or_else(|| "unknown".to_string()); out.push(ListenerProcessInfo { pid, process }); } } out.sort_by_key(|p| p.pid); out.dedup_by_key(|p| p.pid); out } #[cfg(target_os = "linux")] fn listening_inodes_for_port(addr: SocketAddr) -> HashSet { let path = match addr { SocketAddr::V4(_) => "/proc/net/tcp", SocketAddr::V6(_) => "/proc/net/tcp6", }; let mut inodes = HashSet::new(); let Ok(data) = fs::read_to_string(path) else { return inodes; }; for line in data.lines().skip(1) { let cols: Vec<&str> = line.split_whitespace().collect(); if cols.len() < 10 { continue; } // LISTEN state in /proc/net/tcp* if cols[3] != "0A" { continue; } let Some(port_hex) = cols[1].split(':').nth(1) else { continue; }; let Ok(port) = u16::from_str_radix(port_hex, 16) else { continue; }; if port != addr.port() { continue; } if let Ok(inode) = cols[9].parse::() { inodes.insert(inode); } } inodes } #[cfg(test)] mod tests { use super::*; use std::io::ErrorKind; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; #[tokio::test] async fn test_configure_socket() { let listener = match TcpListener::bind("127.0.0.1:0").await { Ok(l) => l, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("bind failed: {e}"), }; let addr = listener.local_addr().unwrap(); let stream = match TcpStream::connect(addr).await { Ok(s) => s, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("connect failed: {e}"), }; if let Err(e) = configure_tcp_socket(&stream, true, Duration::from_secs(30)) { if e.kind() == ErrorKind::PermissionDenied { return; } panic!("configure_tcp_socket failed: {e}"); } } #[tokio::test] async fn test_configure_client_socket() { let listener = match TcpListener::bind("127.0.0.1:0").await { Ok(l) => l, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("bind failed: {e}"), }; let addr = match listener.local_addr() { Ok(addr) => addr, Err(e) => panic!("local_addr failed: {e}"), }; let stream = match TcpStream::connect(addr).await { Ok(s) => s, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("connect failed: {e}"), }; if let Err(e) = configure_client_socket(&stream, 30, 30) { if e.kind() == ErrorKind::PermissionDenied { return; } panic!("configure_client_socket failed: {e}"); } } #[tokio::test] async fn test_configure_client_socket_zero_ack_timeout() { let listener = match TcpListener::bind("127.0.0.1:0").await { Ok(l) => l, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("bind failed: {e}"), }; let addr = match listener.local_addr() { Ok(addr) => addr, Err(e) => panic!("local_addr failed: {e}"), }; let stream = match TcpStream::connect(addr).await { Ok(s) => s, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("connect failed: {e}"), }; if let Err(e) = configure_client_socket(&stream, 30, 0) { if e.kind() == ErrorKind::PermissionDenied { return; } panic!("configure_client_socket with zero ack timeout failed: {e}"); } } #[tokio::test] async fn test_configure_client_socket_roundtrip_io() { let listener = match TcpListener::bind("127.0.0.1:0").await { Ok(l) => l, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("bind failed: {e}"), }; let addr = match listener.local_addr() { Ok(addr) => addr, Err(e) => panic!("local_addr failed: {e}"), }; let server_task = tokio::spawn(async move { let (mut accepted, _) = match listener.accept().await { Ok(v) => v, Err(e) => panic!("accept failed: {e}"), }; let mut payload = [0u8; 4]; if let Err(e) = accepted.read_exact(&mut payload).await { panic!("server read_exact failed: {e}"); } if let Err(e) = accepted.write_all(b"pong").await { panic!("server write_all failed: {e}"); } payload }); let mut stream = match TcpStream::connect(addr).await { Ok(s) => s, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("connect failed: {e}"), }; if let Err(e) = configure_client_socket(&stream, 30, 30) { if e.kind() == ErrorKind::PermissionDenied { return; } panic!("configure_client_socket failed: {e}"); } if let Err(e) = stream.write_all(b"ping").await { panic!("client write_all failed: {e}"); } let mut reply = [0u8; 4]; if let Err(e) = stream.read_exact(&mut reply).await { panic!("client read_exact failed: {e}"); } assert_eq!(&reply, b"pong"); let server_seen = match server_task.await { Ok(value) => value, Err(e) => panic!("server task join failed: {e}"), }; assert_eq!(&server_seen, b"ping"); } #[cfg(target_os = "linux")] #[tokio::test] async fn test_configure_client_socket_ack_timeout_overflow_rejected() { let listener = match TcpListener::bind("127.0.0.1:0").await { Ok(l) => l, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("bind failed: {e}"), }; let addr = match listener.local_addr() { Ok(addr) => addr, Err(e) => panic!("local_addr failed: {e}"), }; let stream = match TcpStream::connect(addr).await { Ok(s) => s, Err(e) if e.kind() == ErrorKind::PermissionDenied => return, Err(e) => panic!("connect failed: {e}"), }; let too_large_secs = (i32::MAX as u64 / 1000) + 1; let err = match configure_client_socket(&stream, 30, too_large_secs) { Ok(()) => panic!("expected overflow validation error"), Err(e) => e, }; assert_eq!(err.kind(), ErrorKind::InvalidInput); } #[test] fn test_normalize_ip() { // IPv4 stays IPv4 let v4: SocketAddr = "192.168.1.1:8080".parse().unwrap(); assert_eq!(normalize_ip(v4), v4); // Pure IPv6 stays IPv6 let v6: SocketAddr = "[::1]:8080".parse().unwrap(); assert_eq!(normalize_ip(v6), v6); } #[test] fn test_listen_options_default() { let opts = ListenOptions::default(); assert!(opts.reuse_addr); assert!(opts.reuse_port); assert_eq!(opts.backlog, 1024); } }