diff --git a/staging/vhost-device-console/Cargo.toml b/staging/vhost-device-console/Cargo.toml index b4c4f02..068dd02 100644 --- a/staging/vhost-device-console/Cargo.toml +++ b/staging/vhost-device-console/Cargo.toml @@ -17,7 +17,6 @@ xen = ["vm-memory/xen", "vhost/xen", "vhost-user-backend/xen"] [dependencies] console = "0.15.7" crossterm = "0.28.1" -nix = "0.26.4" queues = "1.0.2" clap = { version = "4.5", features = ["derive"] } env_logger = "0.11" diff --git a/staging/vhost-device-console/src/backend.rs b/staging/vhost-device-console/src/backend.rs index 62d6745..61f03a1 100644 --- a/staging/vhost-device-console/src/backend.rs +++ b/staging/vhost-device-console/src/backend.rs @@ -5,7 +5,7 @@ // // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause -use log::{error, info, warn}; +use log::{error, info}; use std::any::Any; use std::collections::HashMap; use std::path::PathBuf; @@ -28,6 +28,8 @@ pub(crate) enum Error { SocketCountInvalid(usize), #[error("Could not create console backend: {0}")] CouldNotCreateBackend(crate::vhu_console::Error), + #[error("Could not create console backend: {0}")] + CouldNotInitBackend(crate::vhu_console::Error), #[error("Could not create daemon: {0}")] CouldNotCreateDaemon(vhost_user_backend::Error), #[error("Fatal error: {0}")] @@ -92,6 +94,12 @@ pub(crate) fn start_backend_server( VhostUserConsoleBackend::new(arc_controller).map_err(Error::CouldNotCreateBackend)?, )); + vu_console_backend + .write() + .unwrap() + .assign_input_method(tcp_addr.clone()) + .map_err(Error::CouldNotInitBackend)?; + let mut daemon = VhostUserDaemon::new( String::from("vhost-device-console-backend"), vu_console_backend.clone(), @@ -102,26 +110,15 @@ pub(crate) fn start_backend_server( let vring_workers = daemon.get_epoll_handlers(); vu_console_backend .read() - .unwrap() - .set_vring_worker(&vring_workers[0]); + .expect("Cannot open as write\n") + .set_vring_worker(vring_workers[0].clone()); - // Start the corresponding console thread - let read_handle = if backend == BackendType::Nested { - VhostUserConsoleBackend::start_console_thread(&vu_console_backend) - } else { - VhostUserConsoleBackend::start_tcp_console_thread(&vu_console_backend, tcp_addr.clone()) - }; - - daemon.serve(&socket).map_err(Error::ServeFailed)?; - - // Kill console input thread - vu_console_backend.read().unwrap().kill_console_thread(); - - // Wait for read thread to exit - match read_handle.join() { - Ok(_) => info!("The read thread returned successfully"), - Err(e) => warn!("The read thread returned the error: {:?}", e), - } + daemon.serve(&socket).map_err(|e| { + // Even if daemon stops unexpectedly, the backend should + // be terminated properly (disable raw mode). + vu_console_backend.read().unwrap().prepare_exit(); + Error::ServeFailed(e) + })?; } } diff --git a/staging/vhost-device-console/src/console.rs b/staging/vhost-device-console/src/console.rs index 079fe0c..b02b5d0 100644 --- a/staging/vhost-device-console/src/console.rs +++ b/staging/vhost-device-console/src/console.rs @@ -20,7 +20,6 @@ pub enum BackendType { pub(crate) struct ConsoleController { config: VirtioConsoleConfig, pub backend: BackendType, - pub exit: bool, } impl ConsoleController { @@ -33,7 +32,6 @@ impl ConsoleController { emerg_wr: 64.into(), }, backend, - exit: false, } } diff --git a/staging/vhost-device-console/src/vhu_console.rs b/staging/vhost-device-console/src/vhu_console.rs index d5cca74..f33e2d1 100644 --- a/staging/vhost-device-console/src/vhu_console.rs +++ b/staging/vhost-device-console/src/vhu_console.rs @@ -11,22 +11,20 @@ use crate::virtio_console::{ VIRTIO_CONSOLE_F_MULTIPORT, VIRTIO_CONSOLE_PORT_ADD, VIRTIO_CONSOLE_PORT_NAME, VIRTIO_CONSOLE_PORT_OPEN, VIRTIO_CONSOLE_PORT_READY, }; -use console::Key; use crossterm::terminal::{disable_raw_mode, enable_raw_mode}; -use log::{error, trace}; -use nix::sys::select::{select, FdSet}; -use std::os::fd::AsRawFd; +use log::{error, trace, warn}; +use queues::{IsQueue, Queue}; +use std::net::TcpListener; +use std::os::fd::{AsRawFd, RawFd}; use std::slice::from_raw_parts; use std::sync::{Arc, RwLock}; -use std::thread::JoinHandle; use std::{ convert, - io::{self, Result as IoResult}, + io::{self, Read, Result as IoResult, Write}, }; use thiserror::Error as ThisError; use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; -use vhost_user_backend::VringEpollHandler; -use vhost_user_backend::{VhostUserBackendMut, VringRwLock, VringT}; +use vhost_user_backend::{VhostUserBackendMut, VringEpollHandler, VringRwLock, VringT}; use virtio_bindings::bindings::virtio_config::{VIRTIO_F_NOTIFY_ON_EMPTY, VIRTIO_F_VERSION_1}; use virtio_bindings::bindings::virtio_ring::{ VIRTIO_RING_F_EVENT_IDX, VIRTIO_RING_F_INDIRECT_DESC, @@ -38,13 +36,6 @@ use vm_memory::{ use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::{EventFd, EFD_NONBLOCK}; -use console::Term; -use queues::{IsQueue, Queue}; -use std::io::Read; -use std::io::Write; -use std::net::TcpListener; -use std::thread::spawn; - /// Virtio configuration const QUEUE_SIZE: usize = 128; const NUM_QUEUES: usize = 4; @@ -63,6 +54,9 @@ const CTRL_TX_QUEUE: u16 = 3; /// needs to write to the RX control queue. const BACKEND_RX_EFD: u16 = (NUM_QUEUES + 1) as u16; const BACKEND_CTRL_RX_EFD: u16 = (NUM_QUEUES + 2) as u16; +const KEY_EFD: u16 = (NUM_QUEUES + 3) as u16; +const LISTENER_EFD: u16 = (NUM_QUEUES + 4) as u16; +const EXIT_EFD: u16 = (NUM_QUEUES + 5) as u16; /// Port name - Need to be updated when MULTIPORT feature /// is supported for more than one devices. @@ -88,6 +82,12 @@ pub(crate) enum Error { EventFdFailed, #[error("Failed to add control message in the internal queue")] RxCtrlQueueAddFailed, + #[error("Error adding epoll")] + EpollAdd, + #[error("Error removing epoll")] + EpollRemove, + #[error("Error creating epoll")] + EpollFdCreate, } impl convert::From for io::Error { @@ -96,6 +96,10 @@ impl convert::From for io::Error { } } +// Define a new trait that combines Read and Write +pub trait ReadWrite: Read + Write {} +impl ReadWrite for T {} + // SAFETY: The layout of the structure is fixed and can be initialized by // reading its content from byte array. unsafe impl ByteValued for VirtioConsoleControl {} @@ -105,10 +109,15 @@ pub(crate) struct VhostUserConsoleBackend { acked_features: u64, event_idx: bool, rx_ctrl_fifo: Queue, - rx_data_fifo: Queue, + rx_data_fifo: Queue, + epoll_fd: i32, + stream_fd: Option, pub(crate) ready: bool, pub(crate) ready_to_write: bool, pub(crate) output_queue: Queue, + pub(crate) stdin: Option>, + pub(crate) listener: Option, + pub(crate) stream: Option>, pub(crate) rx_event: EventFd, pub(crate) rx_ctrl_event: EventFd, pub(crate) exit_event: EventFd, @@ -120,14 +129,19 @@ type ConsoleDescriptorChain = DescriptorChain>) -> Result { Ok(VhostUserConsoleBackend { - controller, + controller: controller.clone(), event_idx: false, rx_ctrl_fifo: Queue::new(), rx_data_fifo: Queue::new(), + epoll_fd: epoll::create(false).map_err(|_| Error::EpollFdCreate)?, + stream_fd: None, acked_features: 0x0, ready: false, ready_to_write: false, output_queue: Queue::new(), + stdin: None, + stream: None, + listener: None, rx_event: EventFd::new(EFD_NONBLOCK).map_err(|_| Error::EventFdFailed)?, rx_ctrl_event: EventFd::new(EFD_NONBLOCK).map_err(|_| Error::EventFdFailed)?, exit_event: EventFd::new(EFD_NONBLOCK).map_err(|_| Error::EventFdFailed)?, @@ -135,6 +149,24 @@ impl VhostUserConsoleBackend { }) } + pub fn assign_input_method(&mut self, tcpaddr_str: String) -> Result<()> { + if self.controller.read().unwrap().backend == BackendType::Nested { + // Enable raw mode for local terminal if backend is nested + enable_raw_mode().expect("Raw mode error"); + + let stdin_fd = io::stdin().as_raw_fd(); + let stdin: Box = Box::new(io::stdin()); + self.stdin = Some(stdin); + + Self::epoll_register(self.epoll_fd.as_raw_fd(), stdin_fd, epoll::Events::EPOLLIN) + .map_err(|_| Error::EpollAdd)?; + } else { + let listener = TcpListener::bind(tcpaddr_str.clone()).expect("asdasd"); + self.listener = Some(listener); + } + Ok(()) + } + fn print_console_frame(&self, control_msg: VirtioConsoleControl) { trace!("id 0x{:x}", control_msg.id.to_native()); trace!("event 0x{:x}", control_msg.event.to_native()); @@ -157,22 +189,26 @@ impl VhostUserConsoleBackend { .writer(&atomic_mem) .map_err(|_| Error::DescriptorWriteFailed)?; - let response: String = match self.rx_data_fifo.remove() { - Ok(item) => item, - _ => { - return Ok(false); - } - }; + let avail_data_len = writer.available_bytes(); + let queue_len = self.rx_data_fifo.size(); + let min_limit = std::cmp::min(queue_len, avail_data_len); + + for _i in 0..min_limit { + let response: u8 = match self.rx_data_fifo.remove() { + Ok(item) => item, + _ => { + return Ok(true); + } + }; - for b in response.bytes() { writer - .write_obj::(b) + .write_obj::(response) .map_err(|_| Error::DescriptorWriteFailed)?; - } - vring - .add_used(desc_chain.head_index(), writer.bytes_written() as u32) - .map_err(|_| Error::AddUsedElemFailed(RX_QUEUE))?; + vring + .add_used(desc_chain.head_index(), writer.bytes_written() as u32) + .map_err(|_| Error::AddUsedElemFailed(RX_QUEUE))?; + } } Ok(true) @@ -210,8 +246,8 @@ impl VhostUserConsoleBackend { } else { self.output_queue .add(my_string) - .expect("Failed to add element in the output queue"); - //.map_err(|_| Error::RxCtrlQueueAddFailed)?; + .map_err(|_| Error::RxCtrlQueueAddFailed)?; + self.write_tcp_stream(); } vring @@ -438,7 +474,7 @@ impl VhostUserConsoleBackend { /// Set self's VringWorker. pub(crate) fn set_vring_worker( &self, - vring_worker: &Arc>>>, + vring_worker: Arc>>>, ) { let rx_event_fd = self.rx_event.as_raw_fd(); vring_worker @@ -453,187 +489,170 @@ impl VhostUserConsoleBackend { u64::from(BACKEND_CTRL_RX_EFD), ) .unwrap(); + + let exit_event_fd = self.exit_event.as_raw_fd(); + vring_worker + .register_listener(exit_event_fd, EventSet::IN, u64::from(EXIT_EFD)) + .unwrap(); + + let epoll_fd = self.epoll_fd.as_raw_fd(); + vring_worker + .register_listener(epoll_fd, EventSet::IN, u64::from(KEY_EFD)) + .unwrap(); + + if self.controller.read().unwrap().backend == BackendType::Network { + let listener_fd = self.listener.as_ref().expect("asd").as_raw_fd(); + vring_worker + .register_listener(listener_fd, EventSet::IN, u64::from(LISTENER_EFD)) + .unwrap(); + } } - pub(crate) fn start_tcp_console_thread( - vhu_console: &Arc>, - tcplisener_str: String, - ) -> JoinHandle> { - let vhu_console = Arc::clone(vhu_console); - spawn(move || { - loop { - let ready = vhu_console.read().unwrap().ready_to_write; - let exit = vhu_console.read().unwrap().controller.read().unwrap().exit; - - if exit { - trace!("Thread exits!"); - break; - } else if ready { - let listener = match TcpListener::bind(tcplisener_str.clone()) { - Ok(listener) => listener, - Err(e) => { - eprintln!("Failed to bind to {}: {}", tcplisener_str, e); - return Ok(()); - } - }; - listener.set_nonblocking(true).expect("Non-blocking error"); - - println!("Server listening on address: {}", tcplisener_str.clone()); - for stream in listener.incoming() { - match stream { - Ok(mut stream) => { - trace!("New connection"); - stream.set_nonblocking(true).expect("Non-blocking error"); - - let mut buffer = [0; 1024]; - loop { - let exit = - vhu_console.read().unwrap().controller.read().unwrap().exit; - if exit { - trace!("Thread exits!"); - return Ok(()); - } - // Write to the stream - if vhu_console.read().unwrap().output_queue.size() > 0 { - let byte_stream = vhu_console - .write() - .unwrap() - .output_queue - .remove() - .expect("Error removing element from output queue") - .into_bytes(); - if let Err(e) = stream.write_all(&byte_stream) { - eprintln!("Error writing to stream: {}", e); - } - } - match stream.read(&mut buffer) { - Ok(bytes_read) => { - if bytes_read == 0 { - println!("Close connection"); - break; - } - trace!( - "Received: {}", - String::from_utf8_lossy(&buffer[..bytes_read]) - ); - let input_buffer = - String::from_utf8_lossy(&buffer[..bytes_read]) - .to_string(); - vhu_console - .write() - .unwrap() - .rx_data_fifo - .add(input_buffer) - .unwrap(); - vhu_console.write().unwrap().rx_event.write(1).unwrap(); - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - continue; - } - Err(ref e) - if e.kind() == io::ErrorKind::BrokenPipe - || e.kind() == io::ErrorKind::ConnectionReset => - { - trace!("Stream has been closed."); - break; - } - Err(e) => { - eprintln!("Error reading from socket: {}", e); - } - } - } - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - let exit = - vhu_console.read().unwrap().controller.read().unwrap().exit; - if exit { - trace!("Thread exits!"); - return Ok(()); - } - continue; - } - Err(e) => { - eprintln!("Error accepting connection: {}", e); - break; - } - } - } - } - } - Ok(()) - }) + /// Register a file with an epoll to listen for events in evset. + pub fn epoll_register(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()> { + epoll::ctl( + epoll_fd, + epoll::ControlOptions::EPOLL_CTL_ADD, + fd, + epoll::Event::new(evset, fd as u64), + ) + .map_err(|_| Error::EpollAdd)?; + Ok(()) } - /// Start console thread. - pub(crate) fn start_console_thread( - vhu_console: &Arc>, - ) -> JoinHandle> { - let vhu_console = Arc::clone(vhu_console); + /// Remove a file from the epoll. + pub fn epoll_unregister(epoll_fd: RawFd, fd: RawFd) -> Result<()> { + epoll::ctl( + epoll_fd, + epoll::ControlOptions::EPOLL_CTL_DEL, + fd, + epoll::Event::new(epoll::Events::empty(), 0), + ) + .map_err(|_| Error::EpollRemove)?; - let exit_eventfd = vhu_console.read().unwrap().exit_event.as_raw_fd(); - // Spawn a new thread to handle input. - spawn(move || { - let term = Term::stdout(); - let mut fdset = FdSet::new(); - fdset.insert(term.as_raw_fd()); - fdset.insert(exit_eventfd); - let max_fd = fdset.highest().expect("Failed to read fdset!") + 1; + Ok(()) + } - loop { - let ready = vhu_console.read().unwrap().ready_to_write; - let exit = vhu_console.read().unwrap().controller.read().unwrap().exit; - - if exit { - trace!("Exit!"); - break; - } else if ready { - let mut fdset_clone = fdset; - enable_raw_mode().expect("Raw mode error"); - - match select(Some(max_fd), Some(&mut fdset_clone), None, None, None) { - Ok(_num_fds) => { - let exit = vhu_console.read().unwrap().controller.read().unwrap().exit; - if (fdset_clone.contains(exit_eventfd)) && exit { - trace!("Exit!"); - break; - } - - if fdset_clone.contains(term.as_raw_fd()) { - if let Some(character) = match term.read_key().unwrap() { - Key::Char(character) => Some(character), - Key::Enter => Some('\n'), - Key::Tab => Some('\t'), - Key::Backspace => Some('\u{8}'), - _ => None, - } { - // Pass the data to vhu_console and trigger an EventFd - let input_buffer = character.to_string(); - vhu_console - .write() - .unwrap() - .rx_data_fifo - .add(input_buffer) - .unwrap(); - vhu_console.write().unwrap().rx_event.write(1).unwrap(); - } - } - } - Err(e) => { - eprintln!("Error in select: {}", e); - break; - } + fn create_new_stream_thread(&mut self) { + // Accept only one incoming connection + if let Some(stream) = self.listener.as_ref().expect("asd").incoming().next() { + match stream { + Ok(stream) => { + let local_addr = self + .listener + .as_ref() + .expect("No listener") + .local_addr() + .unwrap(); + println!("New connection on: {}", local_addr); + let stream_raw_fd = stream.as_raw_fd(); + self.stream_fd = Some(stream_raw_fd); + if let Err(err) = Self::epoll_register( + self.epoll_fd.as_raw_fd(), + stream_raw_fd, + epoll::Events::EPOLLIN, + ) { + warn!("Failed to register with epoll: {:?}", err); } + + let stream: Box = Box::new(stream); + self.stream = Some(stream); + self.write_tcp_stream(); + } + Err(e) => { + eprintln!("Stream error: {}", e); } } + } + } + fn write_tcp_stream(&mut self) { + if self.stream.is_some() { + while self.output_queue.size() > 0 { + let byte_stream = self + .output_queue + .remove() + .expect("Error removing element from output queue") + .into_bytes(); + + if let Err(e) = self + .stream + .as_mut() + .expect("Stream not found") + .write_all(&byte_stream) + { + eprintln!("Error writing to stream: {}", e); + } + } + } + } + + fn read_tcp_stream(&mut self) { + let mut buffer = [0; 1024]; + match self.stream.as_mut().expect("No stream").read(&mut buffer) { + Ok(bytes_read) => { + if bytes_read == 0 { + let local_addr = self + .listener + .as_ref() + .expect("No listener") + .local_addr() + .unwrap(); + println!("Close connection on: {}", local_addr); + if let Err(err) = Self::epoll_unregister( + self.epoll_fd.as_raw_fd(), + self.stream_fd.expect("No stream fd"), + ) { + warn!("Failed to register with epoll: {:?}", err); + } + return; + } + if self.ready_to_write { + for byte in buffer.iter().take(bytes_read) { + self.rx_data_fifo.add(*byte).unwrap(); + } + self.rx_event.write(1).unwrap(); + } + } + Err(e) => { + eprintln!("Error reading from socket: {}", e); + } + } + } + + fn read_char_thread(&mut self) -> IoResult<()> { + let mut bytes = [0; 1]; + match self.stdin.as_mut().expect("No stdin").read(&mut bytes) { + Ok(read_len) => { + if read_len > 0 { + // If the user presses ^C then exit + if bytes[0] == 3 { + disable_raw_mode().expect("Raw mode error"); + trace!("Termination!\n"); + std::process::exit(0); + } + + // If backend is ready pass the data to vhu_console + // and trigger an EventFd. + if self.ready_to_write { + self.rx_data_fifo.add(bytes[0]).unwrap(); + self.rx_event.write(1).unwrap(); + } + } + Ok(()) + } + Err(e) => { + eprintln!("Read stdin error: {}", e); + Err(e) + } + } + } + + pub fn prepare_exit(&self) { + /* For the nested backend */ + if self.controller.read().unwrap().backend == BackendType::Nested { disable_raw_mode().expect("Raw mode error"); - Ok(()) - }) - } - pub fn kill_console_thread(&self) { - trace!("Kill thread"); - self.controller.write().unwrap().exit = true; - self.exit_event.write(1).unwrap(); + } } } @@ -703,19 +722,24 @@ impl VhostUserBackendMut for VhostUserConsoleBackend { vrings: &[VringRwLock], _thread_id: usize, ) -> IoResult<()> { - if device_event == RX_QUEUE { - // Check if there are any available data - if self.rx_data_fifo.size() == 0 { - return Ok(()); - } - }; + if device_event == EXIT_EFD { + self.prepare_exit(); + return Ok(()); + } - if device_event == CTRL_RX_QUEUE { - // Check if there are any available data and the device is ready - if (!self.ready) || (self.rx_ctrl_fifo.size() == 0) { + if device_event == LISTENER_EFD { + self.create_new_stream_thread(); + return Ok(()); + } + + if device_event == KEY_EFD { + if self.controller.read().unwrap().backend == BackendType::Nested { + return self.read_char_thread(); + } else { + self.read_tcp_stream(); return Ok(()); } - }; + } let vring = if device_event == BACKEND_RX_EFD { &vrings[RX_QUEUE as usize] @@ -729,12 +753,24 @@ impl VhostUserBackendMut for VhostUserConsoleBackend { loop { vring.disable_notification().unwrap(); match device_event { - RX_QUEUE => self.process_rx_queue(vring), + RX_QUEUE => { + if self.rx_data_fifo.size() != 0 { + self.process_rx_queue(vring) + } else { + break; + } + } TX_QUEUE => { self.ready_to_write = true; self.process_tx_queue(vring) } - CTRL_RX_QUEUE => self.process_ctrl_rx_queue(vring), + CTRL_RX_QUEUE => { + if self.ready && (self.rx_ctrl_fifo.size() != 0) { + self.process_ctrl_rx_queue(vring) + } else { + break; + } + } CTRL_TX_QUEUE => self.process_ctrl_tx_queue(vring), BACKEND_RX_EFD => { let _ = self.rx_event.read();