From db6de9960dc136ce7937c9a0f66d6c64cd334c91 Mon Sep 17 00:00:00 2001 From: Ramyak Mehra Date: Sun, 2 Jun 2024 19:51:19 +0530 Subject: [PATCH] vsock: exit early while reading connect message from socket Before this change, we read a fixed number of bytes before checking `\n`. If the user provided less than that, the application would wait indefinitely. Let`s remove this limitation by using buffered reading and checking `\n` from the beginning of the input. Fixes: #307 Signed-off-by: Ramyak Mehra --- vhost-device-vsock/src/vhu_vsock_thread.rs | 32 ++++++++-------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/vhost-device-vsock/src/vhu_vsock_thread.rs b/vhost-device-vsock/src/vhu_vsock_thread.rs index f6213cb..c86818a 100644 --- a/vhost-device-vsock/src/vhu_vsock_thread.rs +++ b/vhost-device-vsock/src/vhu_vsock_thread.rs @@ -3,8 +3,7 @@ use std::{ collections::HashSet, fs::File, - io, - io::Read, + io::{self, BufRead, BufReader}, iter::FromIterator, num::Wrapping, ops::Deref, @@ -12,8 +11,10 @@ use std::{ net::{UnixListener, UnixStream}, prelude::{AsRawFd, FromRawFd, RawFd}, }, - sync::mpsc::Sender, - sync::{mpsc, Arc, RwLock}, + sync::{ + mpsc::{self, Sender}, + Arc, RwLock, + }, thread, }; @@ -437,26 +438,14 @@ impl VhostUserVsockThread { /// Read `CONNECT PORT_NUM\n` from the connected stream. fn read_local_stream_port(stream: &mut UnixStream) -> Result { - let mut buf = [0u8; 32]; + let mut buf = Vec::new(); + let mut reader = BufReader::new(stream); - // Minimum number of bytes we should be able to read - // Corresponds to 'CONNECT 0\n' - const MIN_READ_LEN: usize = 10; - - // Read in the minimum number of bytes we can read - stream - .read_exact(&mut buf[..MIN_READ_LEN]) + let n = reader + .read_until(b'\n', &mut buf) .map_err(Error::UnixRead)?; - let mut read_len = MIN_READ_LEN; - while buf[read_len - 1] != b'\n' && read_len < buf.len() { - stream - .read_exact(&mut buf[read_len..read_len + 1]) - .map_err(Error::UnixRead)?; - read_len += 1; - } - - let mut word_iter = std::str::from_utf8(&buf[..read_len]) + let mut word_iter = std::str::from_utf8(&buf[..n]) .map_err(Error::ConvertFromUtf8)? .split_whitespace(); @@ -718,6 +707,7 @@ impl Drop for VhostUserVsockThread { mod tests { use super::*; use std::collections::HashMap; + use std::io::Read; use std::io::Write; use tempfile::tempdir; use vm_memory::GuestAddress;