From 8646373e9645c41783d41fb12419f65f4d8df71f Mon Sep 17 00:00:00 2001 From: Dorjoy Chowdhury Date: Sat, 31 Aug 2024 22:31:36 +0600 Subject: [PATCH] vsock: Support for vsock backend This commit adds support for proxying the communication with a VM using vsock, similar to already existing support using unix domain socket through the uds-path option. Two new options have been introduced: - forward-cid - forward-listen The forward-cid option (u32) allows users to specify the CID to which all the connections from the VM should be forwarded to, regardless of the target CID of those connections. Users would typically forward to CID 1 i.e., the host machine. The forward-listen option (string) is a list of ports separated by '+' for forwarding connections from the host machine to the VM. Signed-off-by: Dorjoy Chowdhury --- Cargo.lock | 24 +- vhost-device-vsock/CHANGELOG.md | 1 + vhost-device-vsock/Cargo.toml | 4 + vhost-device-vsock/README.md | 62 ++- vhost-device-vsock/src/main.rs | 619 +++++++++++++++++++-- vhost-device-vsock/src/thread_backend.rs | 175 ++++-- vhost-device-vsock/src/vhu_vsock.rs | 135 +++-- vhost-device-vsock/src/vhu_vsock_thread.rs | 162 +++++- vhost-device-vsock/src/vsock_conn.rs | 6 +- 9 files changed, 1039 insertions(+), 149 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 08b36e2..49fdd31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -739,6 +739,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -755,7 +764,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "memoffset", + "memoffset 0.6.5", ] [[package]] @@ -779,6 +788,7 @@ dependencies = [ "cfg-if", "cfg_aliases", "libc", + "memoffset 0.9.1", ] [[package]] @@ -1527,6 +1537,7 @@ dependencies = [ "env_logger", "epoll", "figment", + "libc", "log", "serde", "tempfile", @@ -1538,6 +1549,7 @@ dependencies = [ "virtio-vsock", "vm-memory", "vmm-sys-util", + "vsock", ] [[package]] @@ -1608,6 +1620,16 @@ dependencies = [ "libc", ] +[[package]] +name = "vsock" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e8b4d00e672f147fc86a09738fadb1445bd1c0a40542378dfb82909deeee688" +dependencies = [ + "libc", + "nix 0.29.0", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/vhost-device-vsock/CHANGELOG.md b/vhost-device-vsock/CHANGELOG.md index 235c199..a3b35b6 100644 --- a/vhost-device-vsock/CHANGELOG.md +++ b/vhost-device-vsock/CHANGELOG.md @@ -3,6 +3,7 @@ ### Added - [#698](https://github.com/rust-vmm/vhost-device/pull/698) vsock: add mdoc page +- [#706](https://github.com/rust-vmm/vhost-device/pull/706) Support proxying using vsock ### Changed diff --git a/vhost-device-vsock/Cargo.toml b/vhost-device-vsock/Cargo.toml index a356347..345d249 100644 --- a/vhost-device-vsock/Cargo.toml +++ b/vhost-device-vsock/Cargo.toml @@ -10,6 +10,8 @@ license = "Apache-2.0 OR BSD-3-Clause" edition = "2021" [features] +default = ["backend_vsock"] +backend_vsock = ["vsock", "libc"] xen = ["vm-memory/xen", "vhost/xen", "vhost-user-backend/xen"] [dependencies] @@ -27,6 +29,8 @@ virtio-vsock = "0.6" vm-memory = "0.14.1" vmm-sys-util = "0.12" figment = { version = "0.10.19", features = ["yaml"] } +vsock = { version = "0.5.0", optional = true } +libc = { version = "0.2.39", optional = true } serde = { version = "1", features = ["derive"] } [dev-dependencies] diff --git a/vhost-device-vsock/README.md b/vhost-device-vsock/README.md index c58f3d6..8c8b915 100644 --- a/vhost-device-vsock/README.md +++ b/vhost-device-vsock/README.md @@ -6,7 +6,10 @@ The crate introduces a vhost-device-vsock device that enables communication betw application running in the guest i.e inside a VM and an application running on the host i.e outside the VM. The application running in the guest communicates over VM sockets i.e over AF_VSOCK sockets. The application running on the host connects to a -unix socket on the host i.e communicates over AF_UNIX sockets. The main components of +unix socket on the host i.e communicates over AF_UNIX sockets when using the unix domain +socket backend through the uds-path option or the application in the host listens or +connects to vsock on the host i.e communicates over AF_VSOCK sockets when using the +vsock backend through the forward-cid, forward-listen options. The main components of the crate are split into various files as described below: - [packet.rs](src/packet.rs) @@ -38,7 +41,7 @@ the crate are split into various files as described below: ## Usage -Run the vhost-device-vsock device: +Run the vhost-device-vsock device with unix domain socket backend: ``` vhost-device-vsock --guest-cid= \ --socket= \ @@ -52,11 +55,26 @@ or vhost-device-vsock --vm guest_cid=,socket=,uds-path=[,tx-buffer-size=host packets)>][,queue-size=][,groups=] ``` +Run the vhost-device-vsock device with vsock backend: +``` +vhost-device-vsock --guest-cid= \ + --socket= \ + --forward-cid= \ + [--forward-listen= \ + [--tx-buffer-size=host packets)>] \ + [--queue-size=] \ +``` +or +``` +vhost-device-vsock --vm guest_cid=,socket=,forward-cid=[,forward-listen=][,tx-buffer-size=host packets)>][,queue-size=][,groups=] +``` + Specify the `--vm` argument multiple times to specify multiple devices like this: ``` vhost-device-vsock \ --vm guest-cid=3,socket=/tmp/vhost3.socket,uds-path=/tmp/vm3.vsock,groups=group1+groupA \ ---vm guest-cid=4,socket=/tmp/vhost4.socket,uds-path=/tmp/vm4.vsock,tx-buffer-size=32768,queue-size=256 +--vm guest-cid=4,socket=/tmp/vhost4.socket,uds-path=/tmp/vm4.vsock,tx-buffer-size=32768,queue-size=256 \ +--vm guest-cid=5,socket=/tmp/vhost5.socket,forward-cid=1,forward-listen=9001+9002,tx-buffer-size=32768,queue-size=1024 ``` Or use a configuration file: @@ -79,6 +97,12 @@ vms: tx_buffer_size: 32768 queue_size: 256 groups: group2+groupB + - guest_cid: 5 + socket: /tmp/vhost5.socket + forward-cid: 1 + forward-listen: 9001+9002 + tx_buffer_size: 32768 + queue_size: 1024 ``` Run VMM (e.g. QEMU): @@ -185,6 +209,38 @@ guest_cid3$ nc-vsock -l 1234 guest_cid4$ nc-vsock 3 1234 ``` +### Using the vsock backend + +The vsock backend is available under the `backend_vsock` feature (enabled by default). If you want to test a guest VM that +has built-in applications which communicate with another VM over AF_VSOCK, you can forward the connections from the guest +to the host machine instead of running a separate VM for easier testing using the forward-cid option. In such a case, you +would run the corresponding applications that listen for or connect with applications in the guest VM using AF_VSOCK in the +host instead of running the separate VM. For forwarding AF_VSOCK connections from the host, you can use the forward-listen +option. + +For example, if the guest VM that you want to test has an application that connects to (CID 3, port 9000) upon boot and applications +that listen on port 9001 and 9002 for connections, first run vhost-device-vsock: + +```sh +shell1$ vhost-device-vsock --vm guest-cid=4,forward-cid=1,forward-listen=9001+9002,socket=/tmp/vhost4.socket +``` + +Now run the application listening for connections to port 9000 on the host machine and then run the guest VM: + +```sh +shell2$ qemu-system-x86_64 \ + -drive file=vm1.qcow2,format=qcow2,if=virtio -smp 2 \ + -object memory-backend-memfd,id=mem0,size=512M \ + -machine q35,accel=kvm,memory-backend=mem0 \ + -chardev socket,id=char0,reconnect=0,path=/tmp/vhost4.socket \ + -device vhost-user-vsock-pci,chardev=char0 +``` + +After the guest VM boots, the application inside the guest connecting to (CID 3, port 9000) should successfully connect to the +application running on the host. Assuming the applications listening on port 9001 and 9002 are running in the guest VM, you can +now run the applications that connect to port 9001 and 9002 (you need to modify the CID they connect to be the host CID i.e. 1) +on the host machine. + ## License This project is licensed under either of diff --git a/vhost-device-vsock/src/main.rs b/vhost-device-vsock/src/main.rs index 0536841..b8ed4bc 100644 --- a/vhost-device-vsock/src/main.rs +++ b/vhost-device-vsock/src/main.rs @@ -17,6 +17,8 @@ use std::{ thread, }; +#[cfg(feature = "backend_vsock")] +use crate::vhu_vsock::VsockProxyInfo; use crate::vhu_vsock::{BackendType, CidMap, VhostUserVsockBackend, VsockConfig}; use clap::{Args, Parser}; use figment::{ @@ -82,8 +84,40 @@ struct VsockParam { socket: String, /// Unix socket to which a host-side application connects to. + #[cfg(not(feature = "backend_vsock"))] #[arg(long, conflicts_with = "config", conflicts_with = "vm")] - uds_path: String, + uds_path: Option, + + /// Unix socket to which a host-side application connects to. + #[cfg(feature = "backend_vsock")] + #[arg( + long, + conflicts_with = "forward_cid", + conflicts_with = "forward_listen", + conflicts_with = "config", + conflicts_with = "vm" + )] + uds_path: Option, + + /// The vsock CID to forward connections from guest + #[cfg(feature = "backend_vsock")] + #[clap( + long, + conflicts_with = "uds_path", + conflicts_with = "config", + conflicts_with = "vm" + )] + forward_cid: Option, + + /// The vsock ports to forward connections from host + #[cfg(feature = "backend_vsock")] + #[clap( + long, + conflicts_with = "uds_path", + conflicts_with = "config", + conflicts_with = "vm" + )] + forward_listen: Option, /// The size of the buffer used for the TX virtqueue #[clap(long, default_value_t = DEFAULT_TX_BUFFER_SIZE, conflicts_with = "config", conflicts_with = "vm")] @@ -109,7 +143,11 @@ struct VsockParam { struct ConfigFileVsockParam { guest_cid: Option, socket: String, - uds_path: String, + uds_path: Option, + #[cfg(feature = "backend_vsock")] + forward_cid: Option, + #[cfg(feature = "backend_vsock")] + forward_listen: Option, tx_buffer_size: Option, queue_size: Option, groups: Option, @@ -126,6 +164,19 @@ struct VsockArgs { /// Example: /// --vm guest-cid=3,socket=/tmp/vhost3.socket,uds-path=/tmp/vm3.vsock,tx-buffer-size=65536,queue-size=1024,groups=group1+group2 /// Multiple instances of this argument can be provided to configure devices for multiple guests. + #[cfg(not(feature = "backend_vsock"))] + #[arg(long, conflicts_with = "config", verbatim_doc_comment, value_parser = parse_vm_params)] + vm: Option>, + + /// Device parameters corresponding to a VM in the form of comma separated key=value pairs. + /// The allowed keys are: guest_cid, socket, uds_path, forward_cid, forward_listen, tx_buffer_size, queue_size and group. + /// uds_path and (forward_cid, forward_listen) are mutually exclusive. Use uds_path when you want unix domain socket + /// backend, otherwise forward_cid, forward_listen for vsock backend. + /// Example: + /// --vm guest-cid=3,socket=/tmp/vhost3.socket,uds-path=/tmp/vm3.vsock,tx-buffer-size=65536,queue-size=1024,groups=group1+group2 + /// --vm guest-cid=3,socket=/tmp/vhost3.socket,forward-cid=1,forward-listen=9001,queue-size=1024 + /// Multiple instances of this argument can be provided to configure devices for multiple guests. + #[cfg(feature = "backend_vsock")] #[arg(long, conflicts_with = "config", verbatim_doc_comment, value_parser = parse_vm_params)] vm: Option>, @@ -142,6 +193,11 @@ fn parse_vm_params(s: &str) -> Result { let mut queue_size = None; let mut groups = None; + #[cfg(feature = "backend_vsock")] + let mut forward_cid = None; + #[cfg(feature = "backend_vsock")] + let mut forward_listen: Option> = None; + for arg in s.trim().split(',') { let mut parts = arg.split('='); let key = parts.next().ok_or(VmArgsParseError::BadArgument)?; @@ -153,6 +209,16 @@ fn parse_vm_params(s: &str) -> Result { } "socket" => socket = Some(val.to_string()), "uds_path" | "uds-path" => uds_path = Some(val.to_string()), + + #[cfg(feature = "backend_vsock")] + "forward_cid" | "forward-cid" => { + forward_cid = Some(val.parse().map_err(VmArgsParseError::ParseInteger)?) + } + #[cfg(feature = "backend_vsock")] + "forward_listen" | "forward-listen" => { + forward_listen = Some(val.split('+').map(|s| s.parse().unwrap()).collect()) + } + "tx_buffer_size" | "tx-buffer-size" => { tx_buffer_size = Some(val.parse().map_err(VmArgsParseError::ParseInteger)?) } @@ -164,10 +230,42 @@ fn parse_vm_params(s: &str) -> Result { } } + #[cfg(feature = "backend_vsock")] + let backend_info = match (uds_path, forward_cid) { + (Some(path), None) => BackendType::UnixDomainSocket(path), + (None, Some(cid)) => { + let listen_ports: Vec = forward_listen.unwrap_or_default(); + BackendType::Vsock(VsockProxyInfo { + forward_cid: cid, + listen_ports, + }) + } + (None, None) => { + return Err(VmArgsParseError::RequiredKeyNotFound( + "uds-path or forward-cid".to_string(), + )) + } + _ => { + return Err(VmArgsParseError::RequiredKeyNotFound( + "Only one of uds-path or forward-cid can be provided".to_string(), + )) + } + }; + + #[cfg(not(feature = "backend_vsock"))] + let backend_info = match uds_path { + Some(path) => BackendType::UnixDomainSocket(path), + _ => { + return Err(VmArgsParseError::RequiredKeyNotFound( + "uds-path".to_string(), + )) + } + }; + Ok(VsockConfig::new( guest_cid.unwrap_or(DEFAULT_GUEST_CID), socket.ok_or_else(|| VmArgsParseError::RequiredKeyNotFound("socket".to_string()))?, - BackendType::UnixDomainSocket(uds_path.ok_or_else(|| VmArgsParseError::RequiredKeyNotFound("uds-path".to_string()))?), + backend_info.clone(), tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE), queue_size.unwrap_or(DEFAULT_QUEUE_SIZE), groups.unwrap_or(vec![DEFAULT_GROUP_NAME.to_string()]), @@ -184,21 +282,46 @@ impl VsockArgs { { let vms_param = config_map.get_mut("vms").unwrap(); if !vms_param.is_empty() { - let parsed: Vec = vms_param - .drain(..) - .map(|p| { - VsockConfig::new( - p.guest_cid.unwrap_or(DEFAULT_GUEST_CID), - p.socket.trim().to_string(), - BackendType::UnixDomainSocket(p.uds_path.trim().to_string()), - p.tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE), - p.queue_size.unwrap_or(DEFAULT_QUEUE_SIZE), - p.groups.map_or(vec![DEFAULT_GROUP_NAME.to_string()], |g| { - g.trim().split('+').map(String::from).collect() - }), - ) - }) - .collect(); + let mut parsed = Vec::new(); + for p in vms_param.drain(..) { + #[cfg(feature = "backend_vsock")] + let backend_info = match (p.uds_path, p.forward_cid) { + (Some(path), None) => { + BackendType::UnixDomainSocket(path.trim().to_string()) + } + (None, Some(cid)) => { + let listen_ports: Vec = match p.forward_listen { + None => Vec::new(), + Some(ports) => { + ports.split('+').map(|s| s.parse().unwrap()).collect() + } + }; + BackendType::Vsock(VsockProxyInfo { + forward_cid: cid, + listen_ports, + }) + } + _ => return Some(Err(CliError::ConfigParse)), + }; + + #[cfg(not(feature = "backend_vsock"))] + let backend_info = match p.uds_path { + Some(path) => BackendType::UnixDomainSocket(path.trim().to_string()), + _ => return Some(Err(CliError::ConfigParse)), + }; + + let config = VsockConfig::new( + p.guest_cid.unwrap_or(DEFAULT_GUEST_CID), + p.socket.trim().to_string(), + backend_info, + p.tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE), + p.queue_size.unwrap_or(DEFAULT_QUEUE_SIZE), + p.groups.map_or(vec![DEFAULT_GROUP_NAME.to_string()], |g| { + g.trim().split('+').map(String::from).collect() + }), + ); + parsed.push(config); + } return Some(Ok(parsed)); } else { return Some(Err(CliError::ConfigParse)); @@ -221,10 +344,36 @@ impl TryFrom for Vec { _ => match cmd_args.vm { Some(v) => Ok(v), _ => cmd_args.param.map_or(Err(CliError::NoArgsProvided), |p| { + #[cfg(feature = "backend_vsock")] + let backend_info = match (p.uds_path, p.forward_cid) { + (Some(path), None) => { + BackendType::UnixDomainSocket(path.trim().to_string()) + } + (None, Some(cid)) => { + let listen_ports: Vec = match p.forward_listen { + None => Vec::new(), + Some(ports) => { + ports.split('+').map(|s| s.parse().unwrap()).collect() + } + }; + BackendType::Vsock(VsockProxyInfo { + forward_cid: cid, + listen_ports, + }) + } + _ => return Err(CliError::ConfigParse), + }; + + #[cfg(not(feature = "backend_vsock"))] + let backend_info = match p.uds_path { + Some(path) => BackendType::UnixDomainSocket(path.trim().to_string()), + _ => return Err(CliError::ConfigParse), + }; + Ok(vec![VsockConfig::new( p.guest_cid, p.socket.trim().to_string(), - BackendType::UnixDomainSocket(p.uds_path.trim().to_string()), + backend_info, p.tx_buffer_size, p.queue_size, p.groups.trim().split('+').map(String::from).collect(), @@ -336,7 +485,7 @@ mod tests { use tempfile::tempdir; impl VsockArgs { - fn from_args( + fn from_args_unix( guest_cid: u64, socket: &str, uds_path: &str, @@ -348,7 +497,13 @@ mod tests { param: Some(VsockParam { guest_cid, socket: socket.to_string(), - uds_path: uds_path.to_string(), + uds_path: Some(uds_path.to_string()), + + #[cfg(feature = "backend_vsock")] + forward_cid: None, + #[cfg(feature = "backend_vsock")] + forward_listen: None, + tx_buffer_size, queue_size, groups: groups.to_string(), @@ -357,6 +512,33 @@ mod tests { config: None, } } + + #[cfg(feature = "backend_vsock")] + fn from_args_vsock( + guest_cid: u64, + socket: &str, + forward_cid: u32, + forward_listen: &str, + tx_buffer_size: u32, + queue_size: usize, + groups: &str, + ) -> Self { + VsockArgs { + param: Some(VsockParam { + guest_cid, + socket: socket.to_string(), + uds_path: None, + forward_cid: Some(forward_cid), + forward_listen: Some(forward_listen.to_string()), + tx_buffer_size, + queue_size, + groups: groups.to_string(), + }), + vm: None, + config: None, + } + } + fn from_file(config: &str) -> Self { VsockArgs { param: None, @@ -367,12 +549,12 @@ mod tests { } #[test] - fn test_vsock_config_setup() { + fn test_vsock_config_setup_unix() { let test_dir = tempdir().expect("Could not create a temp test directory."); let socket_path = test_dir.path().join("vhost4.socket").display().to_string(); let uds_path = test_dir.path().join("vm4.vsock").display().to_string(); - let args = VsockArgs::from_args(3, &socket_path, &uds_path, 64 * 1024, 1024, "group1"); + let args = VsockArgs::from_args_unix(3, &socket_path, &uds_path, 64 * 1024, 1024, "group1"); let configs = Vec::::try_from(args); assert!(configs.is_ok()); @@ -394,8 +576,40 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] #[test] - fn test_vsock_config_setup_from_vm_args() { + fn test_vsock_config_setup_vsock() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let socket_path = test_dir.path().join("vhost4.socket").display().to_string(); + let args = + VsockArgs::from_args_vsock(3, &socket_path, 1, "1234+4321", 64 * 1024, 1024, "group1"); + + let configs = Vec::::try_from(args); + assert!(configs.is_ok()); + + let configs = configs.unwrap(); + assert_eq!(configs.len(), 1); + + let config = &configs[0]; + assert_eq!(config.get_guest_cid(), 3); + assert_eq!(config.get_socket_path(), socket_path); + assert_eq!( + config.get_backend_info(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![1234, 4321] + }) + ); + assert_eq!(config.get_tx_buffer_size(), 64 * 1024); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec!["group1".to_string()]); + + test_dir.close().unwrap(); + } + + #[test] + fn test_vsock_config_setup_from_vm_args_unix() { let test_dir = tempdir().expect("Could not create a temp test directory."); let socket_paths = [ @@ -479,8 +693,114 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] #[test] - fn test_vsock_config_setup_from_file() { + fn test_vsock_config_setup_from_vm_args_vsock() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let socket_paths = [ + test_dir.path().join("vhost3.socket"), + test_dir.path().join("vhost4.socket"), + test_dir.path().join("vhost5.socket"), + test_dir.path().join("vhost6.socket"), + ]; + let uds_paths = [ + test_dir.path().join("vm3.vsock"), + test_dir.path().join("vm4.vsock"), + test_dir.path().join("vm5.vsock"), + ]; + let params = format!( + "--vm socket={vhost3_socket},uds_path={vm3_vsock} \ + --vm socket={vhost4_socket},uds-path={vm4_vsock},guest-cid=4,tx_buffer_size=65536,queue_size=1024,groups=group1 \ + --vm groups=group2+group3,guest-cid=5,socket={vhost5_socket},uds_path={vm5_vsock},tx-buffer-size=32768,queue_size=256 \ + --vm guest-cid=6,socket={vhost6_socket},forward-cid=1,forward-listen=1234+4321,queue-size=2048", + vhost3_socket = socket_paths[0].display(), + vhost4_socket = socket_paths[1].display(), + vhost5_socket = socket_paths[2].display(), + vhost6_socket = socket_paths[3].display(), + vm3_vsock = uds_paths[0].display(), + vm4_vsock = uds_paths[1].display(), + vm5_vsock = uds_paths[2].display(), + ); + + let mut params = params.split_whitespace().collect::>(); + params.insert(0, ""); // to make the test binary name agnostic + + let args = VsockArgs::parse_from(params); + + let configs = Vec::::try_from(args); + assert!(configs.is_ok()); + + let configs = configs.unwrap(); + assert_eq!(configs.len(), 4); + + let config = configs.first().unwrap(); + assert_eq!(config.get_guest_cid(), 3); + assert_eq!( + config.get_socket_path(), + socket_paths[0].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_paths[0].display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 65536); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + let config = configs.get(1).unwrap(); + assert_eq!(config.get_guest_cid(), 4); + assert_eq!( + config.get_socket_path(), + socket_paths[1].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_paths[1].display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 65536); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec!["group1".to_string()]); + + let config = configs.get(2).unwrap(); + assert_eq!(config.get_guest_cid(), 5); + assert_eq!( + config.get_socket_path(), + socket_paths[2].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_paths[2].display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 32768); + assert_eq!(config.get_queue_size(), 256); + assert_eq!( + config.get_groups(), + vec!["group2".to_string(), "group3".to_string()] + ); + + let config = configs.get(3).unwrap(); + assert_eq!(config.get_guest_cid(), 6); + assert_eq!( + config.get_socket_path(), + socket_paths[3].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![1234, 4321] + }) + ); + assert_eq!(config.get_tx_buffer_size(), 65536); + assert_eq!(config.get_queue_size(), 2048); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + test_dir.close().unwrap(); + } + + #[test] + fn test_vsock_config_setup_from_file_unix() { let test_dir = tempdir().expect("Could not create a temp test directory."); let config_path = test_dir.path().join("config.yaml"); @@ -555,8 +875,142 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] #[test] - fn test_vsock_server() { + fn test_vsock_config_setup_from_file_vsock() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let config_path = test_dir.path().join("config.yaml"); + let socket_path_unix = test_dir.path().join("vhost4.socket"); + let socket_path_vsock = test_dir.path().join("vhost5.socket"); + let uds_path = test_dir.path().join("vm4.vsock"); + + let mut yaml = File::create(&config_path).unwrap(); + yaml.write_all( + format!( + "vms: + - guest_cid: 4 + socket: {} + uds_path: {} + tx_buffer_size: 32768 + queue_size: 256 + groups: group1+group2 + - guest_cid: 5 + socket: {} + forward_cid: 1 + forward_listen: 1234+4321 + tx_buffer_size: 32768", + socket_path_unix.display(), + uds_path.display(), + socket_path_vsock.display(), + ) + .as_bytes(), + ) + .unwrap(); + let args = VsockArgs::from_file(&config_path.display().to_string()); + + let configs = Vec::::try_from(args).unwrap(); + assert_eq!(configs.len(), 2); + + let config = &configs[0]; + assert_eq!(config.get_guest_cid(), 4); + assert_eq!( + config.get_socket_path(), + socket_path_unix.display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_path.display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 32768); + assert_eq!(config.get_queue_size(), 256); + assert_eq!( + config.get_groups(), + vec!["group1".to_string(), "group2".to_string()] + ); + + let config = &configs[1]; + assert_eq!(config.get_guest_cid(), 5); + assert_eq!( + config.get_socket_path(), + socket_path_vsock.display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![1234, 4321] + }) + ); + assert_eq!(config.get_tx_buffer_size(), 32768); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + // Now test that optional parameters are correctly set to their default values. + let mut yaml = File::create(&config_path).unwrap(); + yaml.write_all( + format!( + "vms: + - socket: {} + uds_path: {}", + socket_path_unix.display(), + uds_path.display(), + ) + .as_bytes(), + ) + .unwrap(); + let args = VsockArgs::from_file(&config_path.display().to_string()); + + let configs = Vec::::try_from(args).unwrap(); + assert_eq!(configs.len(), 1); + + let config = &configs[0]; + assert_eq!(config.get_guest_cid(), DEFAULT_GUEST_CID); + assert_eq!( + config.get_socket_path(), + socket_path_unix.display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_path.display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), DEFAULT_TX_BUFFER_SIZE); + assert_eq!(config.get_queue_size(), DEFAULT_QUEUE_SIZE); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + std::fs::remove_file(&config_path).unwrap(); + test_dir.close().unwrap(); + } + + fn test_vsock_server(config: VsockConfig) { + let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); + + let backend = Arc::new(VhostUserVsockBackend::new(config, cid_map).unwrap()); + + let daemon = VhostUserDaemon::new( + String::from("vhost-device-vsock"), + backend.clone(), + GuestMemoryAtomic::new(GuestMemoryMmap::new()), + ) + .unwrap(); + + let mut epoll_handlers = daemon.get_epoll_handlers(); + + // VhostUserVsockBackend support a single thread that handles the TX and RX queues + assert_eq!(backend.threads.len(), 1); + + assert_eq!(epoll_handlers.len(), backend.threads.len()); + + for thread in backend.threads.iter() { + thread + .lock() + .unwrap() + .register_listeners(epoll_handlers.remove(0)); + } + } + + #[test] + fn test_vsock_server_unix() { const CID: u64 = 3; const CONN_TX_BUF_SIZE: u32 = 64 * 1024; const QUEUE_SIZE: usize = 1024; @@ -583,30 +1037,39 @@ mod tests { vec![DEFAULT_GROUP_NAME.to_string()], ); - let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); + test_vsock_server(config); - let backend = Arc::new(VhostUserVsockBackend::new(config, cid_map).unwrap()); + test_dir.close().unwrap(); + } - let daemon = VhostUserDaemon::new( - String::from("vhost-device-vsock"), - backend.clone(), - GuestMemoryAtomic::new(GuestMemoryMmap::new()), - ) - .unwrap(); + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_server_vsock() { + const CID: u64 = 3; + const CONN_TX_BUF_SIZE: u32 = 64 * 1024; + const QUEUE_SIZE: usize = 1024; - let mut epoll_handlers = daemon.get_epoll_handlers(); + let test_dir = tempdir().expect("Could not create a temp test directory."); - // VhostUserVsockBackend support a single thread that handles the TX and RX queues - assert_eq!(backend.threads.len(), 1); + let vhost_socket_path = test_dir + .path() + .join("test_vsock_server.socket") + .display() + .to_string(); - assert_eq!(epoll_handlers.len(), backend.threads.len()); + let config = VsockConfig::new( + CID, + vhost_socket_path, + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9000], + }), + CONN_TX_BUF_SIZE, + QUEUE_SIZE, + vec![DEFAULT_GROUP_NAME.to_string()], + ); - for thread in backend.threads.iter() { - thread - .lock() - .unwrap() - .register_listeners(epoll_handlers.remove(0)); - } + test_vsock_server(config); test_dir.close().unwrap(); } @@ -674,8 +1137,9 @@ mod tests { let _ = test_dir.close(); } + #[cfg(not(feature = "backend_vsock"))] #[test] - fn test_main_structs() { + fn test_main_structs_unix() { let error = parse_vm_params("").unwrap_err(); assert_matches!(error, VmArgsParseError::BadArgument); assert_eq!(format!("{error:?}"), "BadArgument"); @@ -689,21 +1153,76 @@ mod tests { assert_matches!(error, CliError::NoArgsProvided); assert_eq!(format!("{error:?}"), "NoArgsProvided"); - let args = VsockArgs::from_args(0, "", "", 0, 0, ""); - assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: \"\", tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); + let args = VsockArgs::from_args_unix(0, "", "", 0, 0, ""); + assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); let param = args.param.unwrap().clone(); - assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: \"\", tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); + assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); let config = ConfigFileVsockParam { guest_cid: None, socket: String::new(), - uds_path: String::new(), + uds_path: Some(String::new()), tx_buffer_size: None, queue_size: None, groups: None, } .clone(); - assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: \"\", tx_buffer_size: None, queue_size: None, groups: None }"); + assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: Some(\"\"), tx_buffer_size: None, queue_size: None, groups: None }"); + } + + #[cfg(feature = "backend_vsock")] + #[test] + fn test_main_structs_vsock() { + let error = parse_vm_params("").unwrap_err(); + assert_matches!(error, VmArgsParseError::BadArgument); + assert_eq!(format!("{error:?}"), "BadArgument"); + + let args = VsockArgs { + param: None, + vm: None, + config: None, + }; + let error = Vec::::try_from(args).unwrap_err(); + assert_matches!(error, CliError::NoArgsProvided); + assert_eq!(format!("{error:?}"), "NoArgsProvided"); + + let args = VsockArgs::from_args_unix(0, "", "", 0, 0, ""); + assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), forward_cid: None, forward_listen: None, tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); + + let param = args.param.unwrap().clone(); + assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), forward_cid: None, forward_listen: None, tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); + + let args = VsockArgs::from_args_vsock(0, "", 1, "", 0, 0, ""); + assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: None, forward_cid: Some(1), forward_listen: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); + + let param = args.param.unwrap().clone(); + assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: None, forward_cid: Some(1), forward_listen: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); + + let config = ConfigFileVsockParam { + guest_cid: None, + socket: String::new(), + uds_path: Some(String::new()), + forward_cid: None, + forward_listen: None, + tx_buffer_size: None, + queue_size: None, + groups: None, + } + .clone(); + assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: Some(\"\"), forward_cid: None, forward_listen: None, tx_buffer_size: None, queue_size: None, groups: None }"); + + let config = ConfigFileVsockParam { + guest_cid: None, + socket: String::new(), + uds_path: None, + forward_cid: Some(1), + forward_listen: Some(String::new()), + tx_buffer_size: None, + queue_size: None, + groups: None, + } + .clone(); + assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: None, forward_cid: Some(1), forward_listen: Some(\"\"), tx_buffer_size: None, queue_size: None, groups: None }"); } } diff --git a/vhost-device-vsock/src/thread_backend.rs b/vhost-device-vsock/src/thread_backend.rs index 9661adb..55f5b00 100644 --- a/vhost-device-vsock/src/thread_backend.rs +++ b/vhost-device-vsock/src/thread_backend.rs @@ -17,6 +17,8 @@ use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; use vm_memory::{ bitmap::BitmapSlice, ReadVolatile, VolatileMemoryError, VolatileSlice, WriteVolatile, }; +#[cfg(feature = "backend_vsock")] +use vsock::VsockStream; use crate::{ rxops::*, @@ -55,6 +57,8 @@ impl RawVsockPacket { pub(crate) enum StreamType { Unix(UnixStream), + #[cfg(feature = "backend_vsock")] + Vsock(VsockStream), } impl StreamType { @@ -64,6 +68,11 @@ impl StreamType { let cloned_stream = stream.try_clone()?; Ok(StreamType::Unix(cloned_stream)) } + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => { + let cloned_stream = stream.try_clone()?; + Ok(StreamType::Vsock(cloned_stream)) + } } } } @@ -72,6 +81,8 @@ impl Read for StreamType { fn read(&mut self, buf: &mut [u8]) -> StdIOResult { match self { StreamType::Unix(stream) => stream.read(buf), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.read(buf), } } } @@ -80,12 +91,16 @@ impl Write for StreamType { fn write(&mut self, buf: &[u8]) -> StdIOResult { match self { StreamType::Unix(stream) => stream.write(buf), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.write(buf), } } fn flush(&mut self) -> StdIOResult<()> { match self { StreamType::Unix(stream) => stream.flush(), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.flush(), } } } @@ -94,6 +109,8 @@ impl AsRawFd for StreamType { fn as_raw_fd(&self) -> RawFd { match self { StreamType::Unix(stream) => stream.as_raw_fd(), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.as_raw_fd(), } } } @@ -105,6 +122,30 @@ impl ReadVolatile for StreamType { ) -> StdResult { match self { StreamType::Unix(stream) => stream.read_volatile(buf), + // Copied from vm_memory crate's ReadVolatile implementation for UnixStream + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => { + let fd = stream.as_raw_fd(); + let guard = buf.ptr_guard_mut(); + + let dst = guard.as_ptr().cast::(); + + // SAFETY: We got a valid file descriptor from `AsRawFd`. The memory pointed to by `dst` is + // valid for writes of length `buf.len() by the invariants upheld by the constructor + // of `VolatileSlice`. + let bytes_read = unsafe { libc::read(fd, dst, buf.len()) }; + + if bytes_read < 0 { + // We don't know if a partial read might have happened, so mark everything as dirty + buf.bitmap().mark_dirty(0, buf.len()); + + Err(VolatileMemoryError::IOError(std::io::Error::last_os_error())) + } else { + let bytes_read = bytes_read.try_into().unwrap(); + buf.bitmap().mark_dirty(0, bytes_read); + Ok(bytes_read) + } + } } } } @@ -116,6 +157,25 @@ impl WriteVolatile for StreamType { ) -> StdResult { match self { StreamType::Unix(stream) => stream.write_volatile(buf), + // Copied from vm_memory crate's WriteVolatile implementation for UnixStream + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => { + let fd = stream.as_raw_fd(); + let guard = buf.ptr_guard(); + + let src = guard.as_ptr().cast::(); + + // SAFETY: We got a valid file descriptor from `AsRawFd`. The memory pointed to by `src` is + // valid for reads of length `buf.len() by the invariants upheld by the constructor + // of `VolatileSlice`. + let bytes_written = unsafe { libc::write(fd, src, buf.len()) }; + + if bytes_written < 0 { + Err(VolatileMemoryError::IOError(std::io::Error::last_os_error())) + } else { + Ok(bytes_written.try_into().unwrap()) + } + } } } } @@ -137,7 +197,7 @@ pub(crate) struct VsockThreadBackend { pub conn_map: HashMap>, /// Queue of ConnMapKey objects indicating pending rx operations. pub backend_rxq: VecDeque, - /// Map of host-side unix streams indexed by raw file descriptors. + /// Map of host-side unix or vsock streams indexed by raw file descriptors. pub stream_map: HashMap, /// Host side socket info for listening to new connections from the host. backend_info: BackendType, @@ -262,36 +322,39 @@ impl VsockThreadBackend { return Ok(()); } - let dst_cid = pkt.dst_cid(); - if dst_cid != VSOCK_HOST_CID { - let cid_map = self.cid_map.read().unwrap(); - if cid_map.contains_key(&dst_cid) { - let (sibling_raw_pkts_queue, sibling_groups_set, sibling_event_fd) = - cid_map.get(&dst_cid).unwrap(); + #[allow(irrefutable_let_patterns)] + if let BackendType::UnixDomainSocket(_) = &self.backend_info { + let dst_cid = pkt.dst_cid(); + if dst_cid != VSOCK_HOST_CID { + let cid_map = self.cid_map.read().unwrap(); + if cid_map.contains_key(&dst_cid) { + let (sibling_raw_pkts_queue, sibling_groups_set, sibling_event_fd) = + cid_map.get(&dst_cid).unwrap(); - if self - .groups_set - .read() - .unwrap() - .is_disjoint(sibling_groups_set.read().unwrap().deref()) - { - info!( - "vsock: dropping packet for cid: {:?} due to group mismatch", - dst_cid - ); - return Ok(()); + if self + .groups_set + .read() + .unwrap() + .is_disjoint(sibling_groups_set.read().unwrap().deref()) + { + info!( + "vsock: dropping packet for cid: {:?} due to group mismatch", + dst_cid + ); + return Ok(()); + } + + sibling_raw_pkts_queue + .write() + .unwrap() + .push_back(RawVsockPacket::from_vsock_packet(pkt)?); + let _ = sibling_event_fd.write(1); + } else { + warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid); } - sibling_raw_pkts_queue - .write() - .unwrap() - .push_back(RawVsockPacket::from_vsock_packet(pkt)?); - let _ = sibling_event_fd.write(1); - } else { - warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid); + return Ok(()); } - - return Ok(()); } // TODO: Rst if packet has unsupported type @@ -371,9 +434,11 @@ impl VsockThreadBackend { /// Handle a new guest initiated connection, i.e from the peer, the guest driver. /// - /// Attempts to connect to a host side unix socket listening on a path - /// corresponding to the destination port as follows: + /// In case of proxying using unix domain socket, attempts to connect to a host side unix socket + /// listening on a path corresponding to the destination port as follows: /// - "{self.host_sock_path}_{local_port}"" + /// + /// In case of proxying using vosck, attempts to connect to the {forward_cid, local_port} fn handle_new_guest_conn(&mut self, pkt: &VsockPacket) { match &self.backend_info { BackendType::UnixDomainSocket(uds_path) => { @@ -385,6 +450,14 @@ impl VsockThreadBackend { .and_then(|stream| self.add_new_guest_conn(StreamType::Unix(stream), pkt)) .unwrap_or_else(|_| self.enq_rst()); } + #[cfg(feature = "backend_vsock")] + BackendType::Vsock(vsock_info) => { + VsockStream::connect_with_cid_port(vsock_info.forward_cid, pkt.dst_port()) + .and_then(|stream| stream.set_nonblocking(true).map(|_| stream)) + .map_err(Error::VsockConnect) + .and_then(|stream| self.add_new_guest_conn(StreamType::Vsock(stream), pkt)) + .unwrap_or_else(|_| self.enq_rst()); + } } } @@ -397,6 +470,8 @@ impl VsockThreadBackend { let conn = VsockConnection::new_peer_init( stream.try_clone().map_err(match stream { StreamType::Unix(_) => Error::UnixConnect, + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(_) => Error::VsockConnect, })?, pkt.dst_cid(), pkt.dst_port(), @@ -436,28 +511,23 @@ impl VsockThreadBackend { #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "backend_vsock")] + use crate::vhu_vsock::VsockProxyInfo; use crate::vhu_vsock::{BackendType, VhostUserVsockBackend, VsockConfig, VSOCK_OP_RW}; use std::os::unix::net::UnixListener; use tempfile::tempdir; use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; + #[cfg(feature = "backend_vsock")] + use vsock::{VsockListener, VMADDR_CID_ANY}; const DATA_LEN: usize = 16; const CONN_TX_BUF_SIZE: u32 = 64 * 1024; const QUEUE_SIZE: usize = 1024; const GROUP_NAME: &str = "default"; + const VSOCK_PEER_PORT: u32 = 1234; - #[test] - fn test_vsock_thread_backend() { + fn test_vsock_thread_backend(backend_info: BackendType) { const CID: u64 = 3; - const VSOCK_PEER_PORT: u32 = 1234; - - let test_dir = tempdir().expect("Could not create a temp test directory."); - - let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock"); - let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234"); - - let _listener = UnixListener::bind(&vsock_peer_path).unwrap(); - let backend_info = BackendType::UnixDomainSocket(vsock_socket_path.display().to_string()); let epoll_fd = epoll::create(false).unwrap(); @@ -510,6 +580,19 @@ mod tests { // TODO: it is a nop for now vtp.enq_rst(); + } + + #[test] + fn test_vsock_thread_backend_unix() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock"); + let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234"); + + let _listener = UnixListener::bind(&vsock_peer_path).unwrap(); + let backend_info = BackendType::UnixDomainSocket(vsock_socket_path.display().to_string()); + + test_vsock_thread_backend(backend_info); // cleanup let _ = std::fs::remove_file(&vsock_peer_path); @@ -518,6 +601,18 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_thread_backend_vsock() { + let _listener = VsockListener::bind_with_cid_port(VMADDR_CID_ANY, VSOCK_PEER_PORT).unwrap(); + let backend_info = BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![], + }); + + test_vsock_thread_backend(backend_info); + } + #[test] fn test_vsock_thread_backend_sibling_vms() { const CID: u64 = 3; diff --git a/vhost-device-vsock/src/vhu_vsock.rs b/vhost-device-vsock/src/vhu_vsock.rs index fa07e5d..e2efe0e 100644 --- a/vhost-device-vsock/src/vhu_vsock.rs +++ b/vhost-device-vsock/src/vhu_vsock.rs @@ -118,8 +118,17 @@ pub(crate) enum Error { PktBufMissing, #[error("Failed to connect to unix socket")] UnixConnect(std::io::Error), - #[error("Unable to write to unix stream")] - UnixWrite, + #[cfg(feature = "backend_vsock")] + #[error("Failed to accept new local vsock socket connection")] + VsockAccept(std::io::Error), + #[cfg(feature = "backend_vsock")] + #[error("Failed to connect to vsock socket")] + VsockConnect(std::io::Error), + #[cfg(feature = "backend_vsock")] + #[error("Failed to bind a vsock stream")] + VsockBind(std::io::Error), + #[error("Unable to write to stream")] + StreamWrite, #[error("Unable to push data to local tx buffer")] LocalTxBufFull, #[error("Unable to flush data from local tx buffer")] @@ -142,10 +151,20 @@ impl std::convert::From for std::io::Error { } } +#[cfg(feature = "backend_vsock")] +#[derive(Debug, PartialEq, Clone)] +pub(crate) struct VsockProxyInfo { + pub forward_cid: u32, + pub listen_ports: Vec, +} + #[derive(Debug, PartialEq, Clone)] pub(crate) enum BackendType { /// unix domain socket path UnixDomainSocket(String), + /// the vsock CID and ports + #[cfg(feature = "backend_vsock")] + Vsock(VsockProxyInfo), } #[derive(Debug, Clone)] @@ -390,34 +409,7 @@ mod tests { const CONN_TX_BUF_SIZE: u32 = 64 * 1024; const QUEUE_SIZE: usize = 1024; - #[test] - fn test_vsock_backend() { - const CID: u64 = 3; - - let groups_list: Vec = vec![String::from("default")]; - - let test_dir = tempdir().expect("Could not create a temp test directory."); - - let vhost_socket_path = test_dir - .path() - .join("test_vsock_backend.socket") - .display() - .to_string(); - let vsock_socket_path = test_dir - .path() - .join("test_vsock_backend.vsock") - .display() - .to_string(); - - let config = VsockConfig::new( - CID, - vhost_socket_path.to_string(), - BackendType::UnixDomainSocket(vsock_socket_path.to_string()), - CONN_TX_BUF_SIZE, - QUEUE_SIZE, - groups_list, - ); - + fn test_vsock_backend(config: VsockConfig, expected_cid: u64) { let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); let backend = VhostUserVsockBackend::new(config, cid_map); @@ -452,7 +444,7 @@ mod tests { let config = backend.get_config(0, 8); assert_eq!(config.len(), 8); let cid = u64::from_le_bytes(config.try_into().unwrap()); - assert_eq!(cid, CID); + assert_eq!(cid, expected_cid); let exit = backend.exit_event(0); assert!(exit.is_some()); @@ -469,11 +461,74 @@ mod tests { let ret = backend.handle_event(BACKEND_EVENT, EventSet::IN, &vrings, 0); assert!(ret.is_ok()); + } + + #[test] + fn test_vsock_backend_unix() { + const CID: u64 = 3; + + let groups_list: Vec = vec![String::from("default")]; + + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let vhost_socket_path = test_dir + .path() + .join("test_vsock_backend_unix.socket") + .display() + .to_string(); + let vsock_socket_path = test_dir + .path() + .join("test_vsock_backend.vsock") + .display() + .to_string(); + + let config = VsockConfig::new( + CID, + vhost_socket_path.to_string(), + BackendType::UnixDomainSocket(vsock_socket_path.to_string()), + CONN_TX_BUF_SIZE, + QUEUE_SIZE, + groups_list, + ); + + test_vsock_backend(config, CID); // cleanup let _ = std::fs::remove_file(vhost_socket_path); let _ = std::fs::remove_file(vsock_socket_path); + test_dir.close().unwrap(); + } + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_backend_vsock() { + const CID: u64 = 3; + + let groups_list: Vec = vec![String::from("default")]; + + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let vhost_socket_path = test_dir + .path() + .join("test_vsock_backend.socket") + .display() + .to_string(); + let config = VsockConfig::new( + CID, + vhost_socket_path.to_string(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9001, 9002], + }), + CONN_TX_BUF_SIZE, + QUEUE_SIZE, + groups_list, + ); + + test_vsock_backend(config, CID); + + // cleanup + let _ = std::fs::remove_file(vhost_socket_path); test_dir.close().unwrap(); } @@ -558,7 +613,7 @@ mod tests { #[test] fn test_vhu_vsock_structs() { - let config = VsockConfig::new( + let unix_config = VsockConfig::new( 0, String::new(), BackendType::UnixDomainSocket(String::new()), @@ -566,8 +621,22 @@ mod tests { 0, vec![String::new()], ); + assert_eq!(format!("{unix_config:?}"), "VsockConfig { guest_cid: 0, socket: \"\", backend_info: UnixDomainSocket(\"\"), tx_buffer_size: 0, queue_size: 0, groups: [\"\"] }"); - assert_eq!(format!("{config:?}"), "VsockConfig { guest_cid: 0, socket: \"\", backend_info: UnixDomainSocket(\"\"), tx_buffer_size: 0, queue_size: 0, groups: [\"\"] }"); + #[cfg(feature = "backend_vsock")] + let vsock_config = VsockConfig::new( + 0, + String::new(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9001, 9002], + }), + 0, + 0, + vec![String::new()], + ); + #[cfg(feature = "backend_vsock")] + assert_eq!(format!("{vsock_config:?}"), "VsockConfig { guest_cid: 0, socket: \"\", backend_info: Vsock(VsockProxyInfo { forward_cid: 1, listen_ports: [9001, 9002] }), tx_buffer_size: 0, queue_size: 0, groups: [\"\"] }"); let conn_map = ConnMapKey::new(0, 0); assert_eq!( diff --git a/vhost-device-vsock/src/vhu_vsock_thread.rs b/vhost-device-vsock/src/vhu_vsock_thread.rs index 8acf724..4ce45d0 100644 --- a/vhost-device-vsock/src/vhu_vsock_thread.rs +++ b/vhost-device-vsock/src/vhu_vsock_thread.rs @@ -18,6 +18,8 @@ use std::{ thread, }; +#[cfg(feature = "backend_vsock")] +use log::error; use log::warn; use vhost_user_backend::{VringEpollHandler, VringRwLock, VringT}; use virtio_queue::QueueOwnedT; @@ -27,6 +29,8 @@ use vmm_sys_util::{ epoll::EventSet, eventfd::{EventFd, EFD_NONBLOCK}, }; +#[cfg(feature = "backend_vsock")] +use vsock::{VsockListener, VMADDR_CID_ANY}; use crate::{ rxops::*, @@ -55,6 +59,8 @@ struct EventData { enum ListenerType { Unix(UnixListener), + #[cfg(feature = "backend_vsock")] + Vsock(VsockListener), } pub(crate) struct VhostUserVsockThread { @@ -105,6 +111,16 @@ impl VhostUserVsockThread { let host_sock = host_listener.as_raw_fd(); host_listeners_map.insert(host_sock, ListenerType::Unix(host_listener)); } + #[cfg(feature = "backend_vsock")] + BackendType::Vsock(vsock_info) => { + for p in &vsock_info.listen_ports { + let host_listener = VsockListener::bind_with_cid_port(VMADDR_CID_ANY, *p) + .and_then(|sock| sock.set_nonblocking(true).map(|_| sock)) + .map_err(Error::VsockBind)?; + let host_sock = host_listener.as_raw_fd(); + host_listeners_map.insert(host_sock, ListenerType::Vsock(host_listener)); + } + } } let epoll_fd = epoll::create(true).map_err(Error::EpollFdCreate)?; @@ -150,6 +166,7 @@ impl VhostUserVsockThread { }; Self::vring_handle_event(event_data); }); + let thread = VhostUserVsockThread { mem: None, event_idx: false, @@ -312,6 +329,51 @@ impl VhostUserVsockThread { }); } } + #[cfg(feature = "backend_vsock")] + ListenerType::Vsock(vsock_listener) => { + let conn = vsock_listener.accept().map_err(Error::VsockAccept); + if self.mem.is_some() { + match conn { + Ok((stream, addr)) => { + if let Err(err) = stream.set_nonblocking(true) { + warn!("Failed to set stream to non-blocking: {:?}", err); + return; + } + + let peer_port = match vsock_listener.local_addr() { + Ok(listener_addr) => listener_addr.port(), + Err(err) => { + warn!("Failed to get peer address: {:?}", err); + return; + } + }; + + let local_port = addr.port(); + let stream_raw_fd = stream.as_raw_fd(); + self.add_new_connection_from_host( + stream_raw_fd, + StreamType::Vsock(stream), + local_port, + peer_port, + ); + if let Err(err) = Self::epoll_register( + self.get_epoll_fd(), + stream_raw_fd, + epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, + ) { + warn!("Failed to register with epoll: {:?}", err); + } + } + Err(err) => { + warn!("Unable to accept new local connection: {:?}", err); + } + } + } else { + conn.map(drop).unwrap_or_else(|err| { + warn!("Error closing an incoming connection: {:?}", err); + }); + } + } } } else { // Check if the stream represented by fd has already established a @@ -333,6 +395,10 @@ impl VhostUserVsockThread { }; match stream { + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(_) => { + error!("Stream type should not be of type vsock"); + } StreamType::Unix(ref mut unix_stream) => { // Local peer is sending a "connect PORT\n" command let peer_port = match Self::read_local_stream_port(unix_stream) { @@ -726,6 +792,10 @@ impl Drop for VhostUserVsockThread { BackendType::UnixDomainSocket(uds_path) => { let _ = std::fs::remove_file(uds_path); } + #[cfg(feature = "backend_vsock")] + BackendType::Vsock(_) => { + // Nothing to do + } } self.thread_backend .cid_map @@ -737,12 +807,16 @@ impl Drop for VhostUserVsockThread { #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "backend_vsock")] + use crate::vhu_vsock::VsockProxyInfo; use std::collections::HashMap; use std::io::Read; use std::io::Write; use tempfile::tempdir; use vm_memory::GuestAddress; use vmm_sys_util::eventfd::EventFd; + #[cfg(feature = "backend_vsock")] + use vsock::VsockStream; const CONN_TX_BUF_SIZE: u32 = 64 * 1024; @@ -752,28 +826,12 @@ mod tests { } } - #[test] - fn test_vsock_thread() { + fn test_vsock_thread(backend_info: BackendType) { let groups: Vec = vec![String::from("default")]; let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); - let test_dir = tempdir().expect("Could not create a temp test directory."); - let backend_info = BackendType::UnixDomainSocket( - test_dir - .path() - .join("test_vsock_thread.vsock") - .display() - .to_string(), - ); - - let t = VhostUserVsockThread::new( - backend_info, - 3, - CONN_TX_BUF_SIZE, - groups, - cid_map, - ); + let t = VhostUserVsockThread::new(backend_info, 3, CONN_TX_BUF_SIZE, groups, cid_map); assert!(t.is_ok()); let mut t = t.unwrap(); @@ -838,10 +896,32 @@ mod tests { dummy_fd.write(1).unwrap(); t.process_backend_evt(EventSet::empty()); + } + #[test] + fn test_vsock_thread_unix() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + let backend_info = BackendType::UnixDomainSocket( + test_dir + .path() + .join("test_vsock_thread.vsock") + .display() + .to_string(), + ); + test_vsock_thread(backend_info); test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_thread_vsock() { + let backend_info = BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![], + }); + test_vsock_thread(backend_info); + } + #[test] fn test_vsock_thread_failures() { let groups: Vec = vec![String::from("default")]; @@ -910,8 +990,9 @@ mod tests { test_dir.close().unwrap(); } + #[test] - fn test_vsock_thread_unix() { + fn test_vsock_thread_unix_backend() { let groups: Vec = vec![String::from("default")]; let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); @@ -956,4 +1037,47 @@ mod tests { test_dir.close().unwrap(); } + + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_thread_vsock_backend() { + let groups: Vec = vec![String::from("default")]; + let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); + + let t = VhostUserVsockThread::new( + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9003, 9004], + }), + 3, + CONN_TX_BUF_SIZE, + groups, + cid_map, + ); + + let mut t = t.unwrap(); + + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(), + ); + + t.mem = Some(mem.clone()); + + let mut vs1 = VsockStream::connect_with_cid_port(1, 9003).unwrap(); + let mut vs2 = VsockStream::connect_with_cid_port(1, 9004).unwrap(); + t.process_backend_evt(EventSet::empty()); + + vs1.write_all(b"some data").unwrap(); + vs2.write_all(b"some data").unwrap(); + t.process_backend_evt(EventSet::empty()); + + let mut buf = vec![0u8; 16]; + vs1.set_nonblocking(true).unwrap(); + vs2.set_nonblocking(true).unwrap(); + // There isn't any peer responding, so we don't expect data + vs1.read(&mut buf).unwrap_err(); + vs2.read(&mut buf).unwrap_err(); + + t.process_backend_evt(EventSet::empty()); + } } diff --git a/vhost-device-vsock/src/vsock_conn.rs b/vhost-device-vsock/src/vsock_conn.rs index 3762e7a..beb3a01 100644 --- a/vhost-device-vsock/src/vsock_conn.rs +++ b/vhost-device-vsock/src/vsock_conn.rs @@ -301,7 +301,7 @@ impl VsockCon /// /// Returns: /// - Ok(cnt) where cnt is the number of bytes written to the stream - /// - Err(Error::UnixWrite) if there was an error writing to the stream + /// - Err(Error::StreamWrite) if there was an error writing to the stream fn send_bytes(&mut self, buf: &VolatileSlice) -> Result<()> { if !self.tx_buf.is_empty() { // Data is already present in the buffer and the backend @@ -318,12 +318,12 @@ impl VsockCon 0 } else { dbg!("send_bytes error: {:?}", e); - return Err(Error::UnixWrite); + return Err(Error::StreamWrite); } } Err(e) => { dbg!("send_bytes error: {:?}", e); - return Err(Error::UnixWrite); + return Err(Error::StreamWrite); } };