diff --git a/vhost-device-vsock/src/main.rs b/vhost-device-vsock/src/main.rs index c4ca367..1a0697e 100644 --- a/vhost-device-vsock/src/main.rs +++ b/vhost-device-vsock/src/main.rs @@ -9,6 +9,7 @@ mod vhu_vsock_thread; mod vsock_conn; use std::{ + any::Any, collections::HashMap, convert::TryFrom, process::exit, @@ -55,6 +56,8 @@ enum BackendError { CouldNotCreateBackend(vhu_vsock::Error), #[error("Could not create daemon: {0}")] CouldNotCreateDaemon(vhost_user_backend::Error), + #[error("Thread `{0}` panicked")] + ThreadPanic(String, Box), } #[derive(Args, Clone, Debug)] @@ -266,20 +269,32 @@ pub(crate) fn start_backend_server( pub(crate) fn start_backend_servers(configs: &[VsockConfig]) -> Result<(), BackendError> { let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); - let mut handles = Vec::new(); + let mut handles = HashMap::new(); + let (senders, receiver) = std::sync::mpsc::channel(); - for c in configs.iter() { + for (thread_id, c) in configs.iter().enumerate() { let config = c.clone(); let cid_map = cid_map.clone(); + let sender = senders.clone(); + let name = format!("vhu-vsock-cid-{}", c.get_guest_cid()); let handle = thread::Builder::new() - .name(format!("vhu-vsock-cid-{}", c.get_guest_cid())) - .spawn(move || start_backend_server(config, cid_map)) + .name(name.clone()) + .spawn(move || { + let result = + std::panic::catch_unwind(move || start_backend_server(config, cid_map)); + + // Notify the main thread that we are done. + sender.send(thread_id).unwrap(); + + result.map_err(|e| BackendError::ThreadPanic(name, e))? + }) .unwrap(); - handles.push(handle); + handles.insert(thread_id, handle); } - for handle in handles { - handle.join().unwrap()?; + while !handles.is_empty() { + let thread_id = receiver.recv().unwrap(); + handles.remove(&thread_id).unwrap().join().unwrap()?; } Ok(()) @@ -548,4 +563,48 @@ mod tests { test_dir.close().unwrap(); } + + #[test] + fn test_start_backend_servers_failure() { + const CONN_TX_BUF_SIZE: u32 = 64 * 1024; + + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let configs = [ + VsockConfig::new( + 3, + test_dir + .path() + .join("test_vsock_server1.socket") + .display() + .to_string(), + test_dir + .path() + .join("test_vsock_server1.vsock") + .display() + .to_string(), + CONN_TX_BUF_SIZE, + vec![DEFAULT_GROUP_NAME.to_string()], + ), + VsockConfig::new( + 3, + test_dir + .path() + .join("test_vsock_server2.socket") + .display() + .to_string(), + test_dir + .path() + .join("test_vsock_server2.vsock") + .display() + .to_string(), + CONN_TX_BUF_SIZE, + vec![DEFAULT_GROUP_NAME.to_string()], + ), + ]; + + start_backend_servers(&configs).unwrap_err(); + + test_dir.close().unwrap(); + } }