diff --git a/staging/vhost-device-sound/src/audio_backends/pipewire.rs b/staging/vhost-device-sound/src/audio_backends/pipewire.rs index f48d0db..09572a1 100644 --- a/staging/vhost-device-sound/src/audio_backends/pipewire.rs +++ b/staging/vhost-device-sound/src/audio_backends/pipewire.rs @@ -139,21 +139,22 @@ impl AudioBackend for PwBackend { { let stream_clone = self.stream_params.clone(); let mut stream_params = stream_clone.write().unwrap(); - let st = stream_params - .get_mut(stream_id as usize) - .expect("Stream does not exist"); - if let Err(err) = st.state.set_parameters() { - log::error!("Stream {} set_parameters {}", stream_id, err); - msg.code = VIRTIO_SND_S_BAD_MSG; - } else if !st.supports_format(request.format) || !st.supports_rate(request.rate) { - msg.code = VIRTIO_SND_S_NOT_SUPP; + if let Some(st) = stream_params.get_mut(stream_id as usize) { + if let Err(err) = st.state.set_parameters() { + log::error!("Stream {} set_parameters {}", stream_id, err); + msg.code = VIRTIO_SND_S_BAD_MSG; + } else if !st.supports_format(request.format) || !st.supports_rate(request.rate) { + msg.code = VIRTIO_SND_S_NOT_SUPP; + } else { + st.params.features = request.features; + st.params.buffer_bytes = request.buffer_bytes; + st.params.period_bytes = request.period_bytes; + st.params.channels = request.channels; + st.params.format = request.format; + st.params.rate = request.rate; + } } else { - st.params.features = request.features; - st.params.buffer_bytes = request.buffer_bytes; - st.params.period_bytes = request.period_bytes; - st.params.channels = request.channels; - st.params.format = request.format; - st.params.rate = request.rate; + msg.code = VIRTIO_SND_S_BAD_MSG; } } drop(msg); @@ -163,7 +164,12 @@ impl AudioBackend for PwBackend { fn prepare(&self, stream_id: u32) -> Result<()> { debug!("pipewire prepare"); - let prepare_result = self.stream_params.write().unwrap()[stream_id as usize] + let prepare_result = self + .stream_params + .write() + .unwrap() + .get_mut(stream_id as usize) + .ok_or_else(|| Error::StreamWithIdNotFound(stream_id))? .state .prepare(); if let Err(err) = prepare_result { @@ -401,69 +407,83 @@ impl AudioBackend for PwBackend { fn release(&self, stream_id: u32, mut msg: ControlMessage) -> Result<()> { debug!("pipewire backend, release function"); - let release_result = self.stream_params.write().unwrap()[stream_id as usize] + let release_result = self + .stream_params + .write() + .unwrap() + .get_mut(stream_id as usize) + .ok_or_else(|| { + msg.code = VIRTIO_SND_S_BAD_MSG; + Error::StreamWithIdNotFound(stream_id) + })? .state .release(); if let Err(err) = release_result { log::error!("Stream {} release {}", stream_id, err); msg.code = VIRTIO_SND_S_BAD_MSG; - } else { - let lock_guard = self.thread_loop.lock(); - let mut stream_hash = self.stream_hash.write().unwrap(); - let mut stream_listener = self.stream_listener.write().unwrap(); - let st_buffer = &mut self.stream_params.write().unwrap(); - - let Some(stream) = stream_hash.get(&stream_id) else { - return Err(Error::StreamWithIdNotFound(stream_id)); - }; - stream.disconnect().expect("could not disconnect stream"); - std::mem::take(&mut st_buffer[stream_id as usize].buffers); - stream_hash.remove(&stream_id); - stream_listener.remove(&stream_id); - - lock_guard.unlock(); + return Ok(()); } - + let lock_guard = self.thread_loop.lock(); + let mut stream_hash = self.stream_hash.write().unwrap(); + let mut stream_listener = self.stream_listener.write().unwrap(); + let st_buffer = &mut self.stream_params.write().unwrap(); + let stream = stream_hash + .get(&stream_id) + .expect("Could not find stream with this id in `stream_hash`."); + stream.disconnect().expect("could not disconnect stream"); + std::mem::take(&mut st_buffer[stream_id as usize].buffers); + stream_hash.remove(&stream_id); + stream_listener.remove(&stream_id); + lock_guard.unlock(); Ok(()) } fn start(&self, stream_id: u32) -> Result<()> { debug!("pipewire start"); - let start_result = self.stream_params.write().unwrap()[stream_id as usize] + let start_result = self + .stream_params + .write() + .unwrap() + .get_mut(stream_id as usize) + .ok_or_else(|| Error::StreamWithIdNotFound(stream_id))? .state .start(); if let Err(err) = start_result { // log the error and continue log::error!("Stream {} start {}", stream_id, err); - } else { - let lock_guard = self.thread_loop.lock(); - let stream_hash = self.stream_hash.read().unwrap(); - let Some(stream) = stream_hash.get(&stream_id) else { - return Err(Error::StreamWithIdNotFound(stream_id)); - }; - stream.set_active(true).expect("could not start stream"); - lock_guard.unlock(); + return Ok(()); } + let lock_guard = self.thread_loop.lock(); + let stream_hash = self.stream_hash.read().unwrap(); + let stream = stream_hash + .get(&stream_id) + .expect("Could not find stream with this id in `stream_hash`."); + stream.set_active(true).expect("could not start stream"); + lock_guard.unlock(); Ok(()) } fn stop(&self, stream_id: u32) -> Result<()> { debug!("pipewire stop"); - let stop_result = self.stream_params.write().unwrap()[stream_id as usize] + let stop_result = self + .stream_params + .write() + .unwrap() + .get_mut(stream_id as usize) + .ok_or_else(|| Error::StreamWithIdNotFound(stream_id))? .state .stop(); if let Err(err) = stop_result { log::error!("Stream {} stop {}", stream_id, err); - } else { - let lock_guard = self.thread_loop.lock(); - let stream_hash = self.stream_hash.read().unwrap(); - let Some(stream) = stream_hash.get(&stream_id) else { - return Err(Error::StreamWithIdNotFound(stream_id)); - }; - stream.set_active(false).expect("could not stop stream"); - lock_guard.unlock(); + return Ok(()); } - + let lock_guard = self.thread_loop.lock(); + let stream_hash = self.stream_hash.read().unwrap(); + let stream = stream_hash + .get(&stream_id) + .expect("Could not find stream with this id in `stream_hash`."); + stream.set_active(false).expect("could not stop stream"); + lock_guard.unlock(); Ok(()) } }