From 40dc6a39c1d3b54fd000c943ac947d3037785737 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Wed, 11 Mar 2026 21:10:58 +0400 Subject: [PATCH] fix(socket): validate ack_timeout_secs and check setsockopt rc --- src/transport/socket.rs | 45 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/src/transport/socket.rs b/src/transport/socket.rs index 54eb143..aa4dc01 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -68,17 +68,27 @@ pub fn configure_client_socket( // is implemented in relay_bidirectional instead #[cfg(target_os = "linux")] { + use std::io::{Error, ErrorKind}; use std::os::unix::io::AsRawFd; + let fd = stream.as_raw_fd(); - let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int; - unsafe { + let timeout_ms_u64 = ack_timeout_secs + .checked_mul(1000) + .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs is too large"))?; + let timeout_ms = i32::try_from(timeout_ms_u64) + .map_err(|_| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs exceeds TCP_USER_TIMEOUT range"))?; + + let rc = unsafe { libc::setsockopt( fd, libc::IPPROTO_TCP, libc::TCP_USER_TIMEOUT, - &timeout_ms as *const _ as *const libc::c_void, + &timeout_ms as *const libc::c_int as *const libc::c_void, std::mem::size_of::() as libc::socklen_t, - ); + ) + }; + if rc != 0 { + return Err(Error::last_os_error()); } } @@ -509,6 +519,33 @@ mod tests { }; assert_eq!(&server_seen, b"ping"); } + + #[cfg(target_os = "linux")] + #[tokio::test] + async fn test_configure_client_socket_ack_timeout_overflow_rejected() { + let listener = match TcpListener::bind("127.0.0.1:0").await { + Ok(l) => l, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("bind failed: {e}"), + }; + let addr = match listener.local_addr() { + Ok(addr) => addr, + Err(e) => panic!("local_addr failed: {e}"), + }; + + let stream = match TcpStream::connect(addr).await { + Ok(s) => s, + Err(e) if e.kind() == ErrorKind::PermissionDenied => return, + Err(e) => panic!("connect failed: {e}"), + }; + + let too_large_secs = (i32::MAX as u64 / 1000) + 1; + let err = match configure_client_socket(&stream, 30, too_large_secs) { + Ok(()) => panic!("expected overflow validation error"), + Err(e) => e, + }; + assert_eq!(err.kind(), ErrorKind::InvalidInput); + } #[test] fn test_normalize_ip() {