diff --git a/Cargo.lock b/Cargo.lock index 08b36e2..49fdd31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -739,6 +739,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -755,7 +764,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "memoffset", + "memoffset 0.6.5", ] [[package]] @@ -779,6 +788,7 @@ dependencies = [ "cfg-if", "cfg_aliases", "libc", + "memoffset 0.9.1", ] [[package]] @@ -1527,6 +1537,7 @@ dependencies = [ "env_logger", "epoll", "figment", + "libc", "log", "serde", "tempfile", @@ -1538,6 +1549,7 @@ dependencies = [ "virtio-vsock", "vm-memory", "vmm-sys-util", + "vsock", ] [[package]] @@ -1608,6 +1620,16 @@ dependencies = [ "libc", ] +[[package]] +name = "vsock" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e8b4d00e672f147fc86a09738fadb1445bd1c0a40542378dfb82909deeee688" +dependencies = [ + "libc", + "nix 0.29.0", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/vhost-device-vsock/CHANGELOG.md b/vhost-device-vsock/CHANGELOG.md index 235c199..a3b35b6 100644 --- a/vhost-device-vsock/CHANGELOG.md +++ b/vhost-device-vsock/CHANGELOG.md @@ -3,6 +3,7 @@ ### Added - [#698](https://github.com/rust-vmm/vhost-device/pull/698) vsock: add mdoc page +- [#706](https://github.com/rust-vmm/vhost-device/pull/706) Support proxying using vsock ### Changed diff --git a/vhost-device-vsock/Cargo.toml b/vhost-device-vsock/Cargo.toml index a356347..345d249 100644 --- a/vhost-device-vsock/Cargo.toml +++ b/vhost-device-vsock/Cargo.toml @@ -10,6 +10,8 @@ license = "Apache-2.0 OR BSD-3-Clause" edition = "2021" [features] +default = ["backend_vsock"] +backend_vsock = ["vsock", "libc"] xen = ["vm-memory/xen", "vhost/xen", "vhost-user-backend/xen"] [dependencies] @@ -27,6 +29,8 @@ virtio-vsock = "0.6" vm-memory = "0.14.1" vmm-sys-util = "0.12" figment = { version = "0.10.19", features = ["yaml"] } +vsock = { version = "0.5.0", optional = true } +libc = { version = "0.2.39", optional = true } serde = { version = "1", features = ["derive"] } [dev-dependencies] diff --git a/vhost-device-vsock/README.md b/vhost-device-vsock/README.md index c58f3d6..8c8b915 100644 --- a/vhost-device-vsock/README.md +++ b/vhost-device-vsock/README.md @@ -6,7 +6,10 @@ The crate introduces a vhost-device-vsock device that enables communication betw application running in the guest i.e inside a VM and an application running on the host i.e outside the VM. The application running in the guest communicates over VM sockets i.e over AF_VSOCK sockets. The application running on the host connects to a -unix socket on the host i.e communicates over AF_UNIX sockets. The main components of +unix socket on the host i.e communicates over AF_UNIX sockets when using the unix domain +socket backend through the uds-path option or the application in the host listens or +connects to vsock on the host i.e communicates over AF_VSOCK sockets when using the +vsock backend through the forward-cid, forward-listen options. The main components of the crate are split into various files as described below: - [packet.rs](src/packet.rs) @@ -38,7 +41,7 @@ the crate are split into various files as described below: ## Usage -Run the vhost-device-vsock device: +Run the vhost-device-vsock device with unix domain socket backend: ``` vhost-device-vsock --guest-cid= \ --socket= \ @@ -52,11 +55,26 @@ or vhost-device-vsock --vm guest_cid=,socket=,uds-path=[,tx-buffer-size=host packets)>][,queue-size=][,groups=] ``` +Run the vhost-device-vsock device with vsock backend: +``` +vhost-device-vsock --guest-cid= \ + --socket= \ + --forward-cid= \ + [--forward-listen= \ + [--tx-buffer-size=host packets)>] \ + [--queue-size=] \ +``` +or +``` +vhost-device-vsock --vm guest_cid=,socket=,forward-cid=[,forward-listen=][,tx-buffer-size=host packets)>][,queue-size=][,groups=] +``` + Specify the `--vm` argument multiple times to specify multiple devices like this: ``` vhost-device-vsock \ --vm guest-cid=3,socket=/tmp/vhost3.socket,uds-path=/tmp/vm3.vsock,groups=group1+groupA \ ---vm guest-cid=4,socket=/tmp/vhost4.socket,uds-path=/tmp/vm4.vsock,tx-buffer-size=32768,queue-size=256 +--vm guest-cid=4,socket=/tmp/vhost4.socket,uds-path=/tmp/vm4.vsock,tx-buffer-size=32768,queue-size=256 \ +--vm guest-cid=5,socket=/tmp/vhost5.socket,forward-cid=1,forward-listen=9001+9002,tx-buffer-size=32768,queue-size=1024 ``` Or use a configuration file: @@ -79,6 +97,12 @@ vms: tx_buffer_size: 32768 queue_size: 256 groups: group2+groupB + - guest_cid: 5 + socket: /tmp/vhost5.socket + forward-cid: 1 + forward-listen: 9001+9002 + tx_buffer_size: 32768 + queue_size: 1024 ``` Run VMM (e.g. QEMU): @@ -185,6 +209,38 @@ guest_cid3$ nc-vsock -l 1234 guest_cid4$ nc-vsock 3 1234 ``` +### Using the vsock backend + +The vsock backend is available under the `backend_vsock` feature (enabled by default). If you want to test a guest VM that +has built-in applications which communicate with another VM over AF_VSOCK, you can forward the connections from the guest +to the host machine instead of running a separate VM for easier testing using the forward-cid option. In such a case, you +would run the corresponding applications that listen for or connect with applications in the guest VM using AF_VSOCK in the +host instead of running the separate VM. For forwarding AF_VSOCK connections from the host, you can use the forward-listen +option. + +For example, if the guest VM that you want to test has an application that connects to (CID 3, port 9000) upon boot and applications +that listen on port 9001 and 9002 for connections, first run vhost-device-vsock: + +```sh +shell1$ vhost-device-vsock --vm guest-cid=4,forward-cid=1,forward-listen=9001+9002,socket=/tmp/vhost4.socket +``` + +Now run the application listening for connections to port 9000 on the host machine and then run the guest VM: + +```sh +shell2$ qemu-system-x86_64 \ + -drive file=vm1.qcow2,format=qcow2,if=virtio -smp 2 \ + -object memory-backend-memfd,id=mem0,size=512M \ + -machine q35,accel=kvm,memory-backend=mem0 \ + -chardev socket,id=char0,reconnect=0,path=/tmp/vhost4.socket \ + -device vhost-user-vsock-pci,chardev=char0 +``` + +After the guest VM boots, the application inside the guest connecting to (CID 3, port 9000) should successfully connect to the +application running on the host. Assuming the applications listening on port 9001 and 9002 are running in the guest VM, you can +now run the applications that connect to port 9001 and 9002 (you need to modify the CID they connect to be the host CID i.e. 1) +on the host machine. + ## License This project is licensed under either of diff --git a/vhost-device-vsock/src/main.rs b/vhost-device-vsock/src/main.rs index 0536841..b8ed4bc 100644 --- a/vhost-device-vsock/src/main.rs +++ b/vhost-device-vsock/src/main.rs @@ -17,6 +17,8 @@ use std::{ thread, }; +#[cfg(feature = "backend_vsock")] +use crate::vhu_vsock::VsockProxyInfo; use crate::vhu_vsock::{BackendType, CidMap, VhostUserVsockBackend, VsockConfig}; use clap::{Args, Parser}; use figment::{ @@ -82,8 +84,40 @@ struct VsockParam { socket: String, /// Unix socket to which a host-side application connects to. + #[cfg(not(feature = "backend_vsock"))] #[arg(long, conflicts_with = "config", conflicts_with = "vm")] - uds_path: String, + uds_path: Option, + + /// Unix socket to which a host-side application connects to. + #[cfg(feature = "backend_vsock")] + #[arg( + long, + conflicts_with = "forward_cid", + conflicts_with = "forward_listen", + conflicts_with = "config", + conflicts_with = "vm" + )] + uds_path: Option, + + /// The vsock CID to forward connections from guest + #[cfg(feature = "backend_vsock")] + #[clap( + long, + conflicts_with = "uds_path", + conflicts_with = "config", + conflicts_with = "vm" + )] + forward_cid: Option, + + /// The vsock ports to forward connections from host + #[cfg(feature = "backend_vsock")] + #[clap( + long, + conflicts_with = "uds_path", + conflicts_with = "config", + conflicts_with = "vm" + )] + forward_listen: Option, /// The size of the buffer used for the TX virtqueue #[clap(long, default_value_t = DEFAULT_TX_BUFFER_SIZE, conflicts_with = "config", conflicts_with = "vm")] @@ -109,7 +143,11 @@ struct VsockParam { struct ConfigFileVsockParam { guest_cid: Option, socket: String, - uds_path: String, + uds_path: Option, + #[cfg(feature = "backend_vsock")] + forward_cid: Option, + #[cfg(feature = "backend_vsock")] + forward_listen: Option, tx_buffer_size: Option, queue_size: Option, groups: Option, @@ -126,6 +164,19 @@ struct VsockArgs { /// Example: /// --vm guest-cid=3,socket=/tmp/vhost3.socket,uds-path=/tmp/vm3.vsock,tx-buffer-size=65536,queue-size=1024,groups=group1+group2 /// Multiple instances of this argument can be provided to configure devices for multiple guests. + #[cfg(not(feature = "backend_vsock"))] + #[arg(long, conflicts_with = "config", verbatim_doc_comment, value_parser = parse_vm_params)] + vm: Option>, + + /// Device parameters corresponding to a VM in the form of comma separated key=value pairs. + /// The allowed keys are: guest_cid, socket, uds_path, forward_cid, forward_listen, tx_buffer_size, queue_size and group. + /// uds_path and (forward_cid, forward_listen) are mutually exclusive. Use uds_path when you want unix domain socket + /// backend, otherwise forward_cid, forward_listen for vsock backend. + /// Example: + /// --vm guest-cid=3,socket=/tmp/vhost3.socket,uds-path=/tmp/vm3.vsock,tx-buffer-size=65536,queue-size=1024,groups=group1+group2 + /// --vm guest-cid=3,socket=/tmp/vhost3.socket,forward-cid=1,forward-listen=9001,queue-size=1024 + /// Multiple instances of this argument can be provided to configure devices for multiple guests. + #[cfg(feature = "backend_vsock")] #[arg(long, conflicts_with = "config", verbatim_doc_comment, value_parser = parse_vm_params)] vm: Option>, @@ -142,6 +193,11 @@ fn parse_vm_params(s: &str) -> Result { let mut queue_size = None; let mut groups = None; + #[cfg(feature = "backend_vsock")] + let mut forward_cid = None; + #[cfg(feature = "backend_vsock")] + let mut forward_listen: Option> = None; + for arg in s.trim().split(',') { let mut parts = arg.split('='); let key = parts.next().ok_or(VmArgsParseError::BadArgument)?; @@ -153,6 +209,16 @@ fn parse_vm_params(s: &str) -> Result { } "socket" => socket = Some(val.to_string()), "uds_path" | "uds-path" => uds_path = Some(val.to_string()), + + #[cfg(feature = "backend_vsock")] + "forward_cid" | "forward-cid" => { + forward_cid = Some(val.parse().map_err(VmArgsParseError::ParseInteger)?) + } + #[cfg(feature = "backend_vsock")] + "forward_listen" | "forward-listen" => { + forward_listen = Some(val.split('+').map(|s| s.parse().unwrap()).collect()) + } + "tx_buffer_size" | "tx-buffer-size" => { tx_buffer_size = Some(val.parse().map_err(VmArgsParseError::ParseInteger)?) } @@ -164,10 +230,42 @@ fn parse_vm_params(s: &str) -> Result { } } + #[cfg(feature = "backend_vsock")] + let backend_info = match (uds_path, forward_cid) { + (Some(path), None) => BackendType::UnixDomainSocket(path), + (None, Some(cid)) => { + let listen_ports: Vec = forward_listen.unwrap_or_default(); + BackendType::Vsock(VsockProxyInfo { + forward_cid: cid, + listen_ports, + }) + } + (None, None) => { + return Err(VmArgsParseError::RequiredKeyNotFound( + "uds-path or forward-cid".to_string(), + )) + } + _ => { + return Err(VmArgsParseError::RequiredKeyNotFound( + "Only one of uds-path or forward-cid can be provided".to_string(), + )) + } + }; + + #[cfg(not(feature = "backend_vsock"))] + let backend_info = match uds_path { + Some(path) => BackendType::UnixDomainSocket(path), + _ => { + return Err(VmArgsParseError::RequiredKeyNotFound( + "uds-path".to_string(), + )) + } + }; + Ok(VsockConfig::new( guest_cid.unwrap_or(DEFAULT_GUEST_CID), socket.ok_or_else(|| VmArgsParseError::RequiredKeyNotFound("socket".to_string()))?, - BackendType::UnixDomainSocket(uds_path.ok_or_else(|| VmArgsParseError::RequiredKeyNotFound("uds-path".to_string()))?), + backend_info.clone(), tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE), queue_size.unwrap_or(DEFAULT_QUEUE_SIZE), groups.unwrap_or(vec![DEFAULT_GROUP_NAME.to_string()]), @@ -184,21 +282,46 @@ impl VsockArgs { { let vms_param = config_map.get_mut("vms").unwrap(); if !vms_param.is_empty() { - let parsed: Vec = vms_param - .drain(..) - .map(|p| { - VsockConfig::new( - p.guest_cid.unwrap_or(DEFAULT_GUEST_CID), - p.socket.trim().to_string(), - BackendType::UnixDomainSocket(p.uds_path.trim().to_string()), - p.tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE), - p.queue_size.unwrap_or(DEFAULT_QUEUE_SIZE), - p.groups.map_or(vec![DEFAULT_GROUP_NAME.to_string()], |g| { - g.trim().split('+').map(String::from).collect() - }), - ) - }) - .collect(); + let mut parsed = Vec::new(); + for p in vms_param.drain(..) { + #[cfg(feature = "backend_vsock")] + let backend_info = match (p.uds_path, p.forward_cid) { + (Some(path), None) => { + BackendType::UnixDomainSocket(path.trim().to_string()) + } + (None, Some(cid)) => { + let listen_ports: Vec = match p.forward_listen { + None => Vec::new(), + Some(ports) => { + ports.split('+').map(|s| s.parse().unwrap()).collect() + } + }; + BackendType::Vsock(VsockProxyInfo { + forward_cid: cid, + listen_ports, + }) + } + _ => return Some(Err(CliError::ConfigParse)), + }; + + #[cfg(not(feature = "backend_vsock"))] + let backend_info = match p.uds_path { + Some(path) => BackendType::UnixDomainSocket(path.trim().to_string()), + _ => return Some(Err(CliError::ConfigParse)), + }; + + let config = VsockConfig::new( + p.guest_cid.unwrap_or(DEFAULT_GUEST_CID), + p.socket.trim().to_string(), + backend_info, + p.tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE), + p.queue_size.unwrap_or(DEFAULT_QUEUE_SIZE), + p.groups.map_or(vec![DEFAULT_GROUP_NAME.to_string()], |g| { + g.trim().split('+').map(String::from).collect() + }), + ); + parsed.push(config); + } return Some(Ok(parsed)); } else { return Some(Err(CliError::ConfigParse)); @@ -221,10 +344,36 @@ impl TryFrom for Vec { _ => match cmd_args.vm { Some(v) => Ok(v), _ => cmd_args.param.map_or(Err(CliError::NoArgsProvided), |p| { + #[cfg(feature = "backend_vsock")] + let backend_info = match (p.uds_path, p.forward_cid) { + (Some(path), None) => { + BackendType::UnixDomainSocket(path.trim().to_string()) + } + (None, Some(cid)) => { + let listen_ports: Vec = match p.forward_listen { + None => Vec::new(), + Some(ports) => { + ports.split('+').map(|s| s.parse().unwrap()).collect() + } + }; + BackendType::Vsock(VsockProxyInfo { + forward_cid: cid, + listen_ports, + }) + } + _ => return Err(CliError::ConfigParse), + }; + + #[cfg(not(feature = "backend_vsock"))] + let backend_info = match p.uds_path { + Some(path) => BackendType::UnixDomainSocket(path.trim().to_string()), + _ => return Err(CliError::ConfigParse), + }; + Ok(vec![VsockConfig::new( p.guest_cid, p.socket.trim().to_string(), - BackendType::UnixDomainSocket(p.uds_path.trim().to_string()), + backend_info, p.tx_buffer_size, p.queue_size, p.groups.trim().split('+').map(String::from).collect(), @@ -336,7 +485,7 @@ mod tests { use tempfile::tempdir; impl VsockArgs { - fn from_args( + fn from_args_unix( guest_cid: u64, socket: &str, uds_path: &str, @@ -348,7 +497,13 @@ mod tests { param: Some(VsockParam { guest_cid, socket: socket.to_string(), - uds_path: uds_path.to_string(), + uds_path: Some(uds_path.to_string()), + + #[cfg(feature = "backend_vsock")] + forward_cid: None, + #[cfg(feature = "backend_vsock")] + forward_listen: None, + tx_buffer_size, queue_size, groups: groups.to_string(), @@ -357,6 +512,33 @@ mod tests { config: None, } } + + #[cfg(feature = "backend_vsock")] + fn from_args_vsock( + guest_cid: u64, + socket: &str, + forward_cid: u32, + forward_listen: &str, + tx_buffer_size: u32, + queue_size: usize, + groups: &str, + ) -> Self { + VsockArgs { + param: Some(VsockParam { + guest_cid, + socket: socket.to_string(), + uds_path: None, + forward_cid: Some(forward_cid), + forward_listen: Some(forward_listen.to_string()), + tx_buffer_size, + queue_size, + groups: groups.to_string(), + }), + vm: None, + config: None, + } + } + fn from_file(config: &str) -> Self { VsockArgs { param: None, @@ -367,12 +549,12 @@ mod tests { } #[test] - fn test_vsock_config_setup() { + fn test_vsock_config_setup_unix() { let test_dir = tempdir().expect("Could not create a temp test directory."); let socket_path = test_dir.path().join("vhost4.socket").display().to_string(); let uds_path = test_dir.path().join("vm4.vsock").display().to_string(); - let args = VsockArgs::from_args(3, &socket_path, &uds_path, 64 * 1024, 1024, "group1"); + let args = VsockArgs::from_args_unix(3, &socket_path, &uds_path, 64 * 1024, 1024, "group1"); let configs = Vec::::try_from(args); assert!(configs.is_ok()); @@ -394,8 +576,40 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] #[test] - fn test_vsock_config_setup_from_vm_args() { + fn test_vsock_config_setup_vsock() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let socket_path = test_dir.path().join("vhost4.socket").display().to_string(); + let args = + VsockArgs::from_args_vsock(3, &socket_path, 1, "1234+4321", 64 * 1024, 1024, "group1"); + + let configs = Vec::::try_from(args); + assert!(configs.is_ok()); + + let configs = configs.unwrap(); + assert_eq!(configs.len(), 1); + + let config = &configs[0]; + assert_eq!(config.get_guest_cid(), 3); + assert_eq!(config.get_socket_path(), socket_path); + assert_eq!( + config.get_backend_info(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![1234, 4321] + }) + ); + assert_eq!(config.get_tx_buffer_size(), 64 * 1024); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec!["group1".to_string()]); + + test_dir.close().unwrap(); + } + + #[test] + fn test_vsock_config_setup_from_vm_args_unix() { let test_dir = tempdir().expect("Could not create a temp test directory."); let socket_paths = [ @@ -479,8 +693,114 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] #[test] - fn test_vsock_config_setup_from_file() { + fn test_vsock_config_setup_from_vm_args_vsock() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let socket_paths = [ + test_dir.path().join("vhost3.socket"), + test_dir.path().join("vhost4.socket"), + test_dir.path().join("vhost5.socket"), + test_dir.path().join("vhost6.socket"), + ]; + let uds_paths = [ + test_dir.path().join("vm3.vsock"), + test_dir.path().join("vm4.vsock"), + test_dir.path().join("vm5.vsock"), + ]; + let params = format!( + "--vm socket={vhost3_socket},uds_path={vm3_vsock} \ + --vm socket={vhost4_socket},uds-path={vm4_vsock},guest-cid=4,tx_buffer_size=65536,queue_size=1024,groups=group1 \ + --vm groups=group2+group3,guest-cid=5,socket={vhost5_socket},uds_path={vm5_vsock},tx-buffer-size=32768,queue_size=256 \ + --vm guest-cid=6,socket={vhost6_socket},forward-cid=1,forward-listen=1234+4321,queue-size=2048", + vhost3_socket = socket_paths[0].display(), + vhost4_socket = socket_paths[1].display(), + vhost5_socket = socket_paths[2].display(), + vhost6_socket = socket_paths[3].display(), + vm3_vsock = uds_paths[0].display(), + vm4_vsock = uds_paths[1].display(), + vm5_vsock = uds_paths[2].display(), + ); + + let mut params = params.split_whitespace().collect::>(); + params.insert(0, ""); // to make the test binary name agnostic + + let args = VsockArgs::parse_from(params); + + let configs = Vec::::try_from(args); + assert!(configs.is_ok()); + + let configs = configs.unwrap(); + assert_eq!(configs.len(), 4); + + let config = configs.first().unwrap(); + assert_eq!(config.get_guest_cid(), 3); + assert_eq!( + config.get_socket_path(), + socket_paths[0].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_paths[0].display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 65536); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + let config = configs.get(1).unwrap(); + assert_eq!(config.get_guest_cid(), 4); + assert_eq!( + config.get_socket_path(), + socket_paths[1].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_paths[1].display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 65536); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec!["group1".to_string()]); + + let config = configs.get(2).unwrap(); + assert_eq!(config.get_guest_cid(), 5); + assert_eq!( + config.get_socket_path(), + socket_paths[2].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_paths[2].display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 32768); + assert_eq!(config.get_queue_size(), 256); + assert_eq!( + config.get_groups(), + vec!["group2".to_string(), "group3".to_string()] + ); + + let config = configs.get(3).unwrap(); + assert_eq!(config.get_guest_cid(), 6); + assert_eq!( + config.get_socket_path(), + socket_paths[3].display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![1234, 4321] + }) + ); + assert_eq!(config.get_tx_buffer_size(), 65536); + assert_eq!(config.get_queue_size(), 2048); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + test_dir.close().unwrap(); + } + + #[test] + fn test_vsock_config_setup_from_file_unix() { let test_dir = tempdir().expect("Could not create a temp test directory."); let config_path = test_dir.path().join("config.yaml"); @@ -555,8 +875,142 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] #[test] - fn test_vsock_server() { + fn test_vsock_config_setup_from_file_vsock() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let config_path = test_dir.path().join("config.yaml"); + let socket_path_unix = test_dir.path().join("vhost4.socket"); + let socket_path_vsock = test_dir.path().join("vhost5.socket"); + let uds_path = test_dir.path().join("vm4.vsock"); + + let mut yaml = File::create(&config_path).unwrap(); + yaml.write_all( + format!( + "vms: + - guest_cid: 4 + socket: {} + uds_path: {} + tx_buffer_size: 32768 + queue_size: 256 + groups: group1+group2 + - guest_cid: 5 + socket: {} + forward_cid: 1 + forward_listen: 1234+4321 + tx_buffer_size: 32768", + socket_path_unix.display(), + uds_path.display(), + socket_path_vsock.display(), + ) + .as_bytes(), + ) + .unwrap(); + let args = VsockArgs::from_file(&config_path.display().to_string()); + + let configs = Vec::::try_from(args).unwrap(); + assert_eq!(configs.len(), 2); + + let config = &configs[0]; + assert_eq!(config.get_guest_cid(), 4); + assert_eq!( + config.get_socket_path(), + socket_path_unix.display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_path.display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), 32768); + assert_eq!(config.get_queue_size(), 256); + assert_eq!( + config.get_groups(), + vec!["group1".to_string(), "group2".to_string()] + ); + + let config = &configs[1]; + assert_eq!(config.get_guest_cid(), 5); + assert_eq!( + config.get_socket_path(), + socket_path_vsock.display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![1234, 4321] + }) + ); + assert_eq!(config.get_tx_buffer_size(), 32768); + assert_eq!(config.get_queue_size(), 1024); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + // Now test that optional parameters are correctly set to their default values. + let mut yaml = File::create(&config_path).unwrap(); + yaml.write_all( + format!( + "vms: + - socket: {} + uds_path: {}", + socket_path_unix.display(), + uds_path.display(), + ) + .as_bytes(), + ) + .unwrap(); + let args = VsockArgs::from_file(&config_path.display().to_string()); + + let configs = Vec::::try_from(args).unwrap(); + assert_eq!(configs.len(), 1); + + let config = &configs[0]; + assert_eq!(config.get_guest_cid(), DEFAULT_GUEST_CID); + assert_eq!( + config.get_socket_path(), + socket_path_unix.display().to_string() + ); + assert_eq!( + config.get_backend_info(), + BackendType::UnixDomainSocket(uds_path.display().to_string()) + ); + assert_eq!(config.get_tx_buffer_size(), DEFAULT_TX_BUFFER_SIZE); + assert_eq!(config.get_queue_size(), DEFAULT_QUEUE_SIZE); + assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]); + + std::fs::remove_file(&config_path).unwrap(); + test_dir.close().unwrap(); + } + + fn test_vsock_server(config: VsockConfig) { + let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); + + let backend = Arc::new(VhostUserVsockBackend::new(config, cid_map).unwrap()); + + let daemon = VhostUserDaemon::new( + String::from("vhost-device-vsock"), + backend.clone(), + GuestMemoryAtomic::new(GuestMemoryMmap::new()), + ) + .unwrap(); + + let mut epoll_handlers = daemon.get_epoll_handlers(); + + // VhostUserVsockBackend support a single thread that handles the TX and RX queues + assert_eq!(backend.threads.len(), 1); + + assert_eq!(epoll_handlers.len(), backend.threads.len()); + + for thread in backend.threads.iter() { + thread + .lock() + .unwrap() + .register_listeners(epoll_handlers.remove(0)); + } + } + + #[test] + fn test_vsock_server_unix() { const CID: u64 = 3; const CONN_TX_BUF_SIZE: u32 = 64 * 1024; const QUEUE_SIZE: usize = 1024; @@ -583,30 +1037,39 @@ mod tests { vec![DEFAULT_GROUP_NAME.to_string()], ); - let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); + test_vsock_server(config); - let backend = Arc::new(VhostUserVsockBackend::new(config, cid_map).unwrap()); + test_dir.close().unwrap(); + } - let daemon = VhostUserDaemon::new( - String::from("vhost-device-vsock"), - backend.clone(), - GuestMemoryAtomic::new(GuestMemoryMmap::new()), - ) - .unwrap(); + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_server_vsock() { + const CID: u64 = 3; + const CONN_TX_BUF_SIZE: u32 = 64 * 1024; + const QUEUE_SIZE: usize = 1024; - let mut epoll_handlers = daemon.get_epoll_handlers(); + let test_dir = tempdir().expect("Could not create a temp test directory."); - // VhostUserVsockBackend support a single thread that handles the TX and RX queues - assert_eq!(backend.threads.len(), 1); + let vhost_socket_path = test_dir + .path() + .join("test_vsock_server.socket") + .display() + .to_string(); - assert_eq!(epoll_handlers.len(), backend.threads.len()); + let config = VsockConfig::new( + CID, + vhost_socket_path, + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9000], + }), + CONN_TX_BUF_SIZE, + QUEUE_SIZE, + vec![DEFAULT_GROUP_NAME.to_string()], + ); - for thread in backend.threads.iter() { - thread - .lock() - .unwrap() - .register_listeners(epoll_handlers.remove(0)); - } + test_vsock_server(config); test_dir.close().unwrap(); } @@ -674,8 +1137,9 @@ mod tests { let _ = test_dir.close(); } + #[cfg(not(feature = "backend_vsock"))] #[test] - fn test_main_structs() { + fn test_main_structs_unix() { let error = parse_vm_params("").unwrap_err(); assert_matches!(error, VmArgsParseError::BadArgument); assert_eq!(format!("{error:?}"), "BadArgument"); @@ -689,21 +1153,76 @@ mod tests { assert_matches!(error, CliError::NoArgsProvided); assert_eq!(format!("{error:?}"), "NoArgsProvided"); - let args = VsockArgs::from_args(0, "", "", 0, 0, ""); - assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: \"\", tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); + let args = VsockArgs::from_args_unix(0, "", "", 0, 0, ""); + assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); let param = args.param.unwrap().clone(); - assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: \"\", tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); + assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); let config = ConfigFileVsockParam { guest_cid: None, socket: String::new(), - uds_path: String::new(), + uds_path: Some(String::new()), tx_buffer_size: None, queue_size: None, groups: None, } .clone(); - assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: \"\", tx_buffer_size: None, queue_size: None, groups: None }"); + assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: Some(\"\"), tx_buffer_size: None, queue_size: None, groups: None }"); + } + + #[cfg(feature = "backend_vsock")] + #[test] + fn test_main_structs_vsock() { + let error = parse_vm_params("").unwrap_err(); + assert_matches!(error, VmArgsParseError::BadArgument); + assert_eq!(format!("{error:?}"), "BadArgument"); + + let args = VsockArgs { + param: None, + vm: None, + config: None, + }; + let error = Vec::::try_from(args).unwrap_err(); + assert_matches!(error, CliError::NoArgsProvided); + assert_eq!(format!("{error:?}"), "NoArgsProvided"); + + let args = VsockArgs::from_args_unix(0, "", "", 0, 0, ""); + assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), forward_cid: None, forward_listen: None, tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); + + let param = args.param.unwrap().clone(); + assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: Some(\"\"), forward_cid: None, forward_listen: None, tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); + + let args = VsockArgs::from_args_vsock(0, "", 1, "", 0, 0, ""); + assert_eq!(format!("{args:?}"), "VsockArgs { param: Some(VsockParam { guest_cid: 0, socket: \"\", uds_path: None, forward_cid: Some(1), forward_listen: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }), vm: None, config: None }"); + + let param = args.param.unwrap().clone(); + assert_eq!(format!("{param:?}"), "VsockParam { guest_cid: 0, socket: \"\", uds_path: None, forward_cid: Some(1), forward_listen: Some(\"\"), tx_buffer_size: 0, queue_size: 0, groups: \"\" }"); + + let config = ConfigFileVsockParam { + guest_cid: None, + socket: String::new(), + uds_path: Some(String::new()), + forward_cid: None, + forward_listen: None, + tx_buffer_size: None, + queue_size: None, + groups: None, + } + .clone(); + assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: Some(\"\"), forward_cid: None, forward_listen: None, tx_buffer_size: None, queue_size: None, groups: None }"); + + let config = ConfigFileVsockParam { + guest_cid: None, + socket: String::new(), + uds_path: None, + forward_cid: Some(1), + forward_listen: Some(String::new()), + tx_buffer_size: None, + queue_size: None, + groups: None, + } + .clone(); + assert_eq!(format!("{config:?}"), "ConfigFileVsockParam { guest_cid: None, socket: \"\", uds_path: None, forward_cid: Some(1), forward_listen: Some(\"\"), tx_buffer_size: None, queue_size: None, groups: None }"); } } diff --git a/vhost-device-vsock/src/thread_backend.rs b/vhost-device-vsock/src/thread_backend.rs index 9661adb..55f5b00 100644 --- a/vhost-device-vsock/src/thread_backend.rs +++ b/vhost-device-vsock/src/thread_backend.rs @@ -17,6 +17,8 @@ use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; use vm_memory::{ bitmap::BitmapSlice, ReadVolatile, VolatileMemoryError, VolatileSlice, WriteVolatile, }; +#[cfg(feature = "backend_vsock")] +use vsock::VsockStream; use crate::{ rxops::*, @@ -55,6 +57,8 @@ impl RawVsockPacket { pub(crate) enum StreamType { Unix(UnixStream), + #[cfg(feature = "backend_vsock")] + Vsock(VsockStream), } impl StreamType { @@ -64,6 +68,11 @@ impl StreamType { let cloned_stream = stream.try_clone()?; Ok(StreamType::Unix(cloned_stream)) } + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => { + let cloned_stream = stream.try_clone()?; + Ok(StreamType::Vsock(cloned_stream)) + } } } } @@ -72,6 +81,8 @@ impl Read for StreamType { fn read(&mut self, buf: &mut [u8]) -> StdIOResult { match self { StreamType::Unix(stream) => stream.read(buf), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.read(buf), } } } @@ -80,12 +91,16 @@ impl Write for StreamType { fn write(&mut self, buf: &[u8]) -> StdIOResult { match self { StreamType::Unix(stream) => stream.write(buf), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.write(buf), } } fn flush(&mut self) -> StdIOResult<()> { match self { StreamType::Unix(stream) => stream.flush(), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.flush(), } } } @@ -94,6 +109,8 @@ impl AsRawFd for StreamType { fn as_raw_fd(&self) -> RawFd { match self { StreamType::Unix(stream) => stream.as_raw_fd(), + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => stream.as_raw_fd(), } } } @@ -105,6 +122,30 @@ impl ReadVolatile for StreamType { ) -> StdResult { match self { StreamType::Unix(stream) => stream.read_volatile(buf), + // Copied from vm_memory crate's ReadVolatile implementation for UnixStream + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => { + let fd = stream.as_raw_fd(); + let guard = buf.ptr_guard_mut(); + + let dst = guard.as_ptr().cast::(); + + // SAFETY: We got a valid file descriptor from `AsRawFd`. The memory pointed to by `dst` is + // valid for writes of length `buf.len() by the invariants upheld by the constructor + // of `VolatileSlice`. + let bytes_read = unsafe { libc::read(fd, dst, buf.len()) }; + + if bytes_read < 0 { + // We don't know if a partial read might have happened, so mark everything as dirty + buf.bitmap().mark_dirty(0, buf.len()); + + Err(VolatileMemoryError::IOError(std::io::Error::last_os_error())) + } else { + let bytes_read = bytes_read.try_into().unwrap(); + buf.bitmap().mark_dirty(0, bytes_read); + Ok(bytes_read) + } + } } } } @@ -116,6 +157,25 @@ impl WriteVolatile for StreamType { ) -> StdResult { match self { StreamType::Unix(stream) => stream.write_volatile(buf), + // Copied from vm_memory crate's WriteVolatile implementation for UnixStream + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(stream) => { + let fd = stream.as_raw_fd(); + let guard = buf.ptr_guard(); + + let src = guard.as_ptr().cast::(); + + // SAFETY: We got a valid file descriptor from `AsRawFd`. The memory pointed to by `src` is + // valid for reads of length `buf.len() by the invariants upheld by the constructor + // of `VolatileSlice`. + let bytes_written = unsafe { libc::write(fd, src, buf.len()) }; + + if bytes_written < 0 { + Err(VolatileMemoryError::IOError(std::io::Error::last_os_error())) + } else { + Ok(bytes_written.try_into().unwrap()) + } + } } } } @@ -137,7 +197,7 @@ pub(crate) struct VsockThreadBackend { pub conn_map: HashMap>, /// Queue of ConnMapKey objects indicating pending rx operations. pub backend_rxq: VecDeque, - /// Map of host-side unix streams indexed by raw file descriptors. + /// Map of host-side unix or vsock streams indexed by raw file descriptors. pub stream_map: HashMap, /// Host side socket info for listening to new connections from the host. backend_info: BackendType, @@ -262,36 +322,39 @@ impl VsockThreadBackend { return Ok(()); } - let dst_cid = pkt.dst_cid(); - if dst_cid != VSOCK_HOST_CID { - let cid_map = self.cid_map.read().unwrap(); - if cid_map.contains_key(&dst_cid) { - let (sibling_raw_pkts_queue, sibling_groups_set, sibling_event_fd) = - cid_map.get(&dst_cid).unwrap(); + #[allow(irrefutable_let_patterns)] + if let BackendType::UnixDomainSocket(_) = &self.backend_info { + let dst_cid = pkt.dst_cid(); + if dst_cid != VSOCK_HOST_CID { + let cid_map = self.cid_map.read().unwrap(); + if cid_map.contains_key(&dst_cid) { + let (sibling_raw_pkts_queue, sibling_groups_set, sibling_event_fd) = + cid_map.get(&dst_cid).unwrap(); - if self - .groups_set - .read() - .unwrap() - .is_disjoint(sibling_groups_set.read().unwrap().deref()) - { - info!( - "vsock: dropping packet for cid: {:?} due to group mismatch", - dst_cid - ); - return Ok(()); + if self + .groups_set + .read() + .unwrap() + .is_disjoint(sibling_groups_set.read().unwrap().deref()) + { + info!( + "vsock: dropping packet for cid: {:?} due to group mismatch", + dst_cid + ); + return Ok(()); + } + + sibling_raw_pkts_queue + .write() + .unwrap() + .push_back(RawVsockPacket::from_vsock_packet(pkt)?); + let _ = sibling_event_fd.write(1); + } else { + warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid); } - sibling_raw_pkts_queue - .write() - .unwrap() - .push_back(RawVsockPacket::from_vsock_packet(pkt)?); - let _ = sibling_event_fd.write(1); - } else { - warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid); + return Ok(()); } - - return Ok(()); } // TODO: Rst if packet has unsupported type @@ -371,9 +434,11 @@ impl VsockThreadBackend { /// Handle a new guest initiated connection, i.e from the peer, the guest driver. /// - /// Attempts to connect to a host side unix socket listening on a path - /// corresponding to the destination port as follows: + /// In case of proxying using unix domain socket, attempts to connect to a host side unix socket + /// listening on a path corresponding to the destination port as follows: /// - "{self.host_sock_path}_{local_port}"" + /// + /// In case of proxying using vosck, attempts to connect to the {forward_cid, local_port} fn handle_new_guest_conn(&mut self, pkt: &VsockPacket) { match &self.backend_info { BackendType::UnixDomainSocket(uds_path) => { @@ -385,6 +450,14 @@ impl VsockThreadBackend { .and_then(|stream| self.add_new_guest_conn(StreamType::Unix(stream), pkt)) .unwrap_or_else(|_| self.enq_rst()); } + #[cfg(feature = "backend_vsock")] + BackendType::Vsock(vsock_info) => { + VsockStream::connect_with_cid_port(vsock_info.forward_cid, pkt.dst_port()) + .and_then(|stream| stream.set_nonblocking(true).map(|_| stream)) + .map_err(Error::VsockConnect) + .and_then(|stream| self.add_new_guest_conn(StreamType::Vsock(stream), pkt)) + .unwrap_or_else(|_| self.enq_rst()); + } } } @@ -397,6 +470,8 @@ impl VsockThreadBackend { let conn = VsockConnection::new_peer_init( stream.try_clone().map_err(match stream { StreamType::Unix(_) => Error::UnixConnect, + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(_) => Error::VsockConnect, })?, pkt.dst_cid(), pkt.dst_port(), @@ -436,28 +511,23 @@ impl VsockThreadBackend { #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "backend_vsock")] + use crate::vhu_vsock::VsockProxyInfo; use crate::vhu_vsock::{BackendType, VhostUserVsockBackend, VsockConfig, VSOCK_OP_RW}; use std::os::unix::net::UnixListener; use tempfile::tempdir; use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; + #[cfg(feature = "backend_vsock")] + use vsock::{VsockListener, VMADDR_CID_ANY}; const DATA_LEN: usize = 16; const CONN_TX_BUF_SIZE: u32 = 64 * 1024; const QUEUE_SIZE: usize = 1024; const GROUP_NAME: &str = "default"; + const VSOCK_PEER_PORT: u32 = 1234; - #[test] - fn test_vsock_thread_backend() { + fn test_vsock_thread_backend(backend_info: BackendType) { const CID: u64 = 3; - const VSOCK_PEER_PORT: u32 = 1234; - - let test_dir = tempdir().expect("Could not create a temp test directory."); - - let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock"); - let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234"); - - let _listener = UnixListener::bind(&vsock_peer_path).unwrap(); - let backend_info = BackendType::UnixDomainSocket(vsock_socket_path.display().to_string()); let epoll_fd = epoll::create(false).unwrap(); @@ -510,6 +580,19 @@ mod tests { // TODO: it is a nop for now vtp.enq_rst(); + } + + #[test] + fn test_vsock_thread_backend_unix() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock"); + let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234"); + + let _listener = UnixListener::bind(&vsock_peer_path).unwrap(); + let backend_info = BackendType::UnixDomainSocket(vsock_socket_path.display().to_string()); + + test_vsock_thread_backend(backend_info); // cleanup let _ = std::fs::remove_file(&vsock_peer_path); @@ -518,6 +601,18 @@ mod tests { test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_thread_backend_vsock() { + let _listener = VsockListener::bind_with_cid_port(VMADDR_CID_ANY, VSOCK_PEER_PORT).unwrap(); + let backend_info = BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![], + }); + + test_vsock_thread_backend(backend_info); + } + #[test] fn test_vsock_thread_backend_sibling_vms() { const CID: u64 = 3; diff --git a/vhost-device-vsock/src/vhu_vsock.rs b/vhost-device-vsock/src/vhu_vsock.rs index fa07e5d..e2efe0e 100644 --- a/vhost-device-vsock/src/vhu_vsock.rs +++ b/vhost-device-vsock/src/vhu_vsock.rs @@ -118,8 +118,17 @@ pub(crate) enum Error { PktBufMissing, #[error("Failed to connect to unix socket")] UnixConnect(std::io::Error), - #[error("Unable to write to unix stream")] - UnixWrite, + #[cfg(feature = "backend_vsock")] + #[error("Failed to accept new local vsock socket connection")] + VsockAccept(std::io::Error), + #[cfg(feature = "backend_vsock")] + #[error("Failed to connect to vsock socket")] + VsockConnect(std::io::Error), + #[cfg(feature = "backend_vsock")] + #[error("Failed to bind a vsock stream")] + VsockBind(std::io::Error), + #[error("Unable to write to stream")] + StreamWrite, #[error("Unable to push data to local tx buffer")] LocalTxBufFull, #[error("Unable to flush data from local tx buffer")] @@ -142,10 +151,20 @@ impl std::convert::From for std::io::Error { } } +#[cfg(feature = "backend_vsock")] +#[derive(Debug, PartialEq, Clone)] +pub(crate) struct VsockProxyInfo { + pub forward_cid: u32, + pub listen_ports: Vec, +} + #[derive(Debug, PartialEq, Clone)] pub(crate) enum BackendType { /// unix domain socket path UnixDomainSocket(String), + /// the vsock CID and ports + #[cfg(feature = "backend_vsock")] + Vsock(VsockProxyInfo), } #[derive(Debug, Clone)] @@ -390,34 +409,7 @@ mod tests { const CONN_TX_BUF_SIZE: u32 = 64 * 1024; const QUEUE_SIZE: usize = 1024; - #[test] - fn test_vsock_backend() { - const CID: u64 = 3; - - let groups_list: Vec = vec![String::from("default")]; - - let test_dir = tempdir().expect("Could not create a temp test directory."); - - let vhost_socket_path = test_dir - .path() - .join("test_vsock_backend.socket") - .display() - .to_string(); - let vsock_socket_path = test_dir - .path() - .join("test_vsock_backend.vsock") - .display() - .to_string(); - - let config = VsockConfig::new( - CID, - vhost_socket_path.to_string(), - BackendType::UnixDomainSocket(vsock_socket_path.to_string()), - CONN_TX_BUF_SIZE, - QUEUE_SIZE, - groups_list, - ); - + fn test_vsock_backend(config: VsockConfig, expected_cid: u64) { let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); let backend = VhostUserVsockBackend::new(config, cid_map); @@ -452,7 +444,7 @@ mod tests { let config = backend.get_config(0, 8); assert_eq!(config.len(), 8); let cid = u64::from_le_bytes(config.try_into().unwrap()); - assert_eq!(cid, CID); + assert_eq!(cid, expected_cid); let exit = backend.exit_event(0); assert!(exit.is_some()); @@ -469,11 +461,74 @@ mod tests { let ret = backend.handle_event(BACKEND_EVENT, EventSet::IN, &vrings, 0); assert!(ret.is_ok()); + } + + #[test] + fn test_vsock_backend_unix() { + const CID: u64 = 3; + + let groups_list: Vec = vec![String::from("default")]; + + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let vhost_socket_path = test_dir + .path() + .join("test_vsock_backend_unix.socket") + .display() + .to_string(); + let vsock_socket_path = test_dir + .path() + .join("test_vsock_backend.vsock") + .display() + .to_string(); + + let config = VsockConfig::new( + CID, + vhost_socket_path.to_string(), + BackendType::UnixDomainSocket(vsock_socket_path.to_string()), + CONN_TX_BUF_SIZE, + QUEUE_SIZE, + groups_list, + ); + + test_vsock_backend(config, CID); // cleanup let _ = std::fs::remove_file(vhost_socket_path); let _ = std::fs::remove_file(vsock_socket_path); + test_dir.close().unwrap(); + } + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_backend_vsock() { + const CID: u64 = 3; + + let groups_list: Vec = vec![String::from("default")]; + + let test_dir = tempdir().expect("Could not create a temp test directory."); + + let vhost_socket_path = test_dir + .path() + .join("test_vsock_backend.socket") + .display() + .to_string(); + let config = VsockConfig::new( + CID, + vhost_socket_path.to_string(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9001, 9002], + }), + CONN_TX_BUF_SIZE, + QUEUE_SIZE, + groups_list, + ); + + test_vsock_backend(config, CID); + + // cleanup + let _ = std::fs::remove_file(vhost_socket_path); test_dir.close().unwrap(); } @@ -558,7 +613,7 @@ mod tests { #[test] fn test_vhu_vsock_structs() { - let config = VsockConfig::new( + let unix_config = VsockConfig::new( 0, String::new(), BackendType::UnixDomainSocket(String::new()), @@ -566,8 +621,22 @@ mod tests { 0, vec![String::new()], ); + assert_eq!(format!("{unix_config:?}"), "VsockConfig { guest_cid: 0, socket: \"\", backend_info: UnixDomainSocket(\"\"), tx_buffer_size: 0, queue_size: 0, groups: [\"\"] }"); - assert_eq!(format!("{config:?}"), "VsockConfig { guest_cid: 0, socket: \"\", backend_info: UnixDomainSocket(\"\"), tx_buffer_size: 0, queue_size: 0, groups: [\"\"] }"); + #[cfg(feature = "backend_vsock")] + let vsock_config = VsockConfig::new( + 0, + String::new(), + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9001, 9002], + }), + 0, + 0, + vec![String::new()], + ); + #[cfg(feature = "backend_vsock")] + assert_eq!(format!("{vsock_config:?}"), "VsockConfig { guest_cid: 0, socket: \"\", backend_info: Vsock(VsockProxyInfo { forward_cid: 1, listen_ports: [9001, 9002] }), tx_buffer_size: 0, queue_size: 0, groups: [\"\"] }"); let conn_map = ConnMapKey::new(0, 0); assert_eq!( diff --git a/vhost-device-vsock/src/vhu_vsock_thread.rs b/vhost-device-vsock/src/vhu_vsock_thread.rs index 8acf724..4ce45d0 100644 --- a/vhost-device-vsock/src/vhu_vsock_thread.rs +++ b/vhost-device-vsock/src/vhu_vsock_thread.rs @@ -18,6 +18,8 @@ use std::{ thread, }; +#[cfg(feature = "backend_vsock")] +use log::error; use log::warn; use vhost_user_backend::{VringEpollHandler, VringRwLock, VringT}; use virtio_queue::QueueOwnedT; @@ -27,6 +29,8 @@ use vmm_sys_util::{ epoll::EventSet, eventfd::{EventFd, EFD_NONBLOCK}, }; +#[cfg(feature = "backend_vsock")] +use vsock::{VsockListener, VMADDR_CID_ANY}; use crate::{ rxops::*, @@ -55,6 +59,8 @@ struct EventData { enum ListenerType { Unix(UnixListener), + #[cfg(feature = "backend_vsock")] + Vsock(VsockListener), } pub(crate) struct VhostUserVsockThread { @@ -105,6 +111,16 @@ impl VhostUserVsockThread { let host_sock = host_listener.as_raw_fd(); host_listeners_map.insert(host_sock, ListenerType::Unix(host_listener)); } + #[cfg(feature = "backend_vsock")] + BackendType::Vsock(vsock_info) => { + for p in &vsock_info.listen_ports { + let host_listener = VsockListener::bind_with_cid_port(VMADDR_CID_ANY, *p) + .and_then(|sock| sock.set_nonblocking(true).map(|_| sock)) + .map_err(Error::VsockBind)?; + let host_sock = host_listener.as_raw_fd(); + host_listeners_map.insert(host_sock, ListenerType::Vsock(host_listener)); + } + } } let epoll_fd = epoll::create(true).map_err(Error::EpollFdCreate)?; @@ -150,6 +166,7 @@ impl VhostUserVsockThread { }; Self::vring_handle_event(event_data); }); + let thread = VhostUserVsockThread { mem: None, event_idx: false, @@ -312,6 +329,51 @@ impl VhostUserVsockThread { }); } } + #[cfg(feature = "backend_vsock")] + ListenerType::Vsock(vsock_listener) => { + let conn = vsock_listener.accept().map_err(Error::VsockAccept); + if self.mem.is_some() { + match conn { + Ok((stream, addr)) => { + if let Err(err) = stream.set_nonblocking(true) { + warn!("Failed to set stream to non-blocking: {:?}", err); + return; + } + + let peer_port = match vsock_listener.local_addr() { + Ok(listener_addr) => listener_addr.port(), + Err(err) => { + warn!("Failed to get peer address: {:?}", err); + return; + } + }; + + let local_port = addr.port(); + let stream_raw_fd = stream.as_raw_fd(); + self.add_new_connection_from_host( + stream_raw_fd, + StreamType::Vsock(stream), + local_port, + peer_port, + ); + if let Err(err) = Self::epoll_register( + self.get_epoll_fd(), + stream_raw_fd, + epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, + ) { + warn!("Failed to register with epoll: {:?}", err); + } + } + Err(err) => { + warn!("Unable to accept new local connection: {:?}", err); + } + } + } else { + conn.map(drop).unwrap_or_else(|err| { + warn!("Error closing an incoming connection: {:?}", err); + }); + } + } } } else { // Check if the stream represented by fd has already established a @@ -333,6 +395,10 @@ impl VhostUserVsockThread { }; match stream { + #[cfg(feature = "backend_vsock")] + StreamType::Vsock(_) => { + error!("Stream type should not be of type vsock"); + } StreamType::Unix(ref mut unix_stream) => { // Local peer is sending a "connect PORT\n" command let peer_port = match Self::read_local_stream_port(unix_stream) { @@ -726,6 +792,10 @@ impl Drop for VhostUserVsockThread { BackendType::UnixDomainSocket(uds_path) => { let _ = std::fs::remove_file(uds_path); } + #[cfg(feature = "backend_vsock")] + BackendType::Vsock(_) => { + // Nothing to do + } } self.thread_backend .cid_map @@ -737,12 +807,16 @@ impl Drop for VhostUserVsockThread { #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "backend_vsock")] + use crate::vhu_vsock::VsockProxyInfo; use std::collections::HashMap; use std::io::Read; use std::io::Write; use tempfile::tempdir; use vm_memory::GuestAddress; use vmm_sys_util::eventfd::EventFd; + #[cfg(feature = "backend_vsock")] + use vsock::VsockStream; const CONN_TX_BUF_SIZE: u32 = 64 * 1024; @@ -752,28 +826,12 @@ mod tests { } } - #[test] - fn test_vsock_thread() { + fn test_vsock_thread(backend_info: BackendType) { let groups: Vec = vec![String::from("default")]; let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); - let test_dir = tempdir().expect("Could not create a temp test directory."); - let backend_info = BackendType::UnixDomainSocket( - test_dir - .path() - .join("test_vsock_thread.vsock") - .display() - .to_string(), - ); - - let t = VhostUserVsockThread::new( - backend_info, - 3, - CONN_TX_BUF_SIZE, - groups, - cid_map, - ); + let t = VhostUserVsockThread::new(backend_info, 3, CONN_TX_BUF_SIZE, groups, cid_map); assert!(t.is_ok()); let mut t = t.unwrap(); @@ -838,10 +896,32 @@ mod tests { dummy_fd.write(1).unwrap(); t.process_backend_evt(EventSet::empty()); + } + #[test] + fn test_vsock_thread_unix() { + let test_dir = tempdir().expect("Could not create a temp test directory."); + let backend_info = BackendType::UnixDomainSocket( + test_dir + .path() + .join("test_vsock_thread.vsock") + .display() + .to_string(), + ); + test_vsock_thread(backend_info); test_dir.close().unwrap(); } + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_thread_vsock() { + let backend_info = BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![], + }); + test_vsock_thread(backend_info); + } + #[test] fn test_vsock_thread_failures() { let groups: Vec = vec![String::from("default")]; @@ -910,8 +990,9 @@ mod tests { test_dir.close().unwrap(); } + #[test] - fn test_vsock_thread_unix() { + fn test_vsock_thread_unix_backend() { let groups: Vec = vec![String::from("default")]; let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); @@ -956,4 +1037,47 @@ mod tests { test_dir.close().unwrap(); } + + #[cfg(feature = "backend_vsock")] + #[test] + fn test_vsock_thread_vsock_backend() { + let groups: Vec = vec![String::from("default")]; + let cid_map: Arc> = Arc::new(RwLock::new(HashMap::new())); + + let t = VhostUserVsockThread::new( + BackendType::Vsock(VsockProxyInfo { + forward_cid: 1, + listen_ports: vec![9003, 9004], + }), + 3, + CONN_TX_BUF_SIZE, + groups, + cid_map, + ); + + let mut t = t.unwrap(); + + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(), + ); + + t.mem = Some(mem.clone()); + + let mut vs1 = VsockStream::connect_with_cid_port(1, 9003).unwrap(); + let mut vs2 = VsockStream::connect_with_cid_port(1, 9004).unwrap(); + t.process_backend_evt(EventSet::empty()); + + vs1.write_all(b"some data").unwrap(); + vs2.write_all(b"some data").unwrap(); + t.process_backend_evt(EventSet::empty()); + + let mut buf = vec![0u8; 16]; + vs1.set_nonblocking(true).unwrap(); + vs2.set_nonblocking(true).unwrap(); + // There isn't any peer responding, so we don't expect data + vs1.read(&mut buf).unwrap_err(); + vs2.read(&mut buf).unwrap_err(); + + t.process_backend_evt(EventSet::empty()); + } } diff --git a/vhost-device-vsock/src/vsock_conn.rs b/vhost-device-vsock/src/vsock_conn.rs index 3762e7a..beb3a01 100644 --- a/vhost-device-vsock/src/vsock_conn.rs +++ b/vhost-device-vsock/src/vsock_conn.rs @@ -301,7 +301,7 @@ impl VsockCon /// /// Returns: /// - Ok(cnt) where cnt is the number of bytes written to the stream - /// - Err(Error::UnixWrite) if there was an error writing to the stream + /// - Err(Error::StreamWrite) if there was an error writing to the stream fn send_bytes(&mut self, buf: &VolatileSlice) -> Result<()> { if !self.tx_buf.is_empty() { // Data is already present in the buffer and the backend @@ -318,12 +318,12 @@ impl VsockCon 0 } else { dbg!("send_bytes error: {:?}", e); - return Err(Error::UnixWrite); + return Err(Error::StreamWrite); } } Err(e) => { dbg!("send_bytes error: {:?}", e); - return Err(Error::UnixWrite); + return Err(Error::StreamWrite); } };