diff --git a/staging/vhost-device-sound/src/audio_backends.rs b/staging/vhost-device-sound/src/audio_backends.rs index 3ed42b5..a2153e6 100644 --- a/staging/vhost-device-sound/src/audio_backends.rs +++ b/staging/vhost-device-sound/src/audio_backends.rs @@ -15,14 +15,14 @@ use self::alsa::AlsaBackend; use self::null::NullBackend; #[cfg(feature = "pw-backend")] use self::pipewire::PwBackend; -use crate::{stream::Stream, BackendType, ControlMessage, Result}; +use crate::{stream::Stream, BackendType, ControlMessage, Result, VirtioSndPcmSetParams}; pub trait AudioBackend { fn write(&self, stream_id: u32) -> Result<()>; fn read(&self, stream_id: u32) -> Result<()>; - fn set_parameters(&self, _stream_id: u32, _: ControlMessage) -> Result<()> { + fn set_parameters(&self, _stream_id: u32, _: VirtioSndPcmSetParams) -> Result<()> { Ok(()) } diff --git a/staging/vhost-device-sound/src/audio_backends/alsa.rs b/staging/vhost-device-sound/src/audio_backends/alsa.rs index 618c886..ead8580 100644 --- a/staging/vhost-device-sound/src/audio_backends/alsa.rs +++ b/staging/vhost-device-sound/src/audio_backends/alsa.rs @@ -15,13 +15,11 @@ use alsa::{ pcm::{Access, Format, HwParams, State, PCM}, PollDescriptors, ValueOr, }; -use virtio_queue::Descriptor; -use vm_memory::Bytes; use super::AudioBackend; use crate::{ stream::{PCMState, Stream}, - virtio_sound::{self, VirtioSndPcmSetParams, VIRTIO_SND_S_BAD_MSG, VIRTIO_SND_S_NOT_SUPP}, + virtio_sound::{self, VirtioSndPcmSetParams, VIRTIO_SND_S_BAD_MSG}, ControlMessage, Direction, Error, Result as CrateResult, }; @@ -659,32 +657,22 @@ impl AudioBackend for AlsaBackend { Ok(()) } - fn set_parameters(&self, stream_id: u32, mut msg: ControlMessage) -> CrateResult<()> { + fn set_parameters(&self, stream_id: u32, request: VirtioSndPcmSetParams) -> CrateResult<()> { if stream_id >= self.streams.read().unwrap().len() as u32 { log::error!( "Received SetParameters action for stream id {} but there are only {} PCM streams.", stream_id, self.streams.read().unwrap().len() as u32 ); - msg.code = VIRTIO_SND_S_BAD_MSG; return Err(Error::StreamWithIdNotFound(stream_id)); } - let descriptors: Vec = msg.desc_chain.clone().collect(); - let desc_request = &descriptors[0]; - let request = msg - .desc_chain - .memory() - .read_obj::(desc_request.addr()) - .unwrap(); { let mut streams = self.streams.write().unwrap(); let st = &mut streams[stream_id as usize]; if let Err(err) = st.state.set_parameters() { log::error!("Stream {} set_parameters {}", stream_id, err); - msg.code = VIRTIO_SND_S_BAD_MSG; return Err(Error::Stream(err)); } else if !st.supports_format(request.format) || !st.supports_rate(request.rate) { - msg.code = VIRTIO_SND_S_NOT_SUPP; return Err(Error::UnexpectedAudioBackendConfiguration); } else { st.params.buffer_bytes = request.buffer_bytes; @@ -694,8 +682,6 @@ impl AudioBackend for AlsaBackend { st.params.format = request.format; st.params.rate = request.rate; } - // Manually drop msg for faster response: the kernel has a timeout. - drop(msg); } update_pcm( &self.pcms[stream_id as usize], diff --git a/staging/vhost-device-sound/src/audio_backends/pipewire.rs b/staging/vhost-device-sound/src/audio_backends/pipewire.rs index 22ed52d..e14c904 100644 --- a/staging/vhost-device-sound/src/audio_backends/pipewire.rs +++ b/staging/vhost-device-sound/src/audio_backends/pipewire.rs @@ -30,8 +30,6 @@ use spa::{ SPA_AUDIO_FORMAT_UNKNOWN, }, }; -use virtio_queue::Descriptor; -use vm_memory::Bytes; use super::AudioBackend; use crate::{ @@ -49,7 +47,6 @@ use crate::{ VIRTIO_SND_PCM_RATE_384000, VIRTIO_SND_PCM_RATE_44100, VIRTIO_SND_PCM_RATE_48000, VIRTIO_SND_PCM_RATE_5512, VIRTIO_SND_PCM_RATE_64000, VIRTIO_SND_PCM_RATE_8000, VIRTIO_SND_PCM_RATE_88200, VIRTIO_SND_PCM_RATE_96000, VIRTIO_SND_S_BAD_MSG, - VIRTIO_SND_S_NOT_SUPP, }, ControlMessage, Direction, Error, Result, Stream, }; @@ -163,42 +160,26 @@ impl AudioBackend for PwBackend { Ok(()) } - fn set_parameters(&self, stream_id: u32, mut msg: ControlMessage) -> Result<()> { - let descriptors: Vec = msg.desc_chain.clone().collect(); - let desc_request = &descriptors[0]; - let request = msg - .desc_chain - .memory() - .read_obj::(desc_request.addr()) - .unwrap(); - { - let stream_clone = self.stream_params.clone(); - let mut stream_params = stream_clone.write().unwrap(); - if let Some(st) = stream_params.get_mut(stream_id as usize) { - if let Err(err) = st.state.set_parameters() { - log::error!("Stream {} set_parameters {}", stream_id, err); - msg.code = VIRTIO_SND_S_BAD_MSG; - drop(msg); - return Err(Error::Stream(err)); - } else if !st.supports_format(request.format) || !st.supports_rate(request.rate) { - msg.code = VIRTIO_SND_S_NOT_SUPP; - drop(msg); - return Err(Error::UnexpectedAudioBackendConfiguration); - } else { - st.params.features = request.features; - st.params.buffer_bytes = request.buffer_bytes; - st.params.period_bytes = request.period_bytes; - st.params.channels = request.channels; - st.params.format = request.format; - st.params.rate = request.rate; - } + fn set_parameters(&self, stream_id: u32, request: VirtioSndPcmSetParams) -> Result<()> { + let stream_clone = self.stream_params.clone(); + let mut stream_params = stream_clone.write().unwrap(); + if let Some(st) = stream_params.get_mut(stream_id as usize) { + if let Err(err) = st.state.set_parameters() { + log::error!("Stream {} set_parameters {}", stream_id, err); + return Err(Error::Stream(err)); + } else if !st.supports_format(request.format) || !st.supports_rate(request.rate) { + return Err(Error::UnexpectedAudioBackendConfiguration); } else { - msg.code = VIRTIO_SND_S_BAD_MSG; - drop(msg); - return Err(Error::StreamWithIdNotFound(stream_id)); + st.params.features = request.features; + st.params.buffer_bytes = request.buffer_bytes; + st.params.period_bytes = request.period_bytes; + st.params.channels = request.channels; + st.params.format = request.format; + st.params.rate = request.rate; } + } else { + return Err(Error::StreamWithIdNotFound(stream_id)); } - drop(msg); Ok(()) } @@ -598,7 +579,8 @@ mod tests { use virtio_bindings::bindings::virtio_ring::{VRING_DESC_F_NEXT, VRING_DESC_F_WRITE}; use virtio_queue::{mock::MockSplitQueue, Descriptor, Queue, QueueOwnedT}; use vm_memory::{ - Address, ByteValued, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap, + Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, + GuestMemoryMmap, }; use super::{test_utils::PipewireTestHarness, *}; @@ -684,8 +666,14 @@ mod tests { let pw_backend = PwBackend::new(stream_params); assert_eq!(pw_backend.stream_hash.read().unwrap().len(), 0); assert_eq!(pw_backend.stream_listener.read().unwrap().len(), 0); - let msg = ctrlmsg(); - pw_backend.set_parameters(0, msg).unwrap(); + // set up minimal configuration for test + let request = VirtioSndPcmSetParams { + format: VIRTIO_SND_PCM_FMT_S16, + rate: VIRTIO_SND_PCM_RATE_11025, + channels: 1, + ..Default::default() + }; + pw_backend.set_parameters(0, request).unwrap(); pw_backend.prepare(0).unwrap(); pw_backend.start(0).unwrap(); pw_backend.write(0).unwrap(); @@ -706,15 +694,12 @@ mod tests { let pw_backend = PwBackend::new(stream_params); - let msg = ctrlmsg(); - - _ = pw_backend.set_parameters(0, msg.clone()); - let resp: VirtioSoundHeader = msg - .desc_chain - .memory() - .read_obj(msg.descriptor.addr()) - .unwrap(); - assert_eq!(resp.code, VIRTIO_SND_S_BAD_MSG); + let request = VirtioSndPcmSetParams::default(); + let res = pw_backend.set_parameters(0, request); + assert_eq!( + res.unwrap_err().to_string(), + Error::StreamWithIdNotFound(0).to_string() + ); for res in [ pw_backend.prepare(0), diff --git a/staging/vhost-device-sound/src/device.rs b/staging/vhost-device-sound/src/device.rs index f1fd0fc..988e13e 100644 --- a/staging/vhost-device-sound/src/device.rs +++ b/staging/vhost-device-sound/src/device.rs @@ -318,25 +318,23 @@ impl VhostUserSoundThread { if stream_id as usize >= self.streams_no { log::error!("{}", Error::from(StreamError::InvalidStreamId(stream_id))); resp.code = VIRTIO_SND_S_BAD_MSG.into(); - } else { - audio_backend - .read() - .unwrap() - .set_parameters( - stream_id, - ControlMessage { - kind: code, - code: VIRTIO_SND_S_OK, - desc_chain, - descriptor: desc_hdr, - vring: vring.clone(), - }, - ) - .unwrap(); - - // PcmSetParams needs check valid formats/rates; the audio backend will - // reply when it drops the ControlMessage. - continue; + } else if let Err(err) = audio_backend + .read() + .unwrap() + .set_parameters(stream_id, request) + { + match err { + Error::Stream(_) | Error::StreamWithIdNotFound(_) => { + resp.code = VIRTIO_SND_S_BAD_MSG.into() + } + Error::UnexpectedAudioBackendConfiguration => { + resp.code = VIRTIO_SND_S_NOT_SUPP.into() + } + _ => { + log::error!("{}", err); + resp.code = VIRTIO_SND_S_IO_ERR.into() + } + } } } ControlMessageKind::PcmPrepare => {