sound/pipewire: handle StreamWithIdNotFound consistently

When handling controlq messages, handle StreamWithIdNotFound the same
way on all methods:

- if we are passed a &mut ControlMessage, set it to
  VIRTIO_SND_S_BAD_MSG.
- if we are passed a stream_id, return StreamWithIdNotFound.

Signed-off-by: Manos Pitsidianakis <manos.pitsidianakis@linaro.org>
This commit is contained in:
Manos Pitsidianakis 2023-10-04 12:30:00 +03:00 committed by Alex Bennée
parent 208a796061
commit aa16ef0699

View File

@ -139,21 +139,22 @@ impl AudioBackend for PwBackend {
{
let stream_clone = self.stream_params.clone();
let mut stream_params = stream_clone.write().unwrap();
let st = stream_params
.get_mut(stream_id as usize)
.expect("Stream does not exist");
if let Err(err) = st.state.set_parameters() {
log::error!("Stream {} set_parameters {}", stream_id, err);
msg.code = VIRTIO_SND_S_BAD_MSG;
} else if !st.supports_format(request.format) || !st.supports_rate(request.rate) {
msg.code = VIRTIO_SND_S_NOT_SUPP;
if let Some(st) = stream_params.get_mut(stream_id as usize) {
if let Err(err) = st.state.set_parameters() {
log::error!("Stream {} set_parameters {}", stream_id, err);
msg.code = VIRTIO_SND_S_BAD_MSG;
} else if !st.supports_format(request.format) || !st.supports_rate(request.rate) {
msg.code = VIRTIO_SND_S_NOT_SUPP;
} else {
st.params.features = request.features;
st.params.buffer_bytes = request.buffer_bytes;
st.params.period_bytes = request.period_bytes;
st.params.channels = request.channels;
st.params.format = request.format;
st.params.rate = request.rate;
}
} else {
st.params.features = request.features;
st.params.buffer_bytes = request.buffer_bytes;
st.params.period_bytes = request.period_bytes;
st.params.channels = request.channels;
st.params.format = request.format;
st.params.rate = request.rate;
msg.code = VIRTIO_SND_S_BAD_MSG;
}
}
drop(msg);
@ -163,7 +164,12 @@ impl AudioBackend for PwBackend {
fn prepare(&self, stream_id: u32) -> Result<()> {
debug!("pipewire prepare");
let prepare_result = self.stream_params.write().unwrap()[stream_id as usize]
let prepare_result = self
.stream_params
.write()
.unwrap()
.get_mut(stream_id as usize)
.ok_or_else(|| Error::StreamWithIdNotFound(stream_id))?
.state
.prepare();
if let Err(err) = prepare_result {
@ -401,69 +407,83 @@ impl AudioBackend for PwBackend {
fn release(&self, stream_id: u32, mut msg: ControlMessage) -> Result<()> {
debug!("pipewire backend, release function");
let release_result = self.stream_params.write().unwrap()[stream_id as usize]
let release_result = self
.stream_params
.write()
.unwrap()
.get_mut(stream_id as usize)
.ok_or_else(|| {
msg.code = VIRTIO_SND_S_BAD_MSG;
Error::StreamWithIdNotFound(stream_id)
})?
.state
.release();
if let Err(err) = release_result {
log::error!("Stream {} release {}", stream_id, err);
msg.code = VIRTIO_SND_S_BAD_MSG;
} else {
let lock_guard = self.thread_loop.lock();
let mut stream_hash = self.stream_hash.write().unwrap();
let mut stream_listener = self.stream_listener.write().unwrap();
let st_buffer = &mut self.stream_params.write().unwrap();
let Some(stream) = stream_hash.get(&stream_id) else {
return Err(Error::StreamWithIdNotFound(stream_id));
};
stream.disconnect().expect("could not disconnect stream");
std::mem::take(&mut st_buffer[stream_id as usize].buffers);
stream_hash.remove(&stream_id);
stream_listener.remove(&stream_id);
lock_guard.unlock();
return Ok(());
}
let lock_guard = self.thread_loop.lock();
let mut stream_hash = self.stream_hash.write().unwrap();
let mut stream_listener = self.stream_listener.write().unwrap();
let st_buffer = &mut self.stream_params.write().unwrap();
let stream = stream_hash
.get(&stream_id)
.expect("Could not find stream with this id in `stream_hash`.");
stream.disconnect().expect("could not disconnect stream");
std::mem::take(&mut st_buffer[stream_id as usize].buffers);
stream_hash.remove(&stream_id);
stream_listener.remove(&stream_id);
lock_guard.unlock();
Ok(())
}
fn start(&self, stream_id: u32) -> Result<()> {
debug!("pipewire start");
let start_result = self.stream_params.write().unwrap()[stream_id as usize]
let start_result = self
.stream_params
.write()
.unwrap()
.get_mut(stream_id as usize)
.ok_or_else(|| Error::StreamWithIdNotFound(stream_id))?
.state
.start();
if let Err(err) = start_result {
// log the error and continue
log::error!("Stream {} start {}", stream_id, err);
} else {
let lock_guard = self.thread_loop.lock();
let stream_hash = self.stream_hash.read().unwrap();
let Some(stream) = stream_hash.get(&stream_id) else {
return Err(Error::StreamWithIdNotFound(stream_id));
};
stream.set_active(true).expect("could not start stream");
lock_guard.unlock();
return Ok(());
}
let lock_guard = self.thread_loop.lock();
let stream_hash = self.stream_hash.read().unwrap();
let stream = stream_hash
.get(&stream_id)
.expect("Could not find stream with this id in `stream_hash`.");
stream.set_active(true).expect("could not start stream");
lock_guard.unlock();
Ok(())
}
fn stop(&self, stream_id: u32) -> Result<()> {
debug!("pipewire stop");
let stop_result = self.stream_params.write().unwrap()[stream_id as usize]
let stop_result = self
.stream_params
.write()
.unwrap()
.get_mut(stream_id as usize)
.ok_or_else(|| Error::StreamWithIdNotFound(stream_id))?
.state
.stop();
if let Err(err) = stop_result {
log::error!("Stream {} stop {}", stream_id, err);
} else {
let lock_guard = self.thread_loop.lock();
let stream_hash = self.stream_hash.read().unwrap();
let Some(stream) = stream_hash.get(&stream_id) else {
return Err(Error::StreamWithIdNotFound(stream_id));
};
stream.set_active(false).expect("could not stop stream");
lock_guard.unlock();
return Ok(());
}
let lock_guard = self.thread_loop.lock();
let stream_hash = self.stream_hash.read().unwrap();
let stream = stream_hash
.get(&stream_id)
.expect("Could not find stream with this id in `stream_hash`.");
stream.set_active(false).expect("could not stop stream");
lock_guard.unlock();
Ok(())
}
}