diff --git a/staging/vhost-device-sound/src/device.rs b/staging/vhost-device-sound/src/device.rs index 2b0ed63..9325c9b 100644 --- a/staging/vhost-device-sound/src/device.rs +++ b/staging/vhost-device-sound/src/device.rs @@ -30,7 +30,7 @@ use crate::{ audio_backends::{alloc_audio_backend, AudioBackend}, stream::{Buffer, Error as StreamError, Stream}, virtio_sound::*, - ControlMessageKind, Direction, Error, IOMessage, Result, SoundConfig, + ControlMessageKind, Direction, Error, IOMessage, QueueIdx, Result, SoundConfig, }; pub struct VhostUserSoundThread { @@ -38,7 +38,7 @@ pub struct VhostUserSoundThread { event_idx: bool, chmaps: Arc>>, jacks: Arc>>, - queue_indexes: Vec, + queue_indexes: Vec, streams: Arc>>, streams_no: usize, } @@ -49,11 +49,11 @@ impl VhostUserSoundThread { pub fn new( chmaps: Arc>>, jacks: Arc>>, - mut queue_indexes: Vec, + mut queue_indexes: Vec, streams: Arc>>, streams_no: usize, ) -> Result { - queue_indexes.sort(); + queue_indexes.sort_by_key(|idx| *idx as u16); Ok(Self { event_idx: false, @@ -70,7 +70,7 @@ impl VhostUserSoundThread { let mut queues_per_thread = 0u64; for idx in self.queue_indexes.iter() { - queues_per_thread |= 1u64 << idx + queues_per_thread |= 1u64 << *idx as u16 } queues_per_thread @@ -94,7 +94,10 @@ impl VhostUserSoundThread { let vring = &vrings .get(device_event as usize) .ok_or_else(|| Error::HandleUnknownEvent(device_event))?; - let queue_idx = self.queue_indexes[device_event as usize]; + let queue_idx = self + .queue_indexes + .get(device_event as usize) + .ok_or_else(|| Error::HandleUnknownEvent(device_event))?; if self.event_idx { // vm-virtio's Queue implementation only checks avail_index // once, so to properly support EVENT_IDX we need to keep @@ -103,11 +106,10 @@ impl VhostUserSoundThread { loop { vring.disable_notification().unwrap(); match queue_idx { - CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend), - EVENT_QUEUE_IDX => self.process_event(vring), - TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output), - RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input), - _ => Err(Error::HandleUnknownEvent(queue_idx).into()), + QueueIdx::Control => self.process_control(vring, audio_backend), + QueueIdx::Event => self.process_event(vring), + QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output), + QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input), }?; if !vring.enable_notification().unwrap() { break; @@ -116,11 +118,10 @@ impl VhostUserSoundThread { } else { // Without EVENT_IDX, a single call is enough. match queue_idx { - CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend), - EVENT_QUEUE_IDX => self.process_event(vring), - TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output), - RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input), - _ => Err(Error::HandleUnknownEvent(queue_idx).into()), + QueueIdx::Control => self.process_control(vring, audio_backend), + QueueIdx::Event => self.process_event(vring), + QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output), + QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input), }?; } Ok(()) @@ -635,21 +636,21 @@ impl VhostUserSoundBackend { RwLock::new(VhostUserSoundThread::new( chmaps.clone(), jacks.clone(), - vec![CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX], + vec![QueueIdx::Control, QueueIdx::Event], streams.clone(), streams_no, )?), RwLock::new(VhostUserSoundThread::new( chmaps.clone(), jacks.clone(), - vec![TX_QUEUE_IDX], + vec![QueueIdx::Tx], streams.clone(), streams_no, )?), RwLock::new(VhostUserSoundThread::new( chmaps, jacks, - vec![RX_QUEUE_IDX], + vec![QueueIdx::Rx], streams.clone(), streams_no, )?), @@ -659,10 +660,10 @@ impl VhostUserSoundBackend { chmaps, jacks, vec![ - CONTROL_QUEUE_IDX, - EVENT_QUEUE_IDX, - TX_QUEUE_IDX, - RX_QUEUE_IDX, + QueueIdx::Control, + QueueIdx::Event, + QueueIdx::Tx, + QueueIdx::Rx, ], streams.clone(), streams_no, @@ -832,7 +833,7 @@ mod tests { let chmaps = Arc::new(RwLock::new(vec![])); let jacks = Arc::new(RwLock::new(vec![])); - let queue_indexes = vec![1, 2, 3]; + let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx]; let streams = vec![Stream::default()]; let streams_no = streams.len(); let streams = Arc::new(RwLock::new(streams)); @@ -927,7 +928,7 @@ mod tests { let chmaps = Arc::new(RwLock::new(vec![])); let jacks = Arc::new(RwLock::new(vec![])); - let queue_indexes = vec![1, 2, 3]; + let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx]; let streams = Arc::new(RwLock::new(vec![])); let streams_no = 0; let thread = diff --git a/staging/vhost-device-sound/src/lib.rs b/staging/vhost-device-sound/src/lib.rs index aa1c6d7..d2ceefe 100644 --- a/staging/vhost-device-sound/src/lib.rs +++ b/staging/vhost-device-sound/src/lib.rs @@ -94,7 +94,43 @@ impl TryFrom for Direction { Ok(match val { virtio_sound::VIRTIO_SND_D_OUTPUT => Self::Output, virtio_sound::VIRTIO_SND_D_INPUT => Self::Input, - other => return Err(Error::InvalidMessageValue(stringify!(Direction), other)), + other => { + return Err(Error::InvalidMessageValue( + stringify!(Direction), + other.into(), + )) + } + }) + } +} + +/// Queue index. +/// +/// Type safe enum for CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX, TX_QUEUE_IDX, +/// RX_QUEUE_IDX. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u16)] +pub enum QueueIdx { + #[doc(alias = "CONTROL_QUEUE_IDX")] + Control = virtio_sound::CONTROL_QUEUE_IDX, + #[doc(alias = "EVENT_QUEUE_IDX")] + Event = virtio_sound::EVENT_QUEUE_IDX, + #[doc(alias = "TX_QUEUE_IDX")] + Tx = virtio_sound::TX_QUEUE_IDX, + #[doc(alias = "RX_QUEUE_IDX")] + Rx = virtio_sound::RX_QUEUE_IDX, +} + +impl TryFrom for QueueIdx { + type Error = Error; + + fn try_from(val: u16) -> std::result::Result { + Ok(match val { + virtio_sound::CONTROL_QUEUE_IDX => Self::Control, + virtio_sound::EVENT_QUEUE_IDX => Self::Event, + virtio_sound::TX_QUEUE_IDX => Self::Tx, + virtio_sound::RX_QUEUE_IDX => Self::Rx, + other => return Err(Error::InvalidMessageValue(stringify!(QueueIdx), other)), }) } } @@ -117,7 +153,7 @@ pub enum Error { #[error("Invalid control message code {0}")] InvalidControlMessage(u32), #[error("Invalid value in {0}: {1}")] - InvalidMessageValue(&'static str, u8), + InvalidMessageValue(&'static str, u16), #[error("Failed to create a new EventFd")] EventFdCreate(IoError), #[error("Request missing data buffer")] @@ -389,6 +425,25 @@ mod tests { let val = 42; Direction::try_from(val).unwrap_err(); + + assert_eq!( + QueueIdx::try_from(virtio_sound::CONTROL_QUEUE_IDX).unwrap(), + QueueIdx::Control + ); + assert_eq!( + QueueIdx::try_from(virtio_sound::EVENT_QUEUE_IDX).unwrap(), + QueueIdx::Event + ); + assert_eq!( + QueueIdx::try_from(virtio_sound::TX_QUEUE_IDX).unwrap(), + QueueIdx::Tx + ); + assert_eq!( + QueueIdx::try_from(virtio_sound::RX_QUEUE_IDX).unwrap(), + QueueIdx::Rx + ); + let val = virtio_sound::NUM_QUEUES; + QueueIdx::try_from(val).unwrap_err(); } #[test]