mirror of https://github.com/telemt/telemt.git
fix(socket): validate ack_timeout_secs and check setsockopt rc
This commit is contained in:
parent
8b5cbb7b4b
commit
40dc6a39c1
|
|
@ -68,17 +68,27 @@ pub fn configure_client_socket(
|
||||||
// is implemented in relay_bidirectional instead
|
// is implemented in relay_bidirectional instead
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
{
|
{
|
||||||
|
use std::io::{Error, ErrorKind};
|
||||||
use std::os::unix::io::AsRawFd;
|
use std::os::unix::io::AsRawFd;
|
||||||
|
|
||||||
let fd = stream.as_raw_fd();
|
let fd = stream.as_raw_fd();
|
||||||
let timeout_ms = (ack_timeout_secs * 1000) as libc::c_int;
|
let timeout_ms_u64 = ack_timeout_secs
|
||||||
unsafe {
|
.checked_mul(1000)
|
||||||
|
.ok_or_else(|| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs is too large"))?;
|
||||||
|
let timeout_ms = i32::try_from(timeout_ms_u64)
|
||||||
|
.map_err(|_| Error::new(ErrorKind::InvalidInput, "ack_timeout_secs exceeds TCP_USER_TIMEOUT range"))?;
|
||||||
|
|
||||||
|
let rc = unsafe {
|
||||||
libc::setsockopt(
|
libc::setsockopt(
|
||||||
fd,
|
fd,
|
||||||
libc::IPPROTO_TCP,
|
libc::IPPROTO_TCP,
|
||||||
libc::TCP_USER_TIMEOUT,
|
libc::TCP_USER_TIMEOUT,
|
||||||
&timeout_ms as *const _ as *const libc::c_void,
|
&timeout_ms as *const libc::c_int as *const libc::c_void,
|
||||||
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
|
||||||
);
|
)
|
||||||
|
};
|
||||||
|
if rc != 0 {
|
||||||
|
return Err(Error::last_os_error());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -510,6 +520,33 @@ mod tests {
|
||||||
assert_eq!(&server_seen, b"ping");
|
assert_eq!(&server_seen, b"ping");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "linux")]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_configure_client_socket_ack_timeout_overflow_rejected() {
|
||||||
|
let listener = match TcpListener::bind("127.0.0.1:0").await {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||||
|
Err(e) => panic!("bind failed: {e}"),
|
||||||
|
};
|
||||||
|
let addr = match listener.local_addr() {
|
||||||
|
Ok(addr) => addr,
|
||||||
|
Err(e) => panic!("local_addr failed: {e}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = match TcpStream::connect(addr).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) if e.kind() == ErrorKind::PermissionDenied => return,
|
||||||
|
Err(e) => panic!("connect failed: {e}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let too_large_secs = (i32::MAX as u64 / 1000) + 1;
|
||||||
|
let err = match configure_client_socket(&stream, 30, too_large_secs) {
|
||||||
|
Ok(()) => panic!("expected overflow validation error"),
|
||||||
|
Err(e) => e,
|
||||||
|
};
|
||||||
|
assert_eq!(err.kind(), ErrorKind::InvalidInput);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_normalize_ip() {
|
fn test_normalize_ip() {
|
||||||
// IPv4 stays IPv4
|
// IPv4 stays IPv4
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue