sound: add QueueIdx enum for virtio queue indices

Add type safe enum to use instead of raw u16 values, which we have to
validate every time we use them.

Signed-off-by: Manos Pitsidianakis <manos.pitsidianakis@linaro.org>
This commit is contained in:
Manos Pitsidianakis 2023-12-13 13:48:50 +02:00 committed by Manos Pitsidianakis
parent 35de89df16
commit 5012ffa7fc
2 changed files with 83 additions and 27 deletions

View File

@ -30,7 +30,7 @@ use crate::{
audio_backends::{alloc_audio_backend, AudioBackend},
stream::{Buffer, Error as StreamError, Stream},
virtio_sound::*,
ControlMessageKind, Direction, Error, IOMessage, Result, SoundConfig,
ControlMessageKind, Direction, Error, IOMessage, QueueIdx, Result, SoundConfig,
};
pub struct VhostUserSoundThread {
@ -38,7 +38,7 @@ pub struct VhostUserSoundThread {
event_idx: bool,
chmaps: Arc<RwLock<Vec<VirtioSoundChmapInfo>>>,
jacks: Arc<RwLock<Vec<VirtioSoundJackInfo>>>,
queue_indexes: Vec<u16>,
queue_indexes: Vec<QueueIdx>,
streams: Arc<RwLock<Vec<Stream>>>,
streams_no: usize,
}
@ -49,11 +49,11 @@ impl VhostUserSoundThread {
pub fn new(
chmaps: Arc<RwLock<Vec<VirtioSoundChmapInfo>>>,
jacks: Arc<RwLock<Vec<VirtioSoundJackInfo>>>,
mut queue_indexes: Vec<u16>,
mut queue_indexes: Vec<QueueIdx>,
streams: Arc<RwLock<Vec<Stream>>>,
streams_no: usize,
) -> Result<Self> {
queue_indexes.sort();
queue_indexes.sort_by_key(|idx| *idx as u16);
Ok(Self {
event_idx: false,
@ -70,7 +70,7 @@ impl VhostUserSoundThread {
let mut queues_per_thread = 0u64;
for idx in self.queue_indexes.iter() {
queues_per_thread |= 1u64 << idx
queues_per_thread |= 1u64 << *idx as u16
}
queues_per_thread
@ -94,7 +94,10 @@ impl VhostUserSoundThread {
let vring = &vrings
.get(device_event as usize)
.ok_or_else(|| Error::HandleUnknownEvent(device_event))?;
let queue_idx = self.queue_indexes[device_event as usize];
let queue_idx = self
.queue_indexes
.get(device_event as usize)
.ok_or_else(|| Error::HandleUnknownEvent(device_event))?;
if self.event_idx {
// vm-virtio's Queue implementation only checks avail_index
// once, so to properly support EVENT_IDX we need to keep
@ -103,11 +106,10 @@ impl VhostUserSoundThread {
loop {
vring.disable_notification().unwrap();
match queue_idx {
CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend),
EVENT_QUEUE_IDX => self.process_event(vring),
TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output),
RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input),
_ => Err(Error::HandleUnknownEvent(queue_idx).into()),
QueueIdx::Control => self.process_control(vring, audio_backend),
QueueIdx::Event => self.process_event(vring),
QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output),
QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input),
}?;
if !vring.enable_notification().unwrap() {
break;
@ -116,11 +118,10 @@ impl VhostUserSoundThread {
} else {
// Without EVENT_IDX, a single call is enough.
match queue_idx {
CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend),
EVENT_QUEUE_IDX => self.process_event(vring),
TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output),
RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input),
_ => Err(Error::HandleUnknownEvent(queue_idx).into()),
QueueIdx::Control => self.process_control(vring, audio_backend),
QueueIdx::Event => self.process_event(vring),
QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output),
QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input),
}?;
}
Ok(())
@ -635,21 +636,21 @@ impl VhostUserSoundBackend {
RwLock::new(VhostUserSoundThread::new(
chmaps.clone(),
jacks.clone(),
vec![CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX],
vec![QueueIdx::Control, QueueIdx::Event],
streams.clone(),
streams_no,
)?),
RwLock::new(VhostUserSoundThread::new(
chmaps.clone(),
jacks.clone(),
vec![TX_QUEUE_IDX],
vec![QueueIdx::Tx],
streams.clone(),
streams_no,
)?),
RwLock::new(VhostUserSoundThread::new(
chmaps,
jacks,
vec![RX_QUEUE_IDX],
vec![QueueIdx::Rx],
streams.clone(),
streams_no,
)?),
@ -659,10 +660,10 @@ impl VhostUserSoundBackend {
chmaps,
jacks,
vec![
CONTROL_QUEUE_IDX,
EVENT_QUEUE_IDX,
TX_QUEUE_IDX,
RX_QUEUE_IDX,
QueueIdx::Control,
QueueIdx::Event,
QueueIdx::Tx,
QueueIdx::Rx,
],
streams.clone(),
streams_no,
@ -832,7 +833,7 @@ mod tests {
let chmaps = Arc::new(RwLock::new(vec![]));
let jacks = Arc::new(RwLock::new(vec![]));
let queue_indexes = vec![1, 2, 3];
let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx];
let streams = vec![Stream::default()];
let streams_no = streams.len();
let streams = Arc::new(RwLock::new(streams));
@ -927,7 +928,7 @@ mod tests {
let chmaps = Arc::new(RwLock::new(vec![]));
let jacks = Arc::new(RwLock::new(vec![]));
let queue_indexes = vec![1, 2, 3];
let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx];
let streams = Arc::new(RwLock::new(vec![]));
let streams_no = 0;
let thread =

View File

@ -94,7 +94,43 @@ impl TryFrom<u8> for Direction {
Ok(match val {
virtio_sound::VIRTIO_SND_D_OUTPUT => Self::Output,
virtio_sound::VIRTIO_SND_D_INPUT => Self::Input,
other => return Err(Error::InvalidMessageValue(stringify!(Direction), other)),
other => {
return Err(Error::InvalidMessageValue(
stringify!(Direction),
other.into(),
))
}
})
}
}
/// Queue index.
///
/// Type safe enum for CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX, TX_QUEUE_IDX,
/// RX_QUEUE_IDX.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub enum QueueIdx {
#[doc(alias = "CONTROL_QUEUE_IDX")]
Control = virtio_sound::CONTROL_QUEUE_IDX,
#[doc(alias = "EVENT_QUEUE_IDX")]
Event = virtio_sound::EVENT_QUEUE_IDX,
#[doc(alias = "TX_QUEUE_IDX")]
Tx = virtio_sound::TX_QUEUE_IDX,
#[doc(alias = "RX_QUEUE_IDX")]
Rx = virtio_sound::RX_QUEUE_IDX,
}
impl TryFrom<u16> for QueueIdx {
type Error = Error;
fn try_from(val: u16) -> std::result::Result<Self, Self::Error> {
Ok(match val {
virtio_sound::CONTROL_QUEUE_IDX => Self::Control,
virtio_sound::EVENT_QUEUE_IDX => Self::Event,
virtio_sound::TX_QUEUE_IDX => Self::Tx,
virtio_sound::RX_QUEUE_IDX => Self::Rx,
other => return Err(Error::InvalidMessageValue(stringify!(QueueIdx), other)),
})
}
}
@ -117,7 +153,7 @@ pub enum Error {
#[error("Invalid control message code {0}")]
InvalidControlMessage(u32),
#[error("Invalid value in {0}: {1}")]
InvalidMessageValue(&'static str, u8),
InvalidMessageValue(&'static str, u16),
#[error("Failed to create a new EventFd")]
EventFdCreate(IoError),
#[error("Request missing data buffer")]
@ -389,6 +425,25 @@ mod tests {
let val = 42;
Direction::try_from(val).unwrap_err();
assert_eq!(
QueueIdx::try_from(virtio_sound::CONTROL_QUEUE_IDX).unwrap(),
QueueIdx::Control
);
assert_eq!(
QueueIdx::try_from(virtio_sound::EVENT_QUEUE_IDX).unwrap(),
QueueIdx::Event
);
assert_eq!(
QueueIdx::try_from(virtio_sound::TX_QUEUE_IDX).unwrap(),
QueueIdx::Tx
);
assert_eq!(
QueueIdx::try_from(virtio_sound::RX_QUEUE_IDX).unwrap(),
QueueIdx::Rx
);
let val = virtio_sound::NUM_QUEUES;
QueueIdx::try_from(val).unwrap_err();
}
#[test]