mirror of https://github.com/telemt/telemt.git
fix(socket): validate ack_timeout_secs and check setsockopt rc
This commit is contained in:
parent
8b5cbb7b4b
commit
40dc6a39c1
|
|
@ -68,17 +68,27 @@ pub fn configure_client_socket(
|
|||
// is implemented in relay_bidirectional instead
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
let fd = stream.as_raw_fd();
|
||||
let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int;
|
||||
unsafe {
|
||||
let timeout_ms_u64 = ack_timeout_secs
|
||||
.checked_mul(1000)
|
||||
.ok_or_else(|| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs is too large"))?;
|
||||
let timeout_ms = i32::try_from(timeout_ms_u64)
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs exceeds TCP_USER_TIMEOUT range"))?;
|
||||
|
||||
let rc = unsafe {
|
||||
libc::setsockopt(
|
||||
fd,
|
||||
libc::IPPROTO_TCP,
|
||||
libc::TCP_USER_TIMEOUT,
|
||||
&timeout_ms as *const _ as *const libc::c_void,
|
||||
&timeout_ms as *const libc::c_int as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
||||
);
|
||||
)
|
||||
};
|
||||
if rc != 0 {
|
||||
return Err(Error::last_os_error());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -509,6 +519,33 @@ mod tests {
|
|||
};
|
||||
assert_eq!(&server_seen, b"ping");
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[tokio::test]
|
||||
async fn test_configure_client_socket_ack_timeout_overflow_rejected() {
|
||||
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||
Ok(l) => l,
|
||||
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||
Err(e) => panic!("bind failed: {e}"),
|
||||
};
|
||||
let addr = match listener.local_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => panic!("local_addr failed: {e}"),
|
||||
};
|
||||
|
||||
let stream = match TcpStream::connect(addr).await {
|
||||
Ok(s) => s,
|
||||
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||
Err(e) => panic!("connect failed: {e}"),
|
||||
};
|
||||
|
||||
let too_large_secs = (i32::MAX as u64 / 1000) + 1;
|
||||
let err = match configure_client_socket(&stream, 30, too_large_secs) {
|
||||
Ok(()) => panic!("expected overflow validation error"),
|
||||
Err(e) => e,
|
||||
};
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidInput);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_ip() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue