From 2f3b0b1f383dc861ce4c4ac2219f8d171b77ea36 Mon Sep 17 00:00:00 2001 From: Matias Ezequiel Vara Larsen Date: Thu, 8 Feb 2024 12:14:43 +0100 Subject: [PATCH] sound: Use Reader/Writer for tx/rx queue This commit makes the handling of the tx/rx to rely on the Reader/Writer modules. The commit also fixes the tests at streams.rs and replaces the naming buffer for request. The coverage is also updated. Signed-off-by: Matias Ezequiel Vara Larsen --- coverage_config_x86_64.json | 2 +- vhost-device-sound/src/audio_backends/alsa.rs | 71 +++--- vhost-device-sound/src/audio_backends/null.rs | 8 +- .../src/audio_backends/pipewire.rs | 38 ++- vhost-device-sound/src/device.rs | 123 +++------- vhost-device-sound/src/lib.rs | 32 ++- vhost-device-sound/src/stream.rs | 223 ++++++++++-------- 7 files changed, 240 insertions(+), 257 deletions(-) diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index 346636c..24d7585 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1,5 +1,5 @@ { - "coverage_score": 77.55, + "coverage_score": 77.63, "exclude_path": "", "crate_features": "" } diff --git a/vhost-device-sound/src/audio_backends/alsa.rs b/vhost-device-sound/src/audio_backends/alsa.rs index f85fa42..6bfc2f8 100644 --- a/vhost-device-sound/src/audio_backends/alsa.rs +++ b/vhost-device-sound/src/audio_backends/alsa.rs @@ -196,21 +196,21 @@ fn write_samples_direct( mmap: &mut alsa::direct::pcm::MmapPlayback, ) -> AResult { while mmap.avail() > 0 { - let Some(buffer) = stream.buffers.front_mut() else { + let Some(request) = stream.requests.front_mut() else { return Ok(false); }; if !matches!(stream.state, PCMState::Start) { return Ok(false); } - let n_bytes = buffer.desc_len() as usize - buffer.pos; + let n_bytes = request.len() - request.pos; let mut buf = vec![0; n_bytes]; - let read_bytes = match buffer.read_output(&mut buf) { + let read_bytes = match request.read_output(&mut buf) { Err(err) => { log::error!( - "Could not read TX buffer from guest, dropping it immediately: {}", + "Could not read TX request from guest, dropping it immediately: {}", err ); - stream.buffers.pop_front(); + stream.requests.pop_front(); continue; } Ok(v) => v, @@ -220,10 +220,10 @@ fn write_samples_direct( let frames = mmap.write(&mut iter); let written_bytes = pcm.frames_to_bytes(frames); if let Ok(written_bytes) = usize::try_from(written_bytes) { - buffer.pos += written_bytes; + request.pos += written_bytes; } - if buffer.pos >= buffer.desc_len() as usize { - stream.buffers.pop_front(); + if request.pos >= request.len() { + stream.requests.pop_front(); } } match mmap.status().state() { @@ -241,7 +241,7 @@ fn read_samples_direct( mmap: &mut alsa::direct::pcm::MmapCapture, ) -> AResult { while mmap.avail() > 0 { - let Some(buffer) = stream.buffers.front_mut() else { + let Some(request) = stream.requests.front_mut() else { return Ok(false); }; @@ -253,12 +253,12 @@ fn read_samples_direct( // [`vm_memory::volatile_memory::VolatileSlice`]) and we can't use alsa's readi // without a slice: use an intermediate buffer and copy it to the // descriptor. - let mut intermediate_buf = vec![0; buffer.desc_len() as usize - buffer.pos]; + let mut intermediate_buf = vec![0; request.len() - request.pos]; for (sample, byte) in intermediate_buf.iter_mut().zip(&mut iter) { *sample = byte; n_bytes += 1; } - if buffer + if request .write_input(&intermediate_buf[0..n_bytes]) .expect("Could not write data to guest memory") == 0 @@ -267,8 +267,8 @@ fn read_samples_direct( } drop(iter); - if buffer.pos as u32 >= buffer.desc_len() || mmap.avail() == 0 { - stream.buffers.pop_front(); + if request.pos >= request.len() || mmap.avail() == 0 { + stream.requests.pop_front(); } } @@ -306,28 +306,31 @@ fn write_samples_io( if avail != 0 { io.mmap(avail as usize, |buf| { let stream = &mut streams.write().unwrap()[stream_id]; - let Some(buffer) = stream.buffers.front_mut() else { + let Some(request) = stream.requests.front_mut() else { return 0; }; if !matches!(stream.state, PCMState::Start) { - stream.buffers.pop_front(); + stream.requests.pop_front(); return 0; } - let n_bytes = std::cmp::min(buf.len(), buffer.desc_len() as usize - buffer.pos); - // read_output() always reads (buffer.desc_len() - buffer.pos) bytes - let read_bytes = match buffer.read_output(&mut buf[0..n_bytes]) { + let n_bytes = std::cmp::min(buf.len(), request.len() - request.pos); + // read_output() always reads (request.len() - request.pos) bytes + let read_bytes = match request.read_output(&mut buf[0..n_bytes]) { Ok(v) => v, Err(err) => { - log::error!("Could not read TX buffer, dropping it immediately: {}", err); - stream.buffers.pop_front(); + log::error!( + "Could not read TX request, dropping it immediately: {}", + err + ); + stream.requests.pop_front(); return 0; } }; - buffer.pos += read_bytes as usize; - if buffer.pos as u32 >= buffer.desc_len() { - stream.buffers.pop_front(); + request.pos += read_bytes as usize; + if request.pos >= request.len() { + stream.requests.pop_front(); } p.bytes_to_frames(isize::try_from(read_bytes).unwrap()) .try_into() @@ -372,11 +375,11 @@ fn read_samples_io( return Ok(false); } let stream = &mut streams.write().unwrap()[stream_id]; - let Some(buffer) = stream.buffers.front_mut() else { + let Some(request) = stream.requests.front_mut() else { return Ok(false); }; if !matches!(stream.state, PCMState::Start) { - stream.buffers.pop_front(); + stream.requests.pop_front(); return Ok(false); } let mut frames_read = 0; @@ -385,16 +388,16 @@ fn read_samples_io( // [`vm_memory::volatile_memory::VolatileSlice`]) and we can't use alsa's readi // without a slice: use an intermediate buffer and copy it to the // descriptor. - let mut intermediate_buf = vec![0; buffer.desc_len() as usize - buffer.pos]; + let mut intermediate_buf = vec![0; request.len() - request.pos]; while let Some(frames) = io - .readi(&mut intermediate_buf[0..(buffer.desc_len() as usize - buffer.pos)]) + .readi(&mut intermediate_buf[0..(request.len() - request.pos)]) .map(std::num::NonZeroUsize::new)? .map(std::num::NonZeroUsize::get) { frames_read += frames; let n_bytes = usize::try_from(p.frames_to_bytes(frames.try_into().unwrap())).unwrap_or_default(); - if buffer + if request .write_input(&intermediate_buf[0..n_bytes]) .expect("Could not write data to guest memory") == 0 @@ -404,8 +407,8 @@ fn read_samples_io( } let bytes_read = p.frames_to_bytes(frames_read.try_into().unwrap()); - if buffer.pos as u32 >= buffer.desc_len() || bytes_read == 0 { - stream.buffers.pop_front(); + if request.pos >= request.len() || bytes_read == 0 { + stream.requests.pop_front(); } match p.state() { @@ -432,7 +435,7 @@ fn alsa_worker( let has_buffers = || -> bool { // Hold `streams` lock as short as possible. let lck = streams.read().unwrap(); - !lck[stream_id].buffers.is_empty() + !lck[stream_id].requests.is_empty() && matches!(lck[stream_id].state, PCMState::Start) }; // Run this loop till the stream's buffer vector is empty: @@ -710,12 +713,12 @@ impl AudioBackend for AlsaBackend { } // Stop worker thread self.senders[stream_id as usize].send(false).unwrap(); - // Drop pending stream buffers to complete pending I/O messages + // Drop pending stream requests to complete pending I/O messages // - // This will release buffers even if state transition is invalid. If it is + // This will release requests even if state transition is invalid. If it is // invalid, we won't be in a valid device state anyway so better to get rid of // them and free the virt queue. - std::mem::take(&mut streams[stream_id as usize].buffers); + std::mem::take(&mut streams[stream_id as usize].requests); Ok(()) } diff --git a/vhost-device-sound/src/audio_backends/null.rs b/vhost-device-sound/src/audio_backends/null.rs index 05175f9..0eedbf1 100644 --- a/vhost-device-sound/src/audio_backends/null.rs +++ b/vhost-device-sound/src/audio_backends/null.rs @@ -18,7 +18,7 @@ impl NullBackend { impl AudioBackend for NullBackend { fn write(&self, stream_id: u32) -> Result<()> { log::trace!("NullBackend write stream_id {}", stream_id); - _ = std::mem::take(&mut self.streams.write().unwrap()[stream_id as usize].buffers); + _ = std::mem::take(&mut self.streams.write().unwrap()[stream_id as usize].requests); Ok(()) } @@ -46,7 +46,7 @@ mod tests { null_backend.write(0).unwrap(); let streams = streams.read().unwrap(); - assert_eq!(streams[0].buffers.len(), 0); + assert_eq!(streams[0].requests.len(), 0); } #[test] @@ -57,8 +57,8 @@ mod tests { null_backend.read(0).unwrap(); - // buffer lengths should remain unchanged + // requests lengths should remain unchanged let streams = streams.read().unwrap(); - assert_eq!(streams[0].buffers.len(), 0); + assert_eq!(streams[0].requests.len(), 0); } } diff --git a/vhost-device-sound/src/audio_backends/pipewire.rs b/vhost-device-sound/src/audio_backends/pipewire.rs index b07e529..000032b 100644 --- a/vhost-device-sound/src/audio_backends/pipewire.rs +++ b/vhost-device-sound/src/audio_backends/pipewire.rs @@ -367,10 +367,10 @@ impl AudioBackend for PwBackend { }) .process(move |stream, _data| match stream.dequeue_buffer() { None => debug!("No buffer recieved"), - Some(mut buf) => { + Some(mut req) => { match direction { Direction::Input => { - let datas = buf.datas_mut(); + let datas = req.datas_mut(); let data = &mut datas[0]; let mut n_samples = data.chunk().size() as usize; let Some(slice) = data.data() else { @@ -383,17 +383,15 @@ impl AudioBackend for PwBackend { let mut start = 0; while n_samples > 0 { - let Some(buffer) = stream.buffers.front_mut() else { + let Some(request) = stream.requests.front_mut() else { return; }; - let avail = usize::try_from(buffer.desc_len()) - .unwrap() - .saturating_sub(buffer.pos); + let avail = request.len().saturating_sub(request.pos); let n_bytes = n_samples.min(avail); let p = &slice[start..start + n_bytes]; - if buffer + if request .write_input(p) .expect("Could not write data to guest memory") == 0 @@ -404,13 +402,13 @@ impl AudioBackend for PwBackend { n_samples -= n_bytes; start += n_bytes; - if buffer.pos >= buffer.desc_len() as usize { - stream.buffers.pop_front(); + if request.pos >= request.len() { + stream.requests.pop_front(); } } } Direction::Output => { - let datas = buf.datas_mut(); + let datas = req.datas_mut(); let frame_size = info.channels * size_of::() as u32; let data = &mut datas[0]; let n_bytes = if let Some(slice) = data.data() { @@ -419,15 +417,13 @@ impl AudioBackend for PwBackend { let streams = streams .get_mut(stream_id as usize) .expect("Stream does not exist"); - let Some(buffer) = streams.buffers.front_mut() else { + let Some(request) = streams.requests.front_mut() else { return; }; - let mut start = buffer.pos; + let mut start = request.pos; - let avail = usize::try_from(buffer.desc_len()) - .unwrap() - .saturating_sub(start); + let avail = request.len().saturating_sub(start); if avail < n_bytes { n_bytes = avail; @@ -444,16 +440,16 @@ impl AudioBackend for PwBackend { } else { // read_output() always reads (buffer.desc_len() - // buffer.pos) bytes - buffer + request .read_output(p) .expect("failed to read buffer from guest"); start += n_bytes; - buffer.pos = start; + request.pos = start; - if start >= buffer.desc_len() as usize { - streams.buffers.pop_front(); + if start >= request.len() { + streams.requests.pop_front(); } } n_bytes @@ -516,7 +512,7 @@ impl AudioBackend for PwBackend { .get(&stream_id) .expect("Could not find stream with this id in `stream_hash`."); stream.disconnect().expect("could not disconnect stream"); - std::mem::take(&mut st_buffer[stream_id as usize].buffers); + std::mem::take(&mut st_buffer[stream_id as usize].requests); stream_hash.remove(&stream_id); stream_listener.remove(&stream_id); lock_guard.unlock(); @@ -613,7 +609,7 @@ mod tests { pw_backend.stop(0).unwrap(); pw_backend.release(0).unwrap(); let streams = streams.read().unwrap(); - assert_eq!(streams[0].buffers.len(), 0); + assert_eq!(streams[0].requests.len(), 0); } #[test] diff --git a/vhost-device-sound/src/device.rs b/vhost-device-sound/src/device.rs index 3116b58..5fea9ea 100644 --- a/vhost-device-sound/src/device.rs +++ b/vhost-device-sound/src/device.rs @@ -18,8 +18,7 @@ use virtio_bindings::{ }; use virtio_queue::{DescriptorChain, QueueOwnedT}; use vm_memory::{ - ByteValued, Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryLoadGuard, GuestMemoryMmap, - Le32, + ByteValued, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryLoadGuard, GuestMemoryMmap, Le32, }; use vmm_sys_util::{ epoll::EventSet, @@ -28,7 +27,7 @@ use vmm_sys_util::{ use crate::{ audio_backends::{alloc_audio_backend, AudioBackend}, - stream::{Buffer, Error as StreamError, Stream}, + stream::{Error as StreamError, Request, Stream}, virtio_sound::*, ControlMessageKind, Direction, Error, IOMessage, QueueIdx, Result, SoundConfig, }; @@ -421,107 +420,47 @@ impl VhostUserSoundThread { return Ok(()); } - // Instead of counting descriptor chain lengths, encode the "parsing" logic in - // an enumeration. Then, the compiler will complain about any unhandled - // match {} cases if any part of the code is changed. This makes invalid - // states unrepresentable in the source code. - #[derive(Copy, Clone, PartialEq, Debug)] - enum IoState { - Ready, - WaitingBufferForStreamId(u32), - Done, - } - // Keep log of stream IDs to wake up, in case the guest has queued more than // one. let mut stream_ids = BTreeSet::default(); + let mem = atomic_mem.memory(); + for desc_chain in requests { - let mut state = IoState::Ready; - let mut buffers = vec![]; - let descriptors: Vec<_> = desc_chain.clone().collect(); let message = Arc::new(IOMessage { vring: vring.clone(), status: VIRTIO_SND_S_OK.into(), used_len: 0.into(), latency_bytes: 0.into(), desc_chain: desc_chain.clone(), - response_descriptor: descriptors.last().cloned().ok_or_else(|| { - log::error!("Received IO request with an empty descriptor chain."); - Error::UnexpectedDescriptorCount(0) - })?, }); - for descriptor in &descriptors { - match state { - IoState::Done => { - return Err(Error::UnexpectedDescriptorCount(descriptors.len()).into()); - } - IoState::Ready - if matches!(direction, Direction::Output) && descriptor.is_write_only() => - { - if descriptor.len() as usize != size_of::() { - return Err(Error::UnexpectedDescriptorSize( - size_of::(), - descriptor.len(), - ) - .into()); - } - state = IoState::Done; - } - IoState::WaitingBufferForStreamId(stream_id) - if descriptor.len() as usize == size_of::() => - { - self.streams.write().unwrap()[stream_id as usize] - .buffers - .extend(std::mem::take(&mut buffers).into_iter()); - state = IoState::Done; - } - IoState::Ready - if descriptor.len() as usize != size_of::() => - { - return Err(Error::UnexpectedDescriptorSize( - size_of::(), - descriptor.len(), - ) - .into()); - } - IoState::Ready => { - let xfer = desc_chain - .memory() - .read_obj::(descriptor.addr()) - .map_err(|_| Error::DescriptorReadFailed)?; - let stream_id: u32 = xfer.stream_id.into(); - stream_ids.insert(stream_id); - state = IoState::WaitingBufferForStreamId(stream_id); - } - IoState::WaitingBufferForStreamId(stream_id) - if descriptor.len() as usize == size_of::() => - { - return Err(Error::UnexpectedDescriptorSize( - u32::from( - self.streams.read().unwrap()[stream_id as usize] - .params - .period_bytes, - ) as usize, - descriptor.len(), - ) - .into()); - } - IoState::WaitingBufferForStreamId(_) => { - // In the case of TX/Playback: - // - // Rather than copying the content of a descriptor, buffer keeps a pointer - // to it. When we copy just after the request is enqueued, the guest's - // userspace may or may not have updated the buffer contents. Guest driver - // simply moves buffers from the used ring to the available ring without - // knowing whether the content has been updated. The device only reads the - // buffer from guest memory when the audio engine requires it, which is - // about after a period thus ensuring that the buffer is up-to-date. - buffers.push(Buffer::new(*descriptor, Arc::clone(&message), direction)); - } + let mut reader = desc_chain + .clone() + .reader(&mem) + .map_err(|_| Error::DescriptorReadFailed)?; + + let in_header: VirtioSoundPcmXfer = + reader.read_obj().map_err(|_| Error::DescriptorReadFailed)?; + + let stream_id: u32 = in_header.stream_id.into(); + + stream_ids.insert(stream_id); + + let payload_len = match direction { + Direction::Output => reader.available_bytes() - reader.bytes_read(), + Direction::Input => { + let writer = desc_chain + .clone() + .writer(&mem) + .map_err(|_| Error::DescriptorReadFailed)?; + writer.available_bytes() - size_of::() } - } + }; + + self.streams.write().unwrap()[stream_id as usize] + .requests + .push_back(Request::new(payload_len, Arc::clone(&message), direction)); } if !stream_ids.is_empty() { @@ -738,7 +677,9 @@ mod tests { use tempfile::tempdir; use virtio_bindings::virtio_ring::VRING_DESC_F_WRITE; use virtio_queue::{mock::MockSplitQueue, Descriptor}; - use vm_memory::{Address, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap}; + use vm_memory::{ + Address, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap, + }; use super::*; use crate::BackendType; diff --git a/vhost-device-sound/src/lib.rs b/vhost-device-sound/src/lib.rs index 46323b2..ea49e7f 100644 --- a/vhost-device-sound/src/lib.rs +++ b/vhost-device-sound/src/lib.rs @@ -45,6 +45,7 @@ pub mod virtio_sound; use std::{ convert::TryFrom, io::{Error as IoError, ErrorKind}, + mem::size_of, sync::Arc, }; @@ -53,9 +54,7 @@ pub use stream::Stream; use thiserror::Error as ThisError; use vhost_user_backend::{VhostUserDaemon, VringRwLock, VringT}; use virtio_sound::*; -use vm_memory::{ - ByteValued, Bytes, GuestMemoryAtomic, GuestMemoryLoadGuard, GuestMemoryMmap, Le32, -}; +use vm_memory::{ByteValued, GuestMemoryAtomic, GuestMemoryLoadGuard, GuestMemoryMmap, Le32}; use crate::device::VhostUserSoundBackend; @@ -296,7 +295,6 @@ pub struct IOMessage { pub used_len: std::sync::atomic::AtomicU32, pub latency_bytes: std::sync::atomic::AtomicU32, desc_chain: SoundDescriptorChain, - response_descriptor: virtio_queue::Descriptor, vring: VringRwLock, } @@ -312,11 +310,27 @@ impl Drop for IOMessage { let used_len: u32 = self.used_len.load(std::sync::atomic::Ordering::SeqCst); log::trace!("dropping IOMessage {:?}", resp); - if let Err(err) = self - .desc_chain - .memory() - .write_obj(resp, self.response_descriptor.addr()) - { + let mem = self.desc_chain.memory(); + + let mut writer = match self.desc_chain.clone().writer(mem) { + Ok(writer) => writer, + Err(err) => { + log::error!("Error::DescriptorReadFailed: {}", err); + return; + } + }; + + let offset = writer.available_bytes() - size_of::(); + + let mut writer_status = match writer.split_at(offset) { + Ok(writer_status) => writer_status, + Err(err) => { + log::error!("Error::DescriptorReadFailed: {}", err); + return; + } + }; + + if let Err(err) = writer_status.write_obj(resp) { log::error!("Error::DescriptorWriteFailed: {}", err); return; } diff --git a/vhost-device-sound/src/stream.rs b/vhost-device-sound/src/stream.rs index 30480c2..7e1cfbf 100644 --- a/vhost-device-sound/src/stream.rs +++ b/vhost-device-sound/src/stream.rs @@ -1,10 +1,16 @@ // Manos Pitsidianakis // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause -use std::{collections::VecDeque, sync::Arc}; +use std::{ + collections::VecDeque, + convert::TryFrom, + io::{Read, Write}, + mem::size_of, + sync::Arc, +}; use thiserror::Error as ThisError; -use vm_memory::{Address, Bytes, Le32, Le64}; +use vm_memory::{Le32, Le64}; use crate::{virtio_sound::*, Direction, IOMessage, SUPPORTED_FORMATS, SUPPORTED_RATES}; @@ -182,7 +188,7 @@ pub struct Stream { pub channels_min: u8, pub channels_max: u8, pub state: PCMState, - pub buffers: VecDeque, + pub requests: VecDeque, } impl Default for Stream { @@ -196,7 +202,7 @@ impl Default for Stream { channels_min: 1, channels_max: 6, state: Default::default(), - buffers: VecDeque::new(), + requests: VecDeque::new(), } } } @@ -241,85 +247,110 @@ impl Default for PcmParams { } } -pub struct Buffer { - data_descriptor: virtio_queue::Descriptor, +pub struct Request { pub pos: usize, + len: usize, pub message: Arc, direction: Direction, } -impl std::fmt::Debug for Buffer { +impl std::fmt::Debug for Request { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct(stringify!(Buffer)) + fmt.debug_struct(stringify!(Request)) .field("pos", &self.pos) + .field("len", &self.len) .field("direction", &self.direction) .field("message", &Arc::as_ptr(&self.message)) .finish() } } -impl Buffer { - pub fn new( - data_descriptor: virtio_queue::Descriptor, - message: Arc, - direction: Direction, - ) -> Self { +impl Request { + pub fn new(len: usize, message: Arc, direction: Direction) -> Self { Self { pos: 0, - data_descriptor, + len, message, direction, } } pub fn read_output(&self, buf: &mut [u8]) -> Result { - let addr = self.data_descriptor.addr(); - let offset = self.pos as u64; - let len = self + let mem = self.message.desc_chain.memory(); + + let mut reader = self .message .desc_chain - .memory() - .read( - buf, - addr.checked_add(offset) - .expect("invalid guest memory address"), - ) + .clone() + .reader(mem) .map_err(|_| Error::DescriptorReadFailed)?; - Ok(len as u32) + + let mut reader_content = reader + .split_at(size_of::() + self.pos) + .map_err(|_| Error::DescriptorReadFailed)?; + + let bytes_read = reader_content + .read(buf) + .map_err(|_| Error::DescriptorReadFailed)?; + + Ok(bytes_read as u32) } pub fn write_input(&mut self, buf: &[u8]) -> Result { - if self.desc_len() <= self.pos as u32 { - return Ok(0); - } - let addr = self.data_descriptor.addr(); - let offset = self.pos as u64; - let len = self + let mem = self.message.desc_chain.memory(); + + let mut writer = self .message .desc_chain - .memory() - .write( - buf, - addr.checked_add(offset) - .expect("invalid guest memory address"), - ) + .clone() + .writer(mem) + .map_err(|_| Error::DescriptorReadFailed)?; + + let mut _status = writer + .split_at(self.len) + .map_err(|_| Error::DescriptorReadFailed)?; + + let mut write_content = writer + .split_at(self.pos) + .map_err(|_| Error::DescriptorReadFailed)?; + + let bytes_written = write_content + .write(buf) .map_err(|_| Error::DescriptorWriteFailed)?; - self.pos += len; - Ok(len as u32) + + self.pos += bytes_written; + + Ok(bytes_written as u32) } #[inline] /// Returns the length of the sound data [`virtio_queue::Descriptor`]. - pub fn desc_len(&self) -> u32 { - self.data_descriptor.len() + pub const fn len(&self) -> usize { + self.len + } + + pub const fn is_empty(&self) -> bool { + self.len == 0 } } -impl Drop for Buffer { +impl Drop for Request { fn drop(&mut self) { + // Since used_len is 32 bits, but self.len() may be bigger + // than that, the spec is unclear about how to handle this + // case, so when converting from usize to u32, saturate + // when conversion overflows + let payload_len = match u32::try_from(self.len()) { + Ok(len) => len, + Err(len) => { + log::warn!("used_len {} overflows u32", len); + u32::MAX + } + }; + match self.direction { Direction::Input => { - let used_len = std::cmp::min(self.pos as u32, self.desc_len()); + let used_len = std::cmp::min(self.pos as u32, payload_len); self.message .used_len .fetch_add(used_len, std::sync::atomic::Ordering::SeqCst); @@ -330,10 +361,10 @@ impl Drop for Buffer { Direction::Output => { self.message .latency_bytes - .fetch_add(self.desc_len(), std::sync::atomic::Ordering::SeqCst); + .fetch_add(payload_len, std::sync::atomic::Ordering::SeqCst); } } - log::trace!("dropping {:?} buffer {:?}", self.direction, self); + log::trace!("dropping {:?} request {:?}", self.direction, self); } } @@ -345,7 +376,8 @@ mod tests { use virtio_bindings::bindings::virtio_ring::{VRING_DESC_F_NEXT, VRING_DESC_F_WRITE}; use virtio_queue::{mock::MockSplitQueue, Descriptor, Queue, QueueOwnedT}; use vm_memory::{ - Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap, + Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, + GuestMemoryMmap, }; use super::*; @@ -355,6 +387,7 @@ mod tests { fn prepare_desc_chain( start_addr: GuestAddress, hdr: R, + payload_len: u32, response_len: u32, ) -> SoundDescriptorChain { let mem = &GuestMemoryMmap::<()>::from_ranges(&[(start_addr, 0x1000)]).unwrap(); @@ -364,7 +397,9 @@ mod tests { let desc_out = Descriptor::new( next_addr, - std::mem::size_of::() as u32, + (std::mem::size_of::() as u32) + .checked_add(payload_len) + .unwrap(), VRING_DESC_F_NEXT as u16, index + 1, ); @@ -395,21 +430,22 @@ mod tests { .unwrap() } - fn iomsg() -> IOMessage { - let hdr = VirtioSndPcmSetParams::default(); + fn iomsg(payload_len: u32, response_len: u32) -> IOMessage { + let hdr = VirtioSoundPcmHeader::default(); let memr = GuestMemoryAtomic::new( GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(), ); let vring = VringRwLock::new(memr, 0x1000).unwrap(); - let mem = &GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x1000)]).unwrap(); - let vq = MockSplitQueue::new(mem, 16); - let next_addr = vq.desc_table().total_size() + 0x100; IOMessage { status: VIRTIO_SND_S_OK.into(), latency_bytes: 0.into(), used_len: 0.into(), - desc_chain: prepare_desc_chain::(GuestAddress(0), hdr, 1), - response_descriptor: Descriptor::new(next_addr, 0x200, VRING_DESC_F_NEXT as u16, 1), + desc_chain: prepare_desc_chain::( + GuestAddress(0), + hdr, + payload_len, + response_len, + ), vring, } } @@ -421,17 +457,16 @@ mod tests { #[test] fn test_logging() { - let data_descriptor = Descriptor::new(0, 0, 0, 0); - let msg = iomsg(); + let msg = iomsg(0, size_of::() as u32); let message = Arc::new(msg); let direction = Direction::Input; - let buffer = Buffer::new(data_descriptor, message, direction); + let request = Request::new(0, message, direction); assert_eq!(format!("{direction:?}"), "Input"); assert_eq!( - format!("{buffer:?}"), + format!("{request:?}"), format!( - "Buffer {{ pos: 0, direction: Input, message: {:?} }}", - &Arc::as_ptr(&buffer.message) + "Request {{ pos: 0, len: 0, direction: Input, message: {:?} }}", + &Arc::as_ptr(&request.message) ) ); } @@ -561,58 +596,52 @@ mod tests { } #[test] - fn test_buffer_read_output() { - let msg = iomsg(); - let message = Arc::new(msg); - let desc_msg = iomsg(); - let buffer = Buffer::new( - desc_msg.desc_chain.clone().readable().next().unwrap(), - message, - Direction::Output, - ); - + fn test_request_read_output() { let mut buf = vec![0; 5]; - buffer.read_output(&mut buf).unwrap(); - } - - #[test] - fn test_buffer_write_input() { - let msg = iomsg(); + let msg = iomsg(buf.len() as u32, size_of::() as u32); let message = Arc::new(msg); - let desc_msg = iomsg(); - let mut buffer = Buffer::new( - desc_msg.desc_chain.clone().readable().next().unwrap(), - message, - Direction::Input, - ); + let request = Request::new(buf.len(), message, Direction::Output); - let buf = vec![0; 5]; - buffer.write_input(&buf).unwrap(); + let len = request.read_output(&mut buf).unwrap(); + assert_eq!(len, buf.len() as u32); } #[test] - fn test_buffer_fn() { - let data_descriptor = Descriptor::new(0, 0, 0, 0); - let msg = iomsg(); + fn test_request_write_input() { + let buf = vec![0; 5]; + let msg = iomsg( + 0, + size_of::() as u32 + buf.len() as u32, + ); + let message = Arc::new(msg); + let mut request = Request::new(buf.len(), message, Direction::Input); + + request.write_input(&buf).unwrap(); + } + + #[test] + fn test_request_fn() { + let msg = iomsg(0, size_of::() as u32); let message = Arc::new(msg); let direction = Direction::Input; - let buffer = Buffer::new(data_descriptor, message, direction); + let request = Request::new(0, message, direction); - assert_eq!(buffer.desc_len() as usize, buffer.pos); - assert_eq!(buffer.desc_len(), 0); - assert_eq!(buffer.direction, Direction::Input); + assert_eq!(request.len() as usize, request.pos); + assert_eq!(request.len(), 0); + assert_eq!(request.direction, Direction::Input); - // Test debug format representation for Buffer + // Test debug format representation for Request let mut debug_output = String::new(); // Format the Debug representation into the String. - write!(&mut debug_output, "{:?}", buffer).unwrap(); + write!(&mut debug_output, "{:?}", request).unwrap(); let expected_debug = format!( - "Buffer {{ pos: {}, direction: {:?}, message: {:?} }}", - buffer.pos, - buffer.direction, - Arc::as_ptr(&buffer.message) + "Request {{ pos: {}, len: {}, direction: {:?}, message: {:?} }}", + request.len, + request.pos, + request.direction, + Arc::as_ptr(&request.message) ); assert_eq!(debug_output, expected_debug);