fix(socket): validate ack_timeout_secs and check setsockopt rc

This commit is contained in:
David Osipov 2026-03-11 21:10:58 +04:00
parent 8b5cbb7b4b
commit 40dc6a39c1
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
1 changed files with 41 additions and 4 deletions

View File

@ -68,17 +68,27 @@ pub fn configure_client_socket(
// is implemented in relay_bidirectional instead
#[cfg(target_os = "linux")]
{
use std::io::{Error, ErrorKind};
use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd();
let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int;
unsafe {
let timeout_ms_u64 = ack_timeout_secs
.checked_mul(1000)
.ok_or_else(|| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs is too large"))?;
let timeout_ms = i32::try_from(timeout_ms_u64)
.map_err(|_| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs exceeds TCP_USER_TIMEOUT range"))?;
let rc = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT,
&timeout_ms as *const _ as *const libc::c_void,
&timeout_ms as *const libc::c_int as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
);
)
};
if rc != 0 {
return Err(Error::last_os_error());
}
}
@ -509,6 +519,33 @@ mod tests {
};
assert_eq!(&server_seen, b"ping");
}
#[cfg(target_os = "linux")]
#[tokio::test]
async fn test_configure_client_socket_ack_timeout_overflow_rejected() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(l) => l,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("bind failed: {e}"),
};
let addr = match listener.local_addr() {
Ok(addr) => addr,
Err(e) => panic!("local_addr failed: {e}"),
};
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
Err(e) => panic!("connect failed: {e}"),
};
let too_large_secs = (i32::MAX as u64 / 1000) + 1;
let err = match configure_client_socket(&stream, 30, too_large_secs) {
Ok(()) => panic!("expected overflow validation error"),
Err(e) => e,
};
assert_eq!(err.kind(), ErrorKind::InvalidInput);
}
#[test]
fn test_normalize_ip() {