fixes UB for recvmmsg, simplifies function signature of recvmmsg and sendmmsg (#2120)

This commit is contained in:
Jan Adä 2023-10-01 20:59:36 +02:00 committed by GitHub
parent ca62a55a79
commit 9f4e87764f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 23 deletions

View File

@ -9,6 +9,9 @@ This project adheres to [Semantic Versioning](https://semver.org/).
- Fix `SigSet` incorrect implementation of `Eq`, `PartialEq` and `Hash`
([#1946](https://github.com/nix-rust/nix/pull/1946))
- Fixed the function signature of `recvmmsg`, potentially causing UB
([#2119](https://github.com/nix-rust/nix/issues/2119))
### Changed
- The following APIs now take an implementation of `AsFd` rather than a
@ -33,6 +36,8 @@ This project adheres to [Semantic Versioning](https://semver.org/).
relaxed lifetime requirements relative to 0.27.1.
([#2136](https://github.com/nix-rust/nix/pull/2136))
- Simplified the function signatures of `recvmmsg` and `sendmmsg`
## [0.27.1] - 2023-08-28
### Fixed

View File

@ -1552,11 +1552,11 @@ pub fn sendmmsg<'a, XS, AS, C, I, S>(
flags: MsgFlags
) -> crate::Result<MultiResults<'a, S>>
where
XS: IntoIterator<Item = &'a I>,
XS: IntoIterator<Item = I>,
AS: AsRef<[Option<S>]>,
I: AsRef<[IoSlice<'a>]> + 'a,
C: AsRef<[ControlMessage<'a>]> + 'a,
S: SockaddrLike + 'a
I: AsRef<[IoSlice<'a>]>,
C: AsRef<[ControlMessage<'a>]>,
S: SockaddrLike,
{
let mut count = 0;
@ -1564,11 +1564,11 @@ pub fn sendmmsg<'a, XS, AS, C, I, S>(
for (i, ((slice, addr), mmsghdr)) in slices.into_iter().zip(addrs.as_ref()).zip(data.items.iter_mut() ).enumerate() {
let p = &mut mmsghdr.msg_hdr;
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
p.msg_iov = slice.as_ref().as_ptr().cast_mut().cast();
p.msg_iovlen = slice.as_ref().len() as _;
p.msg_namelen = addr.as_ref().map_or(0, S::len);
p.msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr) as _;
p.msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr).cast_mut().cast();
// Encode each cmsg. This must happen after initializing the header because
// CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields.
@ -1583,9 +1583,16 @@ pub fn sendmmsg<'a, XS, AS, C, I, S>(
pmhdr = unsafe { CMSG_NXTHDR(p, pmhdr) };
}
count = i+1;
// Doing an unchecked addition is alright here, as the only way to obtain an instance of `MultiHeaders`
// is through the `preallocate` function, which takes an `usize` as an argument to define its size,
// which also provides an upper bound for the size of this zipped iterator. Thus, `i < usize::MAX` or in
// other words: `count` doesn't overflow
count = i + 1;
}
// SAFETY: all pointers are guaranteed to be valid for the scope of this function. `count` does represent the
// maximum number of messages that can be sent safely (i.e. `count` is the minimum of the sizes of `slices`,
// `data.items` and `addrs`)
let sent = Errno::result(unsafe {
libc::sendmmsg(
fd,
@ -1711,14 +1718,19 @@ pub fn recvmmsg<'a, XS, S, I>(
mut timeout: Option<crate::sys::time::TimeSpec>,
) -> crate::Result<MultiResults<'a, S>>
where
XS: IntoIterator<Item = &'a I>,
I: AsRef<[IoSliceMut<'a>]> + 'a,
XS: IntoIterator<Item = I>,
I: AsMut<[IoSliceMut<'a>]>,
{
let mut count = 0;
for (i, (slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() {
for (i, (mut slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() {
let p = &mut mmsghdr.msg_hdr;
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
p.msg_iovlen = slice.as_ref().len() as _;
p.msg_iov = slice.as_mut().as_mut_ptr().cast();
p.msg_iovlen = slice.as_mut().len() as _;
// Doing an unchecked addition is alright here, as the only way to obtain an instance of `MultiHeaders`
// is through the `preallocate` function, which takes an `usize` as an argument to define its size,
// which also provides an upper bound for the size of this zipped iterator. Thus, `i < usize::MAX` or in
// other words: `count` doesn't overflow
count = i + 1;
}
@ -1726,6 +1738,8 @@ where
.as_mut()
.map_or_else(std::ptr::null_mut, |t| t as *mut _ as *mut libc::timespec);
// SAFETY: all pointers are guaranteed to be valid for the scope of this function. `count` does represent the
// maximum number of messages that can be received safely (i.e. `count` is the minimum of the sizes of `slices` and `data.items`)
let received = Errno::result(unsafe {
libc::recvmmsg(
fd,
@ -1743,6 +1757,7 @@ where
})
}
/// Iterator over results of [`recvmmsg`]/[`sendmmsg`]
#[cfg(any(
target_os = "linux",
target_os = "android",
@ -1750,9 +1765,6 @@ where
target_os = "netbsd",
))]
#[derive(Debug)]
/// Iterator over results of [`recvmmsg`]/[`sendmmsg`]
///
///
pub struct MultiResults<'a, S> {
// preallocated structures
rmm: &'a MultiHeaders<S>,
@ -1903,7 +1915,7 @@ mod test {
let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter(), flags, Some(t))?;
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter_mut(), flags, Some(t))?;
for rmsg in recv {
#[cfg(not(any(qemu, target_arch = "aarch64")))]

View File

@ -564,7 +564,7 @@ mod recvfrom {
let res: Vec<RecvMsg<SockaddrIn>> = recvmmsg(
rsock.as_raw_fd(),
&mut data,
msgs.iter(),
msgs.iter_mut(),
MsgFlags::empty(),
None,
)
@ -652,7 +652,7 @@ mod recvfrom {
let res: Vec<RecvMsg<SockaddrIn>> = recvmmsg(
rsock.as_raw_fd(),
&mut data,
msgs.iter(),
msgs.iter_mut(),
MsgFlags::MSG_DONTWAIT,
None,
)
@ -2324,12 +2324,17 @@ fn test_recvmmsg_timestampns() {
// Receive the message
let mut buffer = vec![0u8; message.len()];
let cmsgspace = nix::cmsg_space!(TimeSpec);
let iov = vec![[IoSliceMut::new(&mut buffer)]];
let mut iov = vec![[IoSliceMut::new(&mut buffer)]];
let mut data = MultiHeaders::preallocate(1, Some(cmsgspace));
let r: Vec<RecvMsg<()>> =
recvmmsg(in_socket.as_raw_fd(), &mut data, iov.iter(), flags, None)
.unwrap()
.collect();
let r: Vec<RecvMsg<()>> = recvmmsg(
in_socket.as_raw_fd(),
&mut data,
iov.iter_mut(),
flags,
None,
)
.unwrap()
.collect();
let rtime = match r[0].cmsgs().next() {
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
Some(_) => panic!("Unexpected control message"),