diff --git a/Cargo.lock b/Cargo.lock index 116c1a7..103de3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,12 +28,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "cc" version = "1.0.73" @@ -46,6 +58,22 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clap" +version = "3.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86447ad904c7fb335a790c9d7fe3d0d971dc523b8ccd1561a520de9a85302750" +dependencies = [ + "atty", + "bitflags", + "clap_lex 0.2.4", + "indexmap", + "strsim", + "termcolor", + "textwrap", + "yaml-rust", +] + [[package]] name = "clap" version = "4.0.11" @@ -55,7 +83,7 @@ dependencies = [ "atty", "bitflags", "clap_derive", - "clap_lex", + "clap_lex 0.3.0", "once_cell", "strsim", "termcolor", @@ -74,6 +102,15 @@ dependencies = [ "syn", ] +[[package]] +name = "clap_lex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" +dependencies = [ + "os_str_bytes", +] + [[package]] name = "clap_lex" version = "0.3.0" @@ -115,6 +152,96 @@ dependencies = [ "instant", ] +[[package]] +name = "futures" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" + +[[package]] +name = "futures-executor" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", + "num_cpus", +] + +[[package]] +name = "futures-io" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" + +[[package]] +name = "futures-macro" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56" + +[[package]] +name = "futures-task" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1" + +[[package]] +name = "futures-util" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.2.7" @@ -126,6 +253,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "heck" version = "0.4.0" @@ -147,6 +280,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "indexmap" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" +dependencies = [ + "autocfg", + "hashbrown", +] + [[package]] name = "instant" version = "0.1.12" @@ -170,7 +313,7 @@ dependencies = [ "libc", "libgpiod-sys", "thiserror", - "vmm-sys-util", + "vmm-sys-util 0.10.0", ] [[package]] @@ -181,6 +324,12 @@ dependencies = [ "cc", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "log" version = "0.4.17" @@ -196,6 +345,16 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "num_cpus" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.15.0" @@ -208,6 +367,18 @@ version = "6.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" +[[package]] +name = "pin-project-lite" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "ppv-lite86" version = "0.2.16" @@ -321,6 +492,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "slab" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +dependencies = [ + "autocfg", +] + [[package]] name = "strsim" version = "0.10.0" @@ -361,6 +541,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "textwrap" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "949517c0cf1bf4ee812e2e07e08ab448e3ae0d23472aee8a06c985f0c8815b16" + [[package]] name = "thiserror" version = "1.0.37" @@ -393,6 +579,18 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "vhost" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b56bf8f178fc500fe14505fca8b00dec76fc38f2304f461c8d9d7547982311d" +dependencies = [ + "bitflags", + "libc", + "vm-memory 0.9.0", + "vmm-sys-util 0.10.0", +] + [[package]] name = "vhost" version = "0.5.0" @@ -401,50 +599,50 @@ checksum = "79243657c76e5c90dcbf60187c842614f6dfc7123972c55bb3bcc446792aca93" dependencies = [ "bitflags", "libc", - "vm-memory", - "vmm-sys-util", + "vm-memory 0.9.0", + "vmm-sys-util 0.10.0", ] [[package]] name = "vhost-device-gpio" version = "0.1.0" dependencies = [ - "clap", + "clap 4.0.11", "env_logger", "libc", "libgpiod", "log", "thiserror", - "vhost", - "vhost-user-backend", + "vhost 0.5.0", + "vhost-user-backend 0.7.0", "virtio-bindings", - "virtio-queue", - "vm-memory", - "vmm-sys-util", + "virtio-queue 0.6.1", + "vm-memory 0.9.0", + "vmm-sys-util 0.10.0", ] [[package]] name = "vhost-device-i2c" version = "0.1.0" dependencies = [ - "clap", + "clap 4.0.11", "env_logger", "libc", "log", "thiserror", - "vhost", - "vhost-user-backend", + "vhost 0.5.0", + "vhost-user-backend 0.7.0", "virtio-bindings", - "virtio-queue", - "vm-memory", - "vmm-sys-util", + "virtio-queue 0.6.1", + "vm-memory 0.9.0", + "vmm-sys-util 0.10.0", ] [[package]] name = "vhost-device-rng" version = "0.1.0" dependencies = [ - "clap", + "clap 4.0.11", "env_logger", "epoll", "libc", @@ -452,12 +650,27 @@ dependencies = [ "rand", "tempfile", "thiserror", - "vhost", - "vhost-user-backend", + "vhost 0.5.0", + "vhost-user-backend 0.7.0", "virtio-bindings", - "virtio-queue", - "vm-memory", - "vmm-sys-util", + "virtio-queue 0.6.1", + "vm-memory 0.9.0", + "vmm-sys-util 0.10.0", +] + +[[package]] +name = "vhost-user-backend" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8db00e93514caa8987bb8b536fe962c9b66b4068583abc4c531eb97988477cd" +dependencies = [ + "libc", + "log", + "vhost 0.3.0", + "virtio-bindings", + "virtio-queue 0.1.0", + "vm-memory 0.7.0", + "vmm-sys-util 0.9.0", ] [[package]] @@ -468,11 +681,29 @@ checksum = "6a0fc7d5f8e2943cd9f2ecd58be3f2078add863a49573d14dd9d64e1ab26544c" dependencies = [ "libc", "log", - "vhost", + "vhost 0.5.0", "virtio-bindings", - "virtio-queue", - "vm-memory", - "vmm-sys-util", + "virtio-queue 0.6.1", + "vm-memory 0.9.0", + "vmm-sys-util 0.10.0", +] + +[[package]] +name = "vhost-user-vsock" +version = "0.1.0" +dependencies = [ + "byteorder", + "clap 3.2.22", + "epoll", + "futures", + "log", + "thiserror", + "vhost 0.3.0", + "vhost-user-backend 0.1.0", + "virtio-bindings", + "virtio-queue 0.1.0", + "vm-memory 0.7.0", + "vmm-sys-util 0.9.0", ] [[package]] @@ -481,6 +712,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ff512178285488516ed85f15b5d0113a7cdb89e9e8a760b269ae4f02b84bd6b" +[[package]] +name = "virtio-queue" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90da9e627f6aaf667cc7b6548a28be332d3e1f058f4ceeb46ab6bcee5c4b74d" +dependencies = [ + "log", + "vm-memory 0.7.0", + "vmm-sys-util 0.10.0", +] + [[package]] name = "virtio-queue" version = "0.6.1" @@ -489,8 +731,19 @@ checksum = "435dd49c7b38419729afd43675850c7b5dc4728f2fabd70c7a9079a331e4f8c6" dependencies = [ "log", "virtio-bindings", - "vm-memory", - "vmm-sys-util", + "vm-memory 0.9.0", + "vmm-sys-util 0.10.0", +] + +[[package]] +name = "vm-memory" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339d4349c126fdcd87e034631d7274370cf19eb0e87b33166bcd956589fc72c5" +dependencies = [ + "arc-swap", + "libc", + "winapi", ] [[package]] @@ -504,6 +757,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "vmm-sys-util" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "733537bded03aaa93543f785ae997727b30d1d9f4a03b7861d23290474242e11" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "vmm-sys-util" version = "0.10.0" @@ -550,3 +813,12 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "yaml-rust" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" +dependencies = [ + "linked-hash-map", +] diff --git a/Cargo.toml b/Cargo.toml index faee2f0..258a969 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,5 @@ members = [ "gpio", "i2c", "rng", + "vsock", ] diff --git a/README.md b/README.md index 1c59c98..aa4315b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Here is the list of device backends that we support: - [GPIO](https://github.com/rust-vmm/vhost-device/blob/main/gpio/README.md) - [I2C](https://github.com/rust-vmm/vhost-device/blob/main/i2c/README.md) - [RNG](https://github.com/rust-vmm/vhost-device/blob/main/rng/README.md) +- [VSOCK](https://github.com/rust-vmm/vhost-device/blob/main/vsock/README.md) ## Testing and Code Coverage diff --git a/vsock/Cargo.toml b/vsock/Cargo.toml new file mode 100644 index 0000000..54b488b --- /dev/null +++ b/vsock/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "vhost-user-vsock" +version = "0.1.0" +authors = ["Harshavardhan Unnibhavi "] +description = "A virtio-vsock device using the vhost-user protocol." +repository = "https://github.com/rust-vmm/vhost-device" +readme = "README.md" +keywords = ["vhost", "vsock"] +license = "Apache-2.0 OR BSD-3-Clause" +edition = "2018" + +[dependencies] +byteorder = "1" +clap = { version = ">=3.0", features = ["yaml"] } +epoll = "4.3.1" +futures = { version = "0.3", features = ["thread-pool"] } +log = "0.4.14" +thiserror = "1.0" +vhost = { version = "0.3", features = ["vhost-user-slave"] } +vhost-user-backend = "0.1" +virtio-bindings = ">=0.1" +virtio-queue = "0.1" +vm-memory = "0.7" +vmm-sys-util = "=0.9.0" + +[dev-dependencies] +virtio-queue = { version = "0.1", features = ["test-utils"] } diff --git a/vsock/README.md b/vsock/README.md new file mode 100644 index 0000000..2847b65 --- /dev/null +++ b/vsock/README.md @@ -0,0 +1,108 @@ +# vhost-user-vsock + +## Design + +The crate introduces a vhost-user-vsock device that enables communication between an +application running in the guest i.e inside a VM and an application running on the +host i.e outside the VM. The application running in the guest communicates over VM +sockets i.e over AF_VSOCK sockets. The application running on the host connects to a +unix socket on the host i.e communicates over AF_UNIX sockets. The main components of +the crate are split into various files as described below: + +- [packet.rs](src/packet.rs) + - Introduces the **VsockPacket** structure that represents a single vsock packet + processing methods. +- [rxops.rs](src/rxops.rs) + - Introduces various vsock operations that are enqueued into the rxqueue to be sent to the + guest. Exposes a **RxOps** structure. +- [rxqueue.rs](src/rxqueue.rs) + - rxqueue contains the pending rx operations corresponding to that connection. The queue is + represented as a bitmap as we handle connection-oriented connections. The module contains + various queue manipulation methods. Exposes a **RxQueue** structure. +- [thread_backend.rs](src/thread_backend.rs) + - Multiplexes connections between host and guest and calls into per connection methods that + are responsible for processing data and packets corresponding to the connection. Exposes a + **VsockThreadBackend** structure. +- [txbuf.rs](src/txbuf.rs) + - Module to buffer data that is sent from the guest to the host. The module exposes a **LocalTxBuf** + structure. +- [vhost_user_vsock_thread.rs](src/vhost_user_vsock_thread.rs) + - Module exposes a **VhostUserVsockThread** structure. It also handles new host initiated + connections and provides interfaces for registering host connections with the epoll fd. Also + provides interfaces for iterating through the rx and tx queues. +- [vsock_conn.rs](src/vsock_conn.rs) + - Module introduces a **VsockConnection** structure that represents a single vsock connection + between the guest and the host. It also processes packets according to their type. +- [vhu_vsock.rs](src/lib.rs) + - exposes the main vhost user vsock backend interface. + +## Usage + +Run the vhost-user-vsock device: +``` +vhost-user-vsock --guest-cid=4 --uds-path=/tmp/vm4.vsock --socket=/tmp/vhost4.socket +``` + +Run qemu: + +``` +qemu-system-x86_64 -drive file=/path/to/disk.qcow2 -enable-kvm -m 512M \ + -smp 2 -vga virtio -chardev socket,id=char0,reconnect=0,path=/tmp/vhost4.socket \ + -device vhost-user-vsock-pci,chardev=char0 \ + -object memory-backend-file,share=on,id=mem,size="512M",mem-path="/dev/hugepages" \ + -numa node,memdev=mem -mem-prealloc +``` + +### Guest listening + +#### iperf + +```sh +# https://github.com/stefano-garzarella/iperf-vsock +guest$ iperf3 --vsock -s +host$ iperf3 --vsock -c /tmp/vm4.vsock +``` + +#### netcat + +```sh +guest$ nc --vsock -l 1234 + +host$ nc -U /tmp/vm4.vsock +CONNECT 1234 +``` + +### Host listening + +#### iperf + +```sh +# https://github.com/stefano-garzarella/iperf-vsock +host$ iperf3 --vsock -s -B /tmp/vm4.vsock +guest$ iperf3 --vsock -c 2 +``` + +#### netcat + +```sh +host$ nc -l -U /tmp/vm4.vsock_1234 + +guest$ nc --vsock 2 1234 +``` + +```rust +use my_crate; + +... +``` + +## License + +**!!!NOTICE**: The BSD-3-Clause license is not included in this template. +The license needs to be manually added because the text of the license file +also includes the copyright. The copyright can be different for different +crates. If the crate contains code from CrosVM, the crate must add the +CrosVM copyright which can be found +[here](https://chromium.googlesource.com/chromiumos/platform/crosvm/+/master/LICENSE). +For crates developed from scratch, the copyright is different and depends on +the contributors. diff --git a/vsock/src/cli.yaml b/vsock/src/cli.yaml new file mode 100644 index 0000000..ef172a3 --- /dev/null +++ b/vsock/src/cli.yaml @@ -0,0 +1,36 @@ +name: vhost-device-vsock +version: "0.1.0" +author: "Harshavardhan Unnibhavi " +about: Virtio VSOCK backend daemon. + +settings: + - ArgRequiredElseHelp + +args: + # CID of the guest + - guest_cid: + long: guest-cid + value_name: INT + takes_value: true + help: Context identifier of the guest which uniquely identifies the device for its lifetime. Defaults to 3 if not specified. + # Socket paths + - uds_path: + long: uds-path + value_name: FILE + takes_value: true + help: Unix socket to which a host-side application connects to. + - socket: + long: socket + value_name: FILE + takes_value: true + help: Unix socket to which a hypervisor conencts to and sets up the control path with the device. + +groups: + - required_args: + args: + - guest_cid + args: + - uds_path + args: + - socket + required: true diff --git a/vsock/src/main.rs b/vsock/src/main.rs new file mode 100644 index 0000000..cc3b8b1 --- /dev/null +++ b/vsock/src/main.rs @@ -0,0 +1,87 @@ +mod packet; +mod rxops; +mod rxqueue; +mod thread_backend; +mod txbuf; +mod vhu_vsock; +mod vhu_vsock_thread; +mod vsock_conn; + +use clap::{load_yaml, App}; +use std::{ + convert::TryFrom, + process, + sync::{Arc, RwLock}, +}; +use vhost::{vhost_user, vhost_user::Listener}; +use vhost_user_backend::VhostUserDaemon; +use vhu_vsock::{VhostUserVsockBackend, VsockConfig}; +use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap}; + +/// This is the public API through which an external program starts the +/// vhost-user-vsock backend server. +pub(crate) fn start_backend_server(vsock_config: VsockConfig) { + loop { + let vsock_backend = Arc::new(RwLock::new( + VhostUserVsockBackend::new(vsock_config.clone()).unwrap(), + )); + + let listener = Listener::new(vsock_config.get_socket_path(), true).unwrap(); + + let mut vsock_daemon = VhostUserDaemon::new( + String::from("vhost-user-vsock"), + vsock_backend.clone(), + GuestMemoryAtomic::new(GuestMemoryMmap::new()), + ) + .unwrap(); + + let mut vring_workers = vsock_daemon.get_epoll_handlers(); + + if vring_workers.len() != vsock_backend.read().unwrap().threads.len() { + println!("Number of vring workers must be identical to number of backend threads"); + } + + for thread in vsock_backend.read().unwrap().threads.iter() { + thread + .lock() + .unwrap() + .set_vring_worker(Some(vring_workers.remove(0))); + } + if let Err(e) = vsock_daemon.start(listener) { + dbg!("Failed to start vsock daemon: {:?}", e); + process::exit(1); + } + + match vsock_daemon.wait() { + Ok(()) => { + println!("Stopping cleanly"); + process::exit(0); + } + Err(vhost_user_backend::Error::HandleRequest(vhost_user::Error::PartialMessage)) => { + println!("vhost-user connection closed with partial message. If the VM is shutting down, this is expected behavior; otherwise, it might be a bug."); + continue; + } + Err(e) => { + println!("Error running daemon: {:?}", e); + } + } + + vsock_backend + .read() + .unwrap() + .exit_event + .write(1) + .expect("Shutting down worker thread"); + + println!("Vsock daemon is finished"); + } +} + +fn main() { + let yaml = load_yaml!("cli.yaml"); + let vsock_args = App::from_yaml(yaml).get_matches(); + + let vsock_config = VsockConfig::try_from(vsock_args).unwrap(); + + start_backend_server(vsock_config); +} diff --git a/vsock/src/packet.rs b/vsock/src/packet.rs new file mode 100644 index 0000000..a321faf --- /dev/null +++ b/vsock/src/packet.rs @@ -0,0 +1,594 @@ +#![deny(missing_docs)] +use byteorder::{ByteOrder, LittleEndian}; +use thiserror::Error as ThisError; +use virtio_queue::DescriptorChain; +use vm_memory::{ + GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic, GuestMemoryLoadGuard, + GuestMemoryMmap, +}; + +pub(crate) type Result = std::result::Result; + +/// Below enum defines custom error types for vsock packet operations. +#[derive(Debug, PartialEq, ThisError)] +pub(crate) enum Error { + #[error("Descriptor not writable")] + UnwritableDescriptor, + #[error("Missing descriptor in queue")] + QueueMissingDescriptor, + #[error("Small header descriptor: {0}")] + HdrDescTooSmall(u32), + #[error("Chained guest memory error")] + GuestMemory, + #[error("Descriptor not readable")] + UnreadableDescriptor, + #[error("Extra descriptors in the descriptor chain")] + ExtraDescrInChain, + #[error("Data buffer size less than size in packet header")] + DataDescTooSmall, +} + +// TODO: Replace below with bindgen generated struct +// vsock packet header size when packed +pub const VSOCK_PKT_HDR_SIZE: usize = 44; + +// Offset into header for source cid +const HDROFF_SRC_CID: usize = 0; + +// Offset into header for destination cid +const HDROFF_DST_CID: usize = 8; + +// Offset into header for source port +const HDROFF_SRC_PORT: usize = 16; + +// Offset into header for destination port +const HDROFF_DST_PORT: usize = 20; + +// Offset into the header for data length +const HDROFF_LEN: usize = 24; + +// Offset into header for packet type +const HDROFF_TYPE: usize = 28; + +// Offset into header for operation kind +const HDROFF_OP: usize = 30; + +// Offset into header for additional flags +// only for VSOCK_OP_SHUTDOWN +const HDROFF_FLAGS: usize = 32; + +// Offset into header for tx buf alloc +const HDROFF_BUF_ALLOC: usize = 36; + +// Offset into header for forward count +const HDROFF_FWD_CNT: usize = 40; + +/// Vsock packet structure implemented as a wrapper around a virtq descriptor chain: +/// - chain head holds the packet header +/// - optional data descriptor, only present for data packets (VSOCK_OP_RW) +#[derive(Debug)] +pub struct VsockPacket { + hdr: *mut u8, + buf: Option<*mut u8>, + buf_size: usize, +} + +impl VsockPacket { + /// Create a vsock packet wrapper around a chain in the rx virtqueue. + /// Perform bounds checking before creating the wrapper. + pub(crate) fn from_rx_virtq_head( + chain: &mut DescriptorChain>>, + mem: GuestMemoryAtomic, + ) -> Result { + // head is at 0, next is at 1, max of two descriptors + // head contains the packet header + // next contains the optional packet data + let mut descr_vec = Vec::with_capacity(2); + + for descr in chain { + if !descr.is_write_only() { + return Err(Error::UnwritableDescriptor); + } + + descr_vec.push(descr); + } + + if descr_vec.len() < 2 { + // We expect a head and a data descriptor + return Err(Error::QueueMissingDescriptor); + } + + let head_descr = descr_vec[0]; + let data_descr = descr_vec[1]; + + if head_descr.len() < VSOCK_PKT_HDR_SIZE as u32 { + return Err(Error::HdrDescTooSmall(head_descr.len())); + } + + Ok(Self { + hdr: VsockPacket::guest_to_host_address( + &mem.memory(), + head_descr.addr(), + VSOCK_PKT_HDR_SIZE, + ) + .ok_or(Error::GuestMemory)? as *mut u8, + buf: Some( + VsockPacket::guest_to_host_address( + &mem.memory(), + data_descr.addr(), + data_descr.len() as usize, + ) + .ok_or(Error::GuestMemory)? as *mut u8, + ), + buf_size: data_descr.len() as usize, + }) + } + + /// Create a vsock packet wrapper around a chain in the tx virtqueue + /// Bounds checking before creating the wrapper. + pub(crate) fn from_tx_virtq_head( + chain: &mut DescriptorChain>>, + mem: GuestMemoryAtomic, + ) -> Result { + // head is at 0, next is at 1, max of two descriptors + // head contains the packet header + // next contains the optional packet data + let mut descr_vec = Vec::with_capacity(2); + // let mut num_descr = 0; + + for descr in chain { + if descr.is_write_only() { + return Err(Error::UnreadableDescriptor); + } + + descr_vec.push(descr); + } + + if descr_vec.len() > 2 { + return Err(Error::ExtraDescrInChain); + } + + let head_descr = descr_vec[0]; + + if head_descr.len() < VSOCK_PKT_HDR_SIZE as u32 { + return Err(Error::HdrDescTooSmall(head_descr.len())); + } + + let mut pkt = Self { + hdr: VsockPacket::guest_to_host_address( + &mem.memory(), + head_descr.addr(), + VSOCK_PKT_HDR_SIZE, + ) + .ok_or(Error::GuestMemory)? as *mut u8, + buf: None, + buf_size: 0, + }; + + // Zero length packet + if pkt.is_empty() { + return Ok(pkt); + } + + // There exists packet data as well + let data_descr = descr_vec[1]; + + // Data buffer should be as large as described in the header + if data_descr.len() < pkt.len() { + return Err(Error::DataDescTooSmall); + } + + pkt.buf_size = data_descr.len() as usize; + pkt.buf = Some( + VsockPacket::guest_to_host_address( + &mem.memory(), + data_descr.addr(), + data_descr.len() as usize, + ) + .ok_or(Error::GuestMemory)? as *mut u8, + ); + + Ok(pkt) + } + + /// Convert an absolute address in guest address space to a host + /// pointer and verify that the provided size defines a valid + /// range within a single memory region. + fn guest_to_host_address( + mem: &GuestMemoryLoadGuard, + addr: GuestAddress, + size: usize, + ) -> Option<*mut u8> { + if mem.check_range(addr, size) { + Some(mem.get_host_address(addr).unwrap()) + } else { + None + } + } + + /// In place byte slice access to vsock packet header. + pub fn hdr(&self) -> &[u8] { + // Safe as bound checks performed in from_*_virtq_head + unsafe { std::slice::from_raw_parts(self.hdr as *const u8, VSOCK_PKT_HDR_SIZE) } + } + + /// In place mutable slice access to vsock packet header. + pub fn hdr_mut(&mut self) -> &mut [u8] { + // Safe as bound checks performed in from_*_virtq_head + unsafe { std::slice::from_raw_parts_mut(self.hdr, VSOCK_PKT_HDR_SIZE) } + } + + /// Size of vsock packet data, found by accessing len field + /// of virtio_vsock_hdr struct. + pub fn len(&self) -> u32 { + LittleEndian::read_u32(&self.hdr()[HDROFF_LEN..]) + } + + /// Set the source cid. + pub fn set_src_cid(&mut self, cid: u64) -> &mut Self { + LittleEndian::write_u64(&mut self.hdr_mut()[HDROFF_SRC_CID..], cid); + self + } + + /// Set the destination cid. + pub fn set_dst_cid(&mut self, cid: u64) -> &mut Self { + LittleEndian::write_u64(&mut self.hdr_mut()[HDROFF_DST_CID..], cid); + self + } + + /// Set source port. + pub fn set_src_port(&mut self, port: u32) -> &mut Self { + LittleEndian::write_u32(&mut self.hdr_mut()[HDROFF_SRC_PORT..], port); + self + } + + /// Set destination port. + pub fn set_dst_port(&mut self, port: u32) -> &mut Self { + LittleEndian::write_u32(&mut self.hdr_mut()[HDROFF_DST_PORT..], port); + self + } + + /// Set type of connection. + pub fn set_type(&mut self, type_: u16) -> &mut Self { + LittleEndian::write_u16(&mut self.hdr_mut()[HDROFF_TYPE..], type_); + self + } + + /// Set size of tx buf. + pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self { + LittleEndian::write_u32(&mut self.hdr_mut()[HDROFF_BUF_ALLOC..], buf_alloc); + self + } + + /// Set amount of tx buf data written to stream. + pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self { + LittleEndian::write_u32(&mut self.hdr_mut()[HDROFF_FWD_CNT..], fwd_cnt); + self + } + + /// Set packet operation ID. + pub fn set_op(&mut self, op: u16) -> &mut Self { + LittleEndian::write_u16(&mut self.hdr_mut()[HDROFF_OP..], op); + self + } + + /// Check if the packet has no data. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Get destination port from packet. + pub fn dst_port(&self) -> u32 { + LittleEndian::read_u32(&self.hdr()[HDROFF_DST_PORT..]) + } + + /// Get source port from packet. + pub fn src_port(&self) -> u32 { + LittleEndian::read_u32(&self.hdr()[HDROFF_SRC_PORT..]) + } + + /// Get source cid from packet. + pub fn src_cid(&self) -> u64 { + LittleEndian::read_u64(&self.hdr()[HDROFF_SRC_CID..]) + } + + /// Get destination cid from packet. + pub fn dst_cid(&self) -> u64 { + LittleEndian::read_u64(&self.hdr()[HDROFF_DST_CID..]) + } + + /// Get packet type. + pub fn pkt_type(&self) -> u16 { + LittleEndian::read_u16(&self.hdr()[HDROFF_TYPE..]) + } + + /// Get operation requested in the packet. + pub fn op(&self) -> u16 { + LittleEndian::read_u16(&self.hdr()[HDROFF_OP..]) + } + + /// Byte slice mutable access to vsock packet data buffer. + pub fn buf_mut(&mut self) -> Option<&mut [u8]> { + // Safe as bound checks performed while creating packet + self.buf + .map(|ptr| unsafe { std::slice::from_raw_parts_mut(ptr, self.buf_size) }) + } + + /// Byte slice access to vsock packet data buffer. + pub fn buf(&self) -> Option<&[u8]> { + // Safe as bound checks performed while creating packet + self.buf + .map(|ptr| unsafe { std::slice::from_raw_parts(ptr as *const u8, self.buf_size) }) + } + + /// Set data buffer length. + pub fn set_len(&mut self, len: u32) -> &mut Self { + LittleEndian::write_u32(&mut self.hdr_mut()[HDROFF_LEN..], len); + self + } + + /// Read buf alloc. + pub fn buf_alloc(&self) -> u32 { + LittleEndian::read_u32(&self.hdr()[HDROFF_BUF_ALLOC..]) + } + + /// Get fwd cnt from packet header. + pub fn fwd_cnt(&self) -> u32 { + LittleEndian::read_u32(&self.hdr()[HDROFF_FWD_CNT..]) + } + + /// Read flags from the packet header. + pub fn flags(&self) -> u32 { + LittleEndian::read_u32(&self.hdr()[HDROFF_FLAGS..]) + } + + /// Set packet header flag to flags. + pub fn set_flags(&mut self, flags: u32) -> &mut Self { + LittleEndian::write_u32(&mut self.hdr_mut()[HDROFF_FLAGS..], flags); + self + } + + /// Set OP specific flags. + pub fn set_flag(&mut self, flag: u32) -> &mut Self { + self.set_flags(self.flags() | flag); + self + } +} + +#[cfg(test)] +pub mod tests { + use crate::vhu_vsock::{VSOCK_OP_RW, VSOCK_TYPE_STREAM}; + + use super::*; + use virtio_queue::{ + defs::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}, + mock::MockSplitQueue, + Descriptor, + }; + use vm_memory::{Address, Bytes, GuestAddress, GuestMemoryAtomic, GuestMemoryMmap}; + + pub struct HeadParams { + head_len: usize, + data_len: u32, + } + + impl HeadParams { + pub fn new(head_len: usize, data_len: u32) -> Self { + Self { head_len, data_len } + } + fn construct_head(&self) -> Vec { + let mut header = vec![0_u8; self.head_len]; + if self.head_len == VSOCK_PKT_HDR_SIZE { + LittleEndian::write_u32(&mut header[HDROFF_LEN..], self.data_len); + } + header + } + } + + pub fn prepare_desc_chain_vsock( + write_only: bool, + head_params: &HeadParams, + data_chain_len: u16, + head_data_len: u32, + ) -> ( + GuestMemoryAtomic, + DescriptorChain>, + ) { + let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x1000)]).unwrap(); + let virt_queue = MockSplitQueue::new(&mem, 16); + let mut next_addr = virt_queue.desc_table().total_size() + 0x100; + let mut flags = 0; + let mut head_flags; + + if write_only { + flags |= VIRTQ_DESC_F_WRITE; + } + + if data_chain_len > 0 { + head_flags = flags | VIRTQ_DESC_F_NEXT + } else { + head_flags = flags; + } + + // vsock packet header + // let header = vec![0 as u8; head_params.head_len]; + let header = head_params.construct_head(); + let head_desc = Descriptor::new(next_addr, head_params.head_len as u32, head_flags, 1); + mem.write(&header, head_desc.addr()).unwrap(); + virt_queue.desc_table().store(0, head_desc); + next_addr += head_params.head_len as u64; + + // Put the descriptor index 0 in the first available ring position. + mem.write_obj(0u16, virt_queue.avail_addr().unchecked_add(4)) + .unwrap(); + + // Set `avail_idx` to 1. + mem.write_obj(1u16, virt_queue.avail_addr().unchecked_add(2)) + .unwrap(); + + // chain len excludes the head + for i in 0..(data_chain_len) { + // last descr in chain + if i == data_chain_len - 1 { + head_flags &= !VIRTQ_DESC_F_NEXT; + } + // vsock data + let data = vec![0_u8; head_data_len as usize]; + let data_desc = Descriptor::new(next_addr, data.len() as u32, head_flags, i + 2); + mem.write(&data, data_desc.addr()).unwrap(); + virt_queue.desc_table().store(i + 1, data_desc); + next_addr += head_data_len as u64; + } + + // Create descriptor chain from pre-filled memory + ( + GuestMemoryAtomic::new(mem.clone()), + virt_queue + .create_queue(GuestMemoryAtomic::::new(mem.clone())) + .iter() + .unwrap() + .next() + .unwrap(), + ) + } + + #[test] + fn test_guest_to_host_address() { + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 1000)]).unwrap(), + ); + assert!(VsockPacket::guest_to_host_address(&mem.memory(), GuestAddress(0), 1000).is_some()); + assert!(VsockPacket::guest_to_host_address(&mem.memory(), GuestAddress(0), 500).is_some()); + assert!( + VsockPacket::guest_to_host_address(&mem.memory(), GuestAddress(500), 500).is_some() + ); + assert!( + VsockPacket::guest_to_host_address(&mem.memory(), GuestAddress(501), 500).is_none() + ); + } + + #[test] + fn test_from_rx_virtq_head() { + // parameters for packet head construction + let head_params = HeadParams::new(VSOCK_PKT_HDR_SIZE, 10); + + // write only descriptor chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 2, 10); + assert!(VsockPacket::from_rx_virtq_head(&mut descr_chain, mem).is_ok()); + + // read only descriptor chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 1, 10); + assert_eq!( + VsockPacket::from_rx_virtq_head(&mut descr_chain, mem).unwrap_err(), + Error::UnwritableDescriptor + ); + + // less than two descriptors + let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 0, 10); + assert_eq!( + VsockPacket::from_rx_virtq_head(&mut descr_chain, mem).unwrap_err(), + Error::QueueMissingDescriptor + ); + + // incorrect header length + let head_params = HeadParams::new(22, 10); + let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 1, 10); + assert_eq!( + VsockPacket::from_rx_virtq_head(&mut descr_chain, mem).unwrap_err(), + Error::HdrDescTooSmall(22) + ); + } + + #[test] + fn test_vsock_packet_header_ops() { + // parameters for head construction + let head_params = HeadParams::new(VSOCK_PKT_HDR_SIZE, 10); + + let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 2, 10); + let mut vsock_packet = VsockPacket::from_rx_virtq_head(&mut descr_chain, mem).unwrap(); + + // Check packet data's length + assert!(!vsock_packet.is_empty()); + assert_eq!(vsock_packet.len(), 10); + + // Set and get the source CID in the packet header + vsock_packet.set_src_cid(1); + assert_eq!(vsock_packet.src_cid(), 1); + + // Set and get the destination CID in the packet header + vsock_packet.set_dst_cid(1); + assert_eq!(vsock_packet.dst_cid(), 1); + + // Set and get the source port in the packet header + vsock_packet.set_src_port(5000); + assert_eq!(vsock_packet.src_port(), 5000); + + // Set and get the destination port in the packet header + vsock_packet.set_dst_port(5000); + assert_eq!(vsock_packet.dst_port(), 5000); + + // Set and get packet type + vsock_packet.set_type(VSOCK_TYPE_STREAM); + assert_eq!(vsock_packet.pkt_type(), VSOCK_TYPE_STREAM); + + // Set and get tx buffer size + vsock_packet.set_buf_alloc(10); + assert_eq!(vsock_packet.buf_alloc(), 10); + + // Set and get fwd_cnt of packet's data + vsock_packet.set_fwd_cnt(100); + assert_eq!(vsock_packet.fwd_cnt(), 100); + + // Set and get packet operation type + vsock_packet.set_op(VSOCK_OP_RW); + assert_eq!(vsock_packet.op(), VSOCK_OP_RW); + + // Set and get length of packet's data buffer + // this is a dummy test + vsock_packet.set_len(20); + assert_eq!(vsock_packet.len(), 20); + assert!(!vsock_packet.is_empty()); + + // Set and get packet's flags + vsock_packet.set_flags(1); + assert_eq!(vsock_packet.flags(), 1); + } + + #[test] + fn test_from_tx_virtq_head() { + // parameters for head construction + let head_params = HeadParams::new(VSOCK_PKT_HDR_SIZE, 0); + + // read only descriptor chain no data + let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 0, 0); + assert!(VsockPacket::from_tx_virtq_head(&mut descr_chain, mem).is_ok()); + + // parameters for head construction + let head_params = HeadParams::new(VSOCK_PKT_HDR_SIZE, 10); + + // read only descriptor chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 1, 10); + assert!(VsockPacket::from_tx_virtq_head(&mut descr_chain, mem).is_ok()); + + // write only descriptor chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 1, 10); + assert_eq!( + VsockPacket::from_tx_virtq_head(&mut descr_chain, mem).unwrap_err(), + Error::UnreadableDescriptor + ); + + // more than 2 descriptors in chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 2, 10); + assert_eq!( + VsockPacket::from_tx_virtq_head(&mut descr_chain, mem).unwrap_err(), + Error::ExtraDescrInChain + ); + + // length of data descriptor does not match the value in head + let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 1, 5); + assert_eq!( + VsockPacket::from_tx_virtq_head(&mut descr_chain, mem).unwrap_err(), + Error::DataDescTooSmall + ); + } +} diff --git a/vsock/src/rxops.rs b/vsock/src/rxops.rs new file mode 100644 index 0000000..b79e62b --- /dev/null +++ b/vsock/src/rxops.rs @@ -0,0 +1,34 @@ +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum RxOps { + /// VSOCK_OP_REQUEST + Request = 0, + /// VSOCK_OP_RW + Rw = 1, + /// VSOCK_OP_RESPONSE + Response = 2, + /// VSOCK_OP_CREDIT_UPDATE + CreditUpdate = 3, + /// VSOCK_OP_RST + Reset = 4, +} + +impl RxOps { + /// Convert enum value into bitmask. + pub fn bitmask(self) -> u8 { + 1u8 << (self as u8) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bitmask() { + assert_eq!(1, RxOps::Request.bitmask()); + assert_eq!(2, RxOps::Rw.bitmask()); + assert_eq!(4, RxOps::Response.bitmask()); + assert_eq!(8, RxOps::CreditUpdate.bitmask()); + assert_eq!(16, RxOps::Reset.bitmask()); + } +} diff --git a/vsock/src/rxqueue.rs b/vsock/src/rxqueue.rs new file mode 100644 index 0000000..9b033b9 --- /dev/null +++ b/vsock/src/rxqueue.rs @@ -0,0 +1,155 @@ +use super::rxops::RxOps; + +#[derive(Debug, PartialEq)] +pub struct RxQueue { + /// Bitmap of rx operations. + queue: u8, +} + +impl RxQueue { + /// New instance of RxQueue. + pub fn new() -> Self { + RxQueue { queue: 0_u8 } + } + + /// Enqueue a new rx operation into the queue. + pub fn enqueue(&mut self, op: RxOps) { + self.queue |= op.bitmask(); + } + + /// Dequeue an rx operation from the queue. + pub fn dequeue(&mut self) -> Option { + match self.peek() { + Some(req) => { + self.queue &= !req.bitmask(); + Some(req) + } + None => None, + } + } + + /// Peek into the queue to check if it contains an rx operation. + pub fn peek(&self) -> Option { + if self.contains(RxOps::Request.bitmask()) { + return Some(RxOps::Request); + } + if self.contains(RxOps::Rw.bitmask()) { + return Some(RxOps::Rw); + } + if self.contains(RxOps::Response.bitmask()) { + return Some(RxOps::Response); + } + if self.contains(RxOps::CreditUpdate.bitmask()) { + return Some(RxOps::CreditUpdate); + } + if self.contains(RxOps::Reset.bitmask()) { + Some(RxOps::Reset) + } else { + None + } + } + + /// Check if the queue contains a particular rx operation. + pub fn contains(&self, op: u8) -> bool { + (self.queue & op) != 0 + } + + /// Check if there are any pending rx operations in the queue. + pub fn pending_rx(&self) -> bool { + self.queue != 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contains() { + let mut rxqueue = RxQueue::new(); + rxqueue.queue = 31; + + assert!(rxqueue.contains(RxOps::Request.bitmask())); + assert!(rxqueue.contains(RxOps::Rw.bitmask())); + assert!(rxqueue.contains(RxOps::Response.bitmask())); + assert!(rxqueue.contains(RxOps::CreditUpdate.bitmask())); + assert!(rxqueue.contains(RxOps::Reset.bitmask())); + + rxqueue.queue = 0; + assert!(!rxqueue.contains(RxOps::Request.bitmask())); + assert!(!rxqueue.contains(RxOps::Rw.bitmask())); + assert!(!rxqueue.contains(RxOps::Response.bitmask())); + assert!(!rxqueue.contains(RxOps::CreditUpdate.bitmask())); + assert!(!rxqueue.contains(RxOps::Reset.bitmask())); + } + + #[test] + fn test_enqueue() { + let mut rxqueue = RxQueue::new(); + + rxqueue.enqueue(RxOps::Request); + assert!(rxqueue.contains(RxOps::Request.bitmask())); + + rxqueue.enqueue(RxOps::Rw); + assert!(rxqueue.contains(RxOps::Rw.bitmask())); + + rxqueue.enqueue(RxOps::Response); + assert!(rxqueue.contains(RxOps::Response.bitmask())); + + rxqueue.enqueue(RxOps::CreditUpdate); + assert!(rxqueue.contains(RxOps::CreditUpdate.bitmask())); + + rxqueue.enqueue(RxOps::Reset); + assert!(rxqueue.contains(RxOps::Reset.bitmask())); + } + + #[test] + fn test_peek() { + let mut rxqueue = RxQueue::new(); + + rxqueue.queue = 31; + assert_eq!(rxqueue.peek(), Some(RxOps::Request)); + + rxqueue.queue = 30; + assert_eq!(rxqueue.peek(), Some(RxOps::Rw)); + + rxqueue.queue = 28; + assert_eq!(rxqueue.peek(), Some(RxOps::Response)); + + rxqueue.queue = 24; + assert_eq!(rxqueue.peek(), Some(RxOps::CreditUpdate)); + + rxqueue.queue = 16; + assert_eq!(rxqueue.peek(), Some(RxOps::Reset)); + } + + #[test] + fn test_dequeue() { + let mut rxqueue = RxQueue::new(); + rxqueue.queue = 31; + + assert_eq!(rxqueue.dequeue(), Some(RxOps::Request)); + assert!(!rxqueue.contains(RxOps::Request.bitmask())); + + assert_eq!(rxqueue.dequeue(), Some(RxOps::Rw)); + assert!(!rxqueue.contains(RxOps::Rw.bitmask())); + + assert_eq!(rxqueue.dequeue(), Some(RxOps::Response)); + assert!(!rxqueue.contains(RxOps::Response.bitmask())); + + assert_eq!(rxqueue.dequeue(), Some(RxOps::CreditUpdate)); + assert!(!rxqueue.contains(RxOps::CreditUpdate.bitmask())); + + assert_eq!(rxqueue.dequeue(), Some(RxOps::Reset)); + assert!(!rxqueue.contains(RxOps::Reset.bitmask())); + } + + #[test] + fn test_pending_rx() { + let mut rxqueue = RxQueue::new(); + assert!(!rxqueue.pending_rx()); + + rxqueue.queue = 1; + assert!(rxqueue.pending_rx()); + } +} diff --git a/vsock/src/thread_backend.rs b/vsock/src/thread_backend.rs new file mode 100644 index 0000000..18c5e37 --- /dev/null +++ b/vsock/src/thread_backend.rs @@ -0,0 +1,236 @@ +#![deny(missing_docs)] + +use super::{ + packet::*, + rxops::*, + vhu_vsock::{ + ConnMapKey, Error, Result, VSOCK_HOST_CID, VSOCK_OP_REQUEST, VSOCK_OP_RST, + VSOCK_TYPE_STREAM, + }, + vhu_vsock_thread::VhostUserVsockThread, + vsock_conn::*, +}; +use log::{info, warn}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + os::unix::{ + net::UnixStream, + prelude::{AsRawFd, FromRawFd, RawFd}, + }, +}; + +// TODO: convert UnixStream to Arc> +pub struct VsockThreadBackend { + /// Map of ConnMapKey objects indexed by raw file descriptors. + pub listener_map: HashMap, + /// Map of vsock connection objects indexed by ConnMapKey objects. + pub conn_map: HashMap>, + /// Queue of ConnMapKey objects indicating pending rx operations. + pub backend_rxq: VecDeque, + /// Map of host-side unix streams indexed by raw file descriptors. + pub stream_map: HashMap, + /// Host side socket for listening to new connections from the host. + host_socket_path: String, + /// epoll for registering new host-side connections. + epoll_fd: i32, + /// Set of allocated local ports. + pub local_port_set: HashSet, +} + +impl VsockThreadBackend { + /// New instance of VsockThreadBackend. + pub fn new(host_socket_path: String, epoll_fd: i32) -> Self { + Self { + listener_map: HashMap::new(), + conn_map: HashMap::new(), + backend_rxq: VecDeque::new(), + // Need this map to prevent connected stream from closing + // TODO: think of a better solution + stream_map: HashMap::new(), + host_socket_path, + epoll_fd, + local_port_set: HashSet::new(), + } + } + + /// Checks if there are pending rx requests in the backend rxq. + pub fn pending_rx(&self) -> bool { + !self.backend_rxq.is_empty() + } + + /// Deliver a vsock packet to the guest vsock driver. + /// + /// Returns: + /// - `Ok(())` if the packet was successfully filled in + /// - `Err(Error::EmptyBackendRxQ) if there was no available data + pub(crate) fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> Result<()> { + // Pop an event from the backend_rxq + let key = self.backend_rxq.pop_front().ok_or(Error::EmptyBackendRxQ)?; + let conn = match self.conn_map.get_mut(&key) { + Some(conn) => conn, + None => { + // assume that the connection does not exist + return Ok(()); + } + }; + + if conn.rx_queue.peek() == Some(RxOps::Reset) { + // Handle RST events here + let conn = self.conn_map.remove(&key).unwrap(); + self.listener_map.remove(&conn.stream.as_raw_fd()); + self.stream_map.remove(&conn.stream.as_raw_fd()); + self.local_port_set.remove(&conn.local_port); + VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd()) + .unwrap_or_else(|err| { + warn!( + "Could not remove epoll listener for fd {:?}: {:?}", + conn.stream.as_raw_fd(), + err + ) + }); + + // Initialize the packet header to contain a VSOCK_OP_RST operation + pkt.set_op(VSOCK_OP_RST) + .set_src_cid(VSOCK_HOST_CID) + .set_dst_cid(conn.guest_cid) + .set_src_port(conn.local_port) + .set_dst_port(conn.peer_port) + .set_len(0) + .set_type(VSOCK_TYPE_STREAM) + .set_flags(0) + .set_buf_alloc(0) + .set_fwd_cnt(0); + + return Ok(()); + } + + // Handle other packet types per connection + conn.recv_pkt(pkt)?; + + Ok(()) + } + + /// Deliver a guest generated packet to its destination in the backend. + /// + /// Absorbs unexpected packets, handles rest to respective connection + /// object. + /// + /// Returns: + /// - always `Ok(())` if packet has been consumed correctly + pub(crate) fn send_pkt(&mut self, pkt: &VsockPacket) -> Result<()> { + let key = ConnMapKey::new(pkt.dst_port(), pkt.src_port()); + + // TODO: Rst if packet has unsupported type + if pkt.pkt_type() != VSOCK_TYPE_STREAM { + info!("vsock: dropping packet of unknown type"); + return Ok(()); + } + + // TODO: Handle packets to other CIDs as well + if pkt.dst_cid() != VSOCK_HOST_CID { + info!( + "vsock: dropping packet for cid other than host: {:?}", + pkt.hdr() + ); + + return Ok(()); + } + + // TODO: Handle cases where connection does not exist and packet op + // is not VSOCK_OP_REQUEST + if !self.conn_map.contains_key(&key) { + // The packet contains a new connection request + if pkt.op() == VSOCK_OP_REQUEST { + self.handle_new_guest_conn(&pkt); + } else { + // TODO: send back RST + } + return Ok(()); + } + + if pkt.op() == VSOCK_OP_RST { + // Handle an RST packet from the guest here + let conn = self.conn_map.get(&key).unwrap(); + if conn.rx_queue.contains(RxOps::Reset.bitmask()) { + return Ok(()); + } + let conn = self.conn_map.remove(&key).unwrap(); + self.listener_map.remove(&conn.stream.as_raw_fd()); + self.stream_map.remove(&conn.stream.as_raw_fd()); + self.local_port_set.remove(&conn.local_port); + VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd()) + .unwrap_or_else(|err| { + warn!( + "Could not remove epoll listener for fd {:?}: {:?}", + conn.stream.as_raw_fd(), + err + ) + }); + return Ok(()); + } + + // Forward this packet to its listening connection + let conn = self.conn_map.get_mut(&key).unwrap(); + conn.send_pkt(pkt)?; + + if conn.rx_queue.pending_rx() { + // Required if the connection object adds new rx operations + self.backend_rxq.push_back(key); + } + + Ok(()) + } + + /// Handle a new guest initiated connection, i.e from the peer, the guest driver. + /// + /// Attempts to connect to a host side unix socket listening on a path + /// corresponding to the destination port as follows: + /// - "{self.host_sock_path}_{local_port}"" + fn handle_new_guest_conn(&mut self, pkt: &VsockPacket) { + let port_path = format!("{}_{}", self.host_socket_path, pkt.dst_port()); + + UnixStream::connect(port_path) + .and_then(|stream| stream.set_nonblocking(true).map(|_| stream)) + .map_err(Error::UnixConnect) + .and_then(|stream| self.add_new_guest_conn(stream, pkt)) + .unwrap_or_else(|_| self.enq_rst()); + } + + /// Wrapper to add new connection to relevant HashMaps. + fn add_new_guest_conn(&mut self, stream: UnixStream, pkt: &VsockPacket) -> Result<()> { + let stream_fd = stream.as_raw_fd(); + self.listener_map + .insert(stream_fd, ConnMapKey::new(pkt.dst_port(), pkt.src_port())); + + let vsock_conn = VsockConnection::new_peer_init( + stream, + pkt.dst_cid(), + pkt.dst_port(), + pkt.src_cid(), + pkt.src_port(), + self.epoll_fd, + pkt.buf_alloc(), + ); + + self.conn_map + .insert(ConnMapKey::new(pkt.dst_port(), pkt.src_port()), vsock_conn); + self.backend_rxq + .push_back(ConnMapKey::new(pkt.dst_port(), pkt.src_port())); + self.stream_map + .insert(stream_fd, unsafe { UnixStream::from_raw_fd(stream_fd) }); + self.local_port_set.insert(pkt.dst_port()); + + VhostUserVsockThread::epoll_register( + self.epoll_fd, + stream_fd, + epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, + )?; + Ok(()) + } + + /// Enqueue RST packets to be sent to guest. + fn enq_rst(&mut self) { + // TODO + dbg!("New guest conn error: Enqueue RST"); + } +} diff --git a/vsock/src/txbuf.rs b/vsock/src/txbuf.rs new file mode 100644 index 0000000..c429585 --- /dev/null +++ b/vsock/src/txbuf.rs @@ -0,0 +1,210 @@ +use super::vhu_vsock::{Error, Result, CONN_TX_BUF_SIZE}; +use std::{io::Write, num::Wrapping}; + +#[derive(Debug)] +pub struct LocalTxBuf { + /// Buffer holding data to be forwarded to a host-side application + buf: Vec, + /// Index into buffer from which data can be consumed from the buffer + head: Wrapping, + /// Index into buffer from which data can be added to the buffer + tail: Wrapping, +} + +impl LocalTxBuf { + /// Create a new instance of LocalTxBuf. + pub fn new() -> Self { + Self { + buf: vec![0; CONN_TX_BUF_SIZE as usize], + head: Wrapping(0), + tail: Wrapping(0), + } + } + + /// Check if the buf is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Add new data to the tx buffer, push all or none. + /// Returns LocalTxBufFull error if space not sufficient. + pub(crate) fn push(&mut self, data_buf: &[u8]) -> Result<()> { + if CONN_TX_BUF_SIZE as usize - self.len() < data_buf.len() { + // Tx buffer is full + return Err(Error::LocalTxBufFull); + } + + // Get index into buffer at which data can be inserted + let tail_idx = self.tail.0 as usize % CONN_TX_BUF_SIZE as usize; + + // Check if we can fit the data buffer between head and end of buffer + let len = std::cmp::min(CONN_TX_BUF_SIZE as usize - tail_idx, data_buf.len()); + self.buf[tail_idx..tail_idx + len].copy_from_slice(&data_buf[..len]); + + // Check if there is more data to be wrapped around + if len < data_buf.len() { + self.buf[..(data_buf.len() - len)].copy_from_slice(&data_buf[len..]); + } + + // Increment tail by the amount of data that has been added to the buffer + self.tail += Wrapping(data_buf.len() as u32); + + Ok(()) + } + + /// Flush buf data to stream. + pub(crate) fn flush_to(&mut self, stream: &mut S) -> Result { + if self.is_empty() { + // No data to be flushed + return Ok(0); + } + + // Get index into buffer from which data can be read + let head_idx = self.head.0 as usize % CONN_TX_BUF_SIZE as usize; + + // First write from head to end of buffer + let len = std::cmp::min(CONN_TX_BUF_SIZE as usize - head_idx, self.len()); + let written = stream + .write(&self.buf[head_idx..(head_idx + len)]) + .map_err(Error::LocalTxBufFlush)?; + + // Increment head by amount of data that has been flushed to the stream + self.head += Wrapping(written as u32); + + // If written length is less than the expected length we can try again in the future + if written < len { + return Ok(written); + } + + // The head index has wrapped around the end of the buffer, we call self again + Ok(written + self.flush_to(stream).unwrap_or(0)) + } + + /// Return amount of data in the buffer. + fn len(&self) -> usize { + (self.tail - self.head).0 as usize + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_txbuf_len() { + let mut loc_tx_buf = LocalTxBuf::new(); + + // Zero length tx buf + assert_eq!(loc_tx_buf.len(), 0); + + // finite length tx buf + loc_tx_buf.head = Wrapping(0); + loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE); + assert_eq!(loc_tx_buf.len(), CONN_TX_BUF_SIZE as usize); + + loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE / 2); + assert_eq!(loc_tx_buf.len(), (CONN_TX_BUF_SIZE / 2) as usize); + + loc_tx_buf.head = Wrapping(256); + assert_eq!(loc_tx_buf.len(), 32512); + } + + #[test] + fn test_txbuf_is_empty() { + let mut loc_tx_buf = LocalTxBuf::new(); + + // empty tx buffer + assert!(loc_tx_buf.is_empty()); + + // non empty tx buffer + loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE); + assert!(!loc_tx_buf.is_empty()); + } + + #[test] + fn test_txbuf_push() { + let mut loc_tx_buf = LocalTxBuf::new(); + let data = [0; CONN_TX_BUF_SIZE as usize]; + + // push data into empty tx buffer + let res_push = loc_tx_buf.push(&data); + assert!(res_push.is_ok()); + assert_eq!(loc_tx_buf.head, Wrapping(0)); + assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE)); + + // push data into full tx buffer + let res_push = loc_tx_buf.push(&data); + assert!(res_push.is_err()); + + // head and tail wrap at full + loc_tx_buf.head = Wrapping(CONN_TX_BUF_SIZE); + let res_push = loc_tx_buf.push(&data); + assert!(res_push.is_ok()); + assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE * 2)); + + // only tail wraps at full + let data = vec![1; 4]; + let mut cmp_data = vec![1; 4]; + cmp_data.append(&mut vec![0; (CONN_TX_BUF_SIZE - 4) as usize]); + loc_tx_buf.head = Wrapping(4); + loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE); + let res_push = loc_tx_buf.push(&data); + assert!(res_push.is_ok()); + assert_eq!(loc_tx_buf.head, Wrapping(4)); + assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE + 4)); + assert_eq!(loc_tx_buf.buf, cmp_data); + } + + #[test] + fn test_txbuf_flush_to() { + let mut loc_tx_buf = LocalTxBuf::new(); + + // data to be flushed + let data = vec![1; CONN_TX_BUF_SIZE as usize]; + + // target to which data is flushed + let mut cmp_vec = Vec::with_capacity(data.len()); + + // flush no data + let res_flush = loc_tx_buf.flush_to(&mut cmp_vec); + assert!(res_flush.is_ok()); + assert_eq!(res_flush.unwrap(), 0); + + // flush data of CONN_TX_BUF_SIZE amount + let res_push = loc_tx_buf.push(&data); + assert_eq!(res_push.is_ok(), true); + let res_flush = loc_tx_buf.flush_to(&mut cmp_vec); + if let Ok(n) = res_flush { + assert_eq!(loc_tx_buf.head, Wrapping(n as u32)); + assert_eq!(loc_tx_buf.tail, Wrapping(CONN_TX_BUF_SIZE)); + assert_eq!(n, cmp_vec.len()); + assert_eq!(cmp_vec, data[..n]); + } + + // wrapping head flush + let mut data = vec![0; (CONN_TX_BUF_SIZE / 2) as usize]; + data.append(&mut vec![1; (CONN_TX_BUF_SIZE / 2) as usize]); + loc_tx_buf.head = Wrapping(0); + loc_tx_buf.tail = Wrapping(0); + let res_push = loc_tx_buf.push(&data); + assert!(res_push.is_ok()); + cmp_vec.clear(); + loc_tx_buf.head = Wrapping(CONN_TX_BUF_SIZE / 2); + loc_tx_buf.tail = Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2)); + let res_flush = loc_tx_buf.flush_to(&mut cmp_vec); + if let Ok(n) = res_flush { + assert_eq!( + loc_tx_buf.head, + Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2)) + ); + assert_eq!( + loc_tx_buf.tail, + Wrapping(CONN_TX_BUF_SIZE + (CONN_TX_BUF_SIZE / 2)) + ); + assert_eq!(n, cmp_vec.len()); + let mut data = vec![1; (CONN_TX_BUF_SIZE / 2) as usize]; + data.append(&mut vec![0; (CONN_TX_BUF_SIZE / 2) as usize]); + assert_eq!(cmp_vec, data[..n]); + } + } +} diff --git a/vsock/src/vhu_vsock.rs b/vsock/src/vhu_vsock.rs new file mode 100644 index 0000000..eb50c39 --- /dev/null +++ b/vsock/src/vhu_vsock.rs @@ -0,0 +1,346 @@ +use super::vhu_vsock_thread::*; +use clap::ArgMatches; +use core::slice; +use std::convert::TryFrom; +use std::{io, mem, result, sync::Mutex, u16, u32, u64, u8}; +use thiserror::Error as ThisError; +use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; +use vhost_user_backend::{VhostUserBackendMut, VringRwLock}; +use virtio_bindings::bindings::{ + virtio_blk::__u64, virtio_net::VIRTIO_F_NOTIFY_ON_EMPTY, virtio_net::VIRTIO_F_VERSION_1, + virtio_ring::VIRTIO_RING_F_EVENT_IDX, +}; +use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap}; +use vmm_sys_util::{ + epoll::EventSet, + eventfd::{EventFd, EFD_NONBLOCK}, +}; + +const NUM_QUEUES: usize = 2; +const QUEUE_SIZE: usize = 256; + +// New descriptors pending on the rx queue +const RX_QUEUE_EVENT: u16 = 0; +// New descriptors are pending on the tx queue. +const TX_QUEUE_EVENT: u16 = 1; +// New descriptors are pending on the event queue. +const EVT_QUEUE_EVENT: u16 = 2; +// Notification coming from the backend. +pub const BACKEND_EVENT: u16 = 3; + +// Vsock connection TX buffer capacity +// TODO: Make this value configurable +pub const CONN_TX_BUF_SIZE: u32 = 64 * 1024; + +// CID of the host +pub const VSOCK_HOST_CID: u64 = 2; + +// Connection oriented packet +pub const VSOCK_TYPE_STREAM: u16 = 1; + +// Vsock packet operation ID +// +// Connection request +pub const VSOCK_OP_REQUEST: u16 = 1; +// Connection response +pub const VSOCK_OP_RESPONSE: u16 = 2; +// Connection reset +pub const VSOCK_OP_RST: u16 = 3; +// Shutdown connection +pub const VSOCK_OP_SHUTDOWN: u16 = 4; +// Data read/write +pub const VSOCK_OP_RW: u16 = 5; +// Flow control credit update +pub const VSOCK_OP_CREDIT_UPDATE: u16 = 6; +// Flow control credit request +pub const VSOCK_OP_CREDIT_REQUEST: u16 = 7; + +// Vsock packet flags +// +// VSOCK_OP_SHUTDOWN: Packet sender will receive no more data +pub const VSOCK_FLAGS_SHUTDOWN_RCV: u32 = 1; +// VSOCK_OP_SHUTDOWN: Packet sender will send no more data +pub const VSOCK_FLAGS_SHUTDOWN_SEND: u32 = 2; + +// Queue mask to select vrings. +const QUEUE_MASK: u64 = 0b11; + +pub(crate) type Result = std::result::Result; + +/// Below enum defines custom error types. +#[derive(Debug, ThisError)] +pub(crate) enum Error { + #[error("Failed to handle event other than EPOLLIN event")] + HandleEventNotEpollIn, + #[error("Failed to handle unknown event")] + HandleUnknownEvent, + #[error("Failed to accept new local socket connection")] + UnixAccept(std::io::Error), + #[error("Failed to create an epoll fd")] + EpollFdCreate(std::io::Error), + #[error("Failed to add to epoll")] + EpollAdd(std::io::Error), + #[error("Failed to modify evset associated with epoll")] + EpollModify(std::io::Error), + #[error("Failed to read from unix stream")] + UnixRead(std::io::Error), + #[error("Failed to convert byte array to string")] + ConvertFromUtf8(std::str::Utf8Error), + #[error("Invalid vsock connection request from host")] + InvalidPortRequest, + #[error("Unable to convert string to integer")] + ParseInteger(std::num::ParseIntError), + #[error("Error reading stream port")] + ReadStreamPort(Box), + #[error("Failed to de-register fd from epoll")] + EpollRemove(std::io::Error), + #[error("No memory configured")] + NoMemoryConfigured, + #[error("Unable to iterate queue")] + IterateQueue, + #[error("No rx request available")] + NoRequestRx, + #[error("Unable to create thread pool")] + CreateThreadPool(std::io::Error), + #[error("Packet missing data buffer")] + PktBufMissing, + #[error("Failed to connect to unix socket")] + UnixConnect(std::io::Error), + #[error("Unable to write to unix stream")] + UnixWrite, + #[error("Unable to push data to local tx buffer")] + LocalTxBufFull, + #[error("Unable to flush data from local tx buffer")] + LocalTxBufFlush(std::io::Error), + #[error("No free local port available for new host inititated connection")] + NoFreeLocalPort, + #[error("Backend rx queue is empty")] + EmptyBackendRxQ, + #[error("Invalid socket path as cmd line argument")] + SocketPathMissing, + #[error("Invalid UDS path as cmd line argument")] + UDSPathMissing, +} + +impl std::convert::From for std::io::Error { + fn from(e: Error) -> Self { + std::io::Error::new(io::ErrorKind::Other, e) + } +} + +#[derive(Debug, Clone)] +/// This structure is the public API through which an external program +/// is allowed to configure the backend. +pub(crate) struct VsockConfig { + guest_cid: u64, + socket: String, + uds_path: String, +} + +impl VsockConfig { + /// Create a new instance of the VsockConfig struct, containing the + /// parameters to be fed into the vsock-backend server. + pub fn new(guest_cid: u64, socket: String, uds_path: String) -> Self { + Self { + guest_cid, + socket, + uds_path, + } + } + + /// Return the guest's current CID. + pub fn get_guest_cid(&self) -> u64 { + self.guest_cid + } + + /// Return the path of the unix domain socket which is listening to + /// requests from the host side application. + pub fn get_uds_path(&self) -> String { + String::from(&self.uds_path) + } + + /// Return the path of the unix domain socket which is listening to + /// requests from the guest. + pub fn get_socket_path(&self) -> String { + String::from(&self.socket) + } +} + +impl TryFrom for VsockConfig { + type Error = Error; + + fn try_from(cmd_args: ArgMatches) -> Result { + let guest_cid = cmd_args + .value_of("guest_cid") + .unwrap_or("3") + .parse::() + .unwrap_or(3); + let socket = cmd_args + .value_of("socket") + .ok_or(Error::SocketPathMissing)? + .to_string(); + let uds_path = cmd_args + .value_of("uds_path") + .ok_or(Error::UDSPathMissing)? + .to_string(); + + Ok(VsockConfig::new(guest_cid, socket, uds_path)) + } +} + +/// A local port and peer port pair used to retrieve +/// the corresponding connection. +#[derive(Hash, PartialEq, Eq, Debug, Clone)] +pub struct ConnMapKey { + local_port: u32, + peer_port: u32, +} + +impl ConnMapKey { + pub fn new(local_port: u32, peer_port: u32) -> Self { + Self { + local_port, + peer_port, + } + } +} + +pub struct VhostUserVsockBackend { + guest_cid: __u64, + pub threads: Vec>, + queues_per_thread: Vec, + pub exit_event: EventFd, +} + +impl VhostUserVsockBackend { + pub(crate) fn new(vsock_config: VsockConfig) -> Result { + let thread = Mutex::new( + VhostUserVsockThread::new(vsock_config.get_uds_path(), vsock_config.get_guest_cid()) + .unwrap(), + ); + let queues_per_thread = vec![QUEUE_MASK]; + + Ok(Self { + guest_cid: vsock_config.get_guest_cid(), + threads: vec![thread], + queues_per_thread, + exit_event: EventFd::new(EFD_NONBLOCK).expect("Creating exit eventfd"), + }) + } +} + +impl VhostUserBackendMut for VhostUserVsockBackend { + fn num_queues(&self) -> usize { + NUM_QUEUES + } + + fn max_queue_size(&self) -> usize { + QUEUE_SIZE + } + + fn features(&self) -> u64 { + 1 << VIRTIO_F_VERSION_1 + | 1 << VIRTIO_F_NOTIFY_ON_EMPTY + | 1 << VIRTIO_RING_F_EVENT_IDX + | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + VhostUserProtocolFeatures::CONFIG + } + + fn set_event_idx(&mut self, enabled: bool) { + for thread in self.threads.iter() { + thread.lock().unwrap().event_idx = enabled; + } + } + + fn update_memory( + &mut self, + atomic_mem: GuestMemoryAtomic, + ) -> result::Result<(), io::Error> { + for thread in self.threads.iter() { + thread.lock().unwrap().mem = Some(atomic_mem.clone()); + } + Ok(()) + } + + fn handle_event( + &mut self, + device_event: u16, + evset: EventSet, + vrings: &[VringRwLock], + thread_id: usize, + ) -> result::Result { + let vring_rx = &vrings[0]; + let vring_tx = &vrings[1]; + + if evset == EventSet::OUT { + dbg!("received epollout"); + } + + if evset != EventSet::IN { + return Err(Error::HandleEventNotEpollIn.into()); + } + + let mut thread = self.threads[thread_id].lock().unwrap(); + let evt_idx = thread.event_idx; + + match device_event { + RX_QUEUE_EVENT => {} + TX_QUEUE_EVENT => { + thread.process_tx(vring_tx, evt_idx)?; + } + EVT_QUEUE_EVENT => {} + BACKEND_EVENT => { + thread.process_backend_evt(evset); + thread.process_tx(vring_tx, evt_idx)?; + } + _ => { + return Err(Error::HandleUnknownEvent.into()); + } + } + + if device_event != EVT_QUEUE_EVENT && thread.thread_backend.pending_rx() { + thread.process_rx(vring_rx, evt_idx)?; + } + + Ok(false) + } + + fn get_config(&self, _offset: u32, _size: u32) -> Vec { + let buf = unsafe { + slice::from_raw_parts( + &self.guest_cid as *const __u64 as *const _, + mem::size_of::<__u64>(), + ) + }; + buf.to_vec() + } + + fn queues_per_thread(&self) -> Vec { + self.queues_per_thread.clone() + } + + fn exit_event(&self, _thread_index: usize) -> Option { + Some(self.exit_event.try_clone().expect("Cloning exit eventfd")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vsock_config_setup() { + let vsock_config = VsockConfig::new( + 3, + "/tmp/vhost4.socket".to_string(), + "/tmp/vm4.vsock".to_string(), + ); + + assert_eq!(vsock_config.get_guest_cid(), 3); + assert_eq!(vsock_config.get_socket_path(), "/tmp/vhost4.socket"); + assert_eq!(vsock_config.get_uds_path(), "/tmp/vm4.vsock"); + } +} diff --git a/vsock/src/vhu_vsock_thread.rs b/vsock/src/vhu_vsock_thread.rs new file mode 100644 index 0000000..e77e2f1 --- /dev/null +++ b/vsock/src/vhu_vsock_thread.rs @@ -0,0 +1,553 @@ +use super::{ + packet::*, + rxops::*, + thread_backend::*, + vhu_vsock::{ConnMapKey, Error, Result, VhostUserVsockBackend, BACKEND_EVENT, VSOCK_HOST_CID}, + vsock_conn::*, +}; +use futures::executor::{ThreadPool, ThreadPoolBuilder}; +use log::warn; +use std::{ + fs::File, + io, + io::Read, + num::Wrapping, + os::unix::{ + net::{UnixListener, UnixStream}, + prelude::{AsRawFd, FromRawFd, RawFd}, + }, + sync::{Arc, RwLock}, +}; +use vhost_user_backend::{VringEpollHandler, VringRwLock, VringT}; +use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap}; +use vmm_sys_util::{ + epoll::EventSet, + eventfd::{EventFd, EFD_NONBLOCK}, +}; + +type ArcVhostBknd = Arc>; + +pub struct VhostUserVsockThread { + /// Guest memory map. + pub mem: Option>, + /// VIRTIO_RING_F_EVENT_IDX. + pub event_idx: bool, + /// Host socket raw file descriptor. + host_sock: RawFd, + /// Listener listening for new connections on the host. + host_listener: UnixListener, + /// Used to kill the thread. + pub kill_evt: EventFd, + /// Instance of VringWorker. + vring_worker: Option>>, + /// epoll fd to which new host connections are added. + epoll_file: File, + /// VsockThreadBackend instance. + pub thread_backend: VsockThreadBackend, + /// CID of the guest. + guest_cid: u64, + /// Thread pool to handle event idx. + pool: ThreadPool, + /// host side port on which application listens. + local_port: Wrapping, +} + +impl VhostUserVsockThread { + /// Create a new instance of VhostUserVsockThread. + pub(crate) fn new(uds_path: String, guest_cid: u64) -> Result { + // TODO: better error handling + if let Ok(()) = std::fs::remove_file(uds_path.clone()) {} + let host_sock = UnixListener::bind(&uds_path) + .and_then(|sock| sock.set_nonblocking(true).map(|_| sock)) + .unwrap(); + + let epoll_fd = epoll::create(true).map_err(Error::EpollFdCreate)?; + let epoll_file = unsafe { File::from_raw_fd(epoll_fd) }; + + let host_raw_fd = host_sock.as_raw_fd(); + + let thread = VhostUserVsockThread { + mem: None, + event_idx: false, + host_sock: host_sock.as_raw_fd(), + host_listener: host_sock, + kill_evt: EventFd::new(EFD_NONBLOCK).unwrap(), + vring_worker: None, + epoll_file, + thread_backend: VsockThreadBackend::new(uds_path, epoll_fd), + guest_cid, + pool: ThreadPoolBuilder::new() + .pool_size(1) + .create() + .map_err(Error::CreateThreadPool)?, + local_port: Wrapping(0), + }; + + VhostUserVsockThread::epoll_register(epoll_fd, host_raw_fd, epoll::Events::EPOLLIN)?; + + Ok(thread) + } + + /// Register a file with an epoll to listen for events in evset. + pub(crate) fn epoll_register(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()> { + epoll::ctl( + epoll_fd, + epoll::ControlOptions::EPOLL_CTL_ADD, + fd, + epoll::Event::new(evset, fd as u64), + ) + .map_err(Error::EpollAdd)?; + + Ok(()) + } + + /// Remove a file from the epoll. + pub(crate) fn epoll_unregister(epoll_fd: RawFd, fd: RawFd) -> Result<()> { + epoll::ctl( + epoll_fd, + epoll::ControlOptions::EPOLL_CTL_DEL, + fd, + epoll::Event::new(epoll::Events::empty(), 0), + ) + .map_err(Error::EpollRemove)?; + + Ok(()) + } + + /// Modify the events we listen to for the fd in the epoll. + pub(crate) fn epoll_modify(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()> { + epoll::ctl( + epoll_fd, + epoll::ControlOptions::EPOLL_CTL_MOD, + fd, + epoll::Event::new(evset, fd as u64), + ) + .map_err(Error::EpollModify)?; + + Ok(()) + } + + /// Return raw file descriptor of the epoll file. + fn get_epoll_fd(&self) -> RawFd { + self.epoll_file.as_raw_fd() + } + + /// Set self's VringWorker. + pub fn set_vring_worker( + &mut self, + vring_worker: Option>>, + ) { + self.vring_worker = vring_worker; + self.vring_worker + .as_ref() + .unwrap() + .register_listener(self.get_epoll_fd(), EventSet::IN, u64::from(BACKEND_EVENT)) + .unwrap(); + } + + /// Process a BACKEND_EVENT received by VhostUserVsockBackend. + pub fn process_backend_evt(&mut self, _evset: EventSet) { + let mut epoll_events = vec![epoll::Event::new(epoll::Events::empty(), 0); 32]; + 'epoll: loop { + match epoll::wait(self.epoll_file.as_raw_fd(), 0, epoll_events.as_mut_slice()) { + Ok(ev_cnt) => { + for evt in epoll_events.iter().take(ev_cnt) { + self.handle_event( + evt.data as RawFd, + epoll::Events::from_bits(evt.events).unwrap(), + ); + } + } + Err(e) => { + if e.kind() == io::ErrorKind::Interrupted { + continue; + } + warn!("failed to consume new epoll event"); + } + } + break 'epoll; + } + } + + /// Handle a BACKEND_EVENT by either accepting a new connection or + /// forwarding a request to the appropriate connection object. + fn handle_event(&mut self, fd: RawFd, evset: epoll::Events) { + if fd == self.host_sock { + // This is a new connection initiated by an application running on the host + self.host_listener + .accept() + .map_err(Error::UnixAccept) + .and_then(|(stream, _)| { + stream + .set_nonblocking(true) + .map(|_| stream) + .map_err(Error::UnixAccept) + }) + .and_then(|stream| self.add_stream_listener(stream)) + .unwrap_or_else(|err| { + warn!("Unable to accept new local connection: {:?}", err); + }); + } else { + // Check if the stream represented by fd has already established a + // connection with the application running in the guest + if let std::collections::hash_map::Entry::Vacant(_) = + self.thread_backend.listener_map.entry(fd) + { + // New connection from the host + if evset != epoll::Events::EPOLLIN { + // Has to be EPOLLIN as it was not connected previously + return; + } + let mut unix_stream = self.thread_backend.stream_map.remove(&fd).unwrap(); + + // Local peer is sending a "connect PORT\n" command + let peer_port = Self::read_local_stream_port(&mut unix_stream).unwrap(); + + // Allocate a local port number + let local_port = match self.allocate_local_port() { + Ok(lp) => lp, + Err(_) => { + return; + } + }; + + // Insert the fd into the backend's maps + self.thread_backend + .listener_map + .insert(fd, ConnMapKey::new(local_port, peer_port)); + + // Create a new connection object an enqueue a connection request + // packet to be sent to the guest + let conn_map_key = ConnMapKey::new(local_port, peer_port); + let mut new_vsock_conn = VsockConnection::new_local_init( + unix_stream, + VSOCK_HOST_CID, + local_port, + self.guest_cid, + peer_port, + self.get_epoll_fd(), + ); + new_vsock_conn.rx_queue.enqueue(RxOps::Request); + new_vsock_conn.set_peer_port(peer_port); + + // Add connection object into the backend's maps + self.thread_backend + .conn_map + .insert(conn_map_key, new_vsock_conn); + + self.thread_backend + .backend_rxq + .push_back(ConnMapKey::new(local_port, peer_port)); + + // Re-register the fd to listen for EPOLLIN and EPOLLOUT events + Self::epoll_modify( + self.get_epoll_fd(), + fd, + epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, + ) + .unwrap(); + } else { + // Previously connected connection + let key = self.thread_backend.listener_map.get(&fd).unwrap(); + let vsock_conn = self.thread_backend.conn_map.get_mut(&key).unwrap(); + + if evset == epoll::Events::EPOLLOUT { + // Flush any remaining data from the tx buffer + match vsock_conn.tx_buf.flush_to(&mut vsock_conn.stream) { + Ok(cnt) => { + vsock_conn.fwd_cnt += Wrapping(cnt as u32); + vsock_conn.rx_queue.enqueue(RxOps::CreditUpdate); + self.thread_backend.backend_rxq.push_back(ConnMapKey::new( + vsock_conn.local_port, + vsock_conn.peer_port, + )); + } + Err(e) => { + dbg!("Error: {:?}", e); + } + } + return; + } + + // Unregister stream from the epoll, register when connection is + // established with the guest + Self::epoll_unregister(self.epoll_file.as_raw_fd(), fd).unwrap(); + + // Enqueue a read request + vsock_conn.rx_queue.enqueue(RxOps::Rw); + self.thread_backend + .backend_rxq + .push_back(ConnMapKey::new(vsock_conn.local_port, vsock_conn.peer_port)); + } + } + } + + /// Allocate a new local port number. + fn allocate_local_port(&mut self) -> Result { + // TODO: Improve space efficiency of this operation + // TODO: Reuse the conn_map HashMap + // TODO: Test this. + let mut alloc_local_port = self.local_port.0; + loop { + if !self + .thread_backend + .local_port_set + .contains(&alloc_local_port) + { + // The port set doesn't contain the newly allocated port number. + self.local_port = Wrapping(alloc_local_port + 1); + self.thread_backend.local_port_set.insert(alloc_local_port); + return Ok(alloc_local_port); + } else { + if alloc_local_port == self.local_port.0 { + // We have exhausted our search and wrapped back to the current port number + return Err(Error::NoFreeLocalPort); + } + alloc_local_port += 1; + } + } + } + + /// Read `CONNECT PORT_NUM\n` from the connected stream. + fn read_local_stream_port(stream: &mut UnixStream) -> Result { + let mut buf = [0u8; 32]; + + // Minimum number of bytes we should be able to read + // Corresponds to 'CONNECT 0\n' + const MIN_READ_LEN: usize = 10; + + // Read in the minimum number of bytes we can read + stream + .read_exact(&mut buf[..MIN_READ_LEN]) + .map_err(Error::UnixRead)?; + + let mut read_len = MIN_READ_LEN; + while buf[read_len - 1] != b'\n' && read_len < buf.len() { + stream + .read_exact(&mut buf[read_len..read_len + 1]) + .map_err(Error::UnixRead)?; + read_len += 1; + } + + let mut word_iter = std::str::from_utf8(&buf[..read_len]) + .map_err(Error::ConvertFromUtf8)? + .split_whitespace(); + + word_iter + .next() + .ok_or(Error::InvalidPortRequest) + .and_then(|word| { + if word.to_lowercase() == "connect" { + Ok(()) + } else { + Err(Error::InvalidPortRequest) + } + }) + .and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest)) + .and_then(|word| word.parse::().map_err(Error::ParseInteger)) + .map_err(|e| Error::ReadStreamPort(Box::new(e))) + } + + /// Add a stream to epoll to listen for EPOLLIN events. + fn add_stream_listener(&mut self, stream: UnixStream) -> Result<()> { + let stream_fd = stream.as_raw_fd(); + self.thread_backend.stream_map.insert(stream_fd, stream); + VhostUserVsockThread::epoll_register( + self.get_epoll_fd(), + stream_fd, + epoll::Events::EPOLLIN, + )?; + + // self.register_listener(stream_fd, BACKEND_EVENT); + Ok(()) + } + + /// Iterate over the rx queue and process rx requests. + fn process_rx_queue(&mut self, vring: &VringRwLock) -> Result { + let mut used_any = false; + let atomic_mem = match &self.mem { + Some(m) => m, + None => return Err(Error::NoMemoryConfigured), + }; + + let mut vring_mut = vring.get_mut(); + + let queue = vring_mut.get_queue_mut(); + + while let Some(mut avail_desc) = queue.iter().map_err(|_| Error::IterateQueue)?.next() { + used_any = true; + let atomic_mem = atomic_mem.clone(); + + let head_idx = avail_desc.head_index(); + let used_len = + match VsockPacket::from_rx_virtq_head(&mut avail_desc, atomic_mem.clone()) { + Ok(mut pkt) => { + if self.thread_backend.recv_pkt(&mut pkt).is_ok() { + pkt.hdr().len() + pkt.len() as usize + } else { + queue.iter().unwrap().go_to_previous_position(); + break; + } + } + Err(e) => { + warn!("vsock: RX queue error: {:?}", e); + 0 + } + }; + + let vring = vring.clone(); + let event_idx = self.event_idx; + + self.pool.spawn_ok(async move { + // TODO: Understand why doing the following in the pool works + if event_idx { + if vring.add_used(head_idx, used_len as u32).is_err() { + warn!("Could not return used descriptors to ring"); + } + match vring.needs_notification() { + Err(_) => { + warn!("Could not check if queue needs to be notified"); + vring.signal_used_queue().unwrap(); + } + Ok(needs_notification) => { + if needs_notification { + vring.signal_used_queue().unwrap(); + } + } + } + } else { + if vring.add_used(head_idx, used_len as u32).is_err() { + warn!("Could not return used descriptors to ring"); + } + vring.signal_used_queue().unwrap(); + } + }); + + if !self.thread_backend.pending_rx() { + break; + } + } + Ok(used_any) + } + + /// Wrapper to process rx queue based on whether event idx is enabled or not. + pub(crate) fn process_rx(&mut self, vring: &VringRwLock, event_idx: bool) -> Result { + if event_idx { + // To properly handle EVENT_IDX we need to keep calling + // process_rx_queue until it stops finding new requests + // on the queue, as vm-virtio's Queue implementation + // only checks avail_index once + loop { + if !self.thread_backend.pending_rx() { + break; + } + vring.disable_notification().unwrap(); + + self.process_rx_queue(vring)?; + if !vring.enable_notification().unwrap() { + break; + } + // TODO: This may not be required because of + // previous pending_rx check + // if !work { + // break; + // } + } + } else { + self.process_rx_queue(vring)?; + } + Ok(false) + } + + /// Process tx queue and send requests to the backend for processing. + fn process_tx_queue(&mut self, vring: &VringRwLock) -> Result { + let mut used_any = false; + + let atomic_mem = match &self.mem { + Some(m) => m, + None => return Err(Error::NoMemoryConfigured), + }; + + while let Some(mut avail_desc) = vring + .get_mut() + .get_queue_mut() + .iter() + .map_err(|_| Error::IterateQueue)? + .next() + { + used_any = true; + let atomic_mem = atomic_mem.clone(); + + let head_idx = avail_desc.head_index(); + let pkt = match VsockPacket::from_tx_virtq_head(&mut avail_desc, atomic_mem.clone()) { + Ok(pkt) => pkt, + Err(e) => { + dbg!("vsock: error reading TX packet: {:?}", e); + continue; + } + }; + + if self.thread_backend.send_pkt(&pkt).is_err() { + vring + .get_mut() + .get_queue_mut() + .iter() + .unwrap() + .go_to_previous_position(); + break; + } + + // TODO: Check if the protocol requires read length to be correct + let used_len = 0; + + let vring = vring.clone(); + let event_idx = self.event_idx; + + self.pool.spawn_ok(async move { + if event_idx { + if vring.add_used(head_idx, used_len as u32).is_err() { + warn!("Could not return used descriptors to ring"); + } + match vring.needs_notification() { + Err(_) => { + warn!("Could not check if queue needs to be notified"); + vring.signal_used_queue().unwrap(); + } + Ok(needs_notification) => { + if needs_notification { + vring.signal_used_queue().unwrap(); + } + } + } + } else { + if vring.add_used(head_idx, used_len as u32).is_err() { + warn!("Could not return used descriptors to ring"); + } + vring.signal_used_queue().unwrap(); + } + }); + } + + Ok(used_any) + } + + /// Wrapper to process tx queue based on whether event idx is enabled or not. + pub(crate) fn process_tx(&mut self, vring_lock: &VringRwLock, event_idx: bool) -> Result { + if event_idx { + // To properly handle EVENT_IDX we need to keep calling + // process_rx_queue until it stops finding new requests + // on the queue, as vm-virtio's Queue implementation + // only checks avail_index once + loop { + vring_lock.disable_notification().unwrap(); + self.process_tx_queue(vring_lock)?; + if !vring_lock.enable_notification().unwrap() { + break; + } + } + } else { + self.process_tx_queue(vring_lock)?; + } + Ok(false) + } +} diff --git a/vsock/src/vsock_conn.rs b/vsock/src/vsock_conn.rs new file mode 100644 index 0000000..cc0170c --- /dev/null +++ b/vsock/src/vsock_conn.rs @@ -0,0 +1,605 @@ +use super::{ + packet::*, + rxops::*, + rxqueue::*, + txbuf::*, + vhu_vsock::{ + Error, Result, CONN_TX_BUF_SIZE, VSOCK_FLAGS_SHUTDOWN_RCV, VSOCK_FLAGS_SHUTDOWN_SEND, + VSOCK_OP_CREDIT_REQUEST, VSOCK_OP_CREDIT_UPDATE, VSOCK_OP_REQUEST, VSOCK_OP_RESPONSE, + VSOCK_OP_RST, VSOCK_OP_RW, VSOCK_OP_SHUTDOWN, VSOCK_TYPE_STREAM, + }, + vhu_vsock_thread::VhostUserVsockThread, +}; +use log::info; +use std::{ + io::{ErrorKind, Read, Write}, + num::Wrapping, + os::unix::prelude::{AsRawFd, RawFd}, +}; + +#[derive(Debug)] +pub struct VsockConnection { + /// Host-side stream corresponding to this vsock connection. + pub stream: S, + /// Specifies if the stream is connected to a listener on the host. + pub connect: bool, + /// Port at which a guest application is listening to. + pub peer_port: u32, + /// Queue holding pending rx operations per connection. + pub rx_queue: RxQueue, + /// CID of the host. + local_cid: u64, + /// Port on the host at which a host-side application listens to. + pub local_port: u32, + /// CID of the guest. + pub guest_cid: u64, + /// Total number of bytes written to stream from tx buffer. + pub fwd_cnt: Wrapping, + /// Total number of bytes previously forwarded to stream. + last_fwd_cnt: Wrapping, + /// Size of buffer the guest has allocated for this connection. + peer_buf_alloc: u32, + /// Number of bytes the peer has forwarded to a connection. + peer_fwd_cnt: Wrapping, + /// The total number of bytes sent to the guest vsock driver. + rx_cnt: Wrapping, + /// epoll fd to which this connection's stream has to be added. + pub epoll_fd: RawFd, + /// Local tx buffer. + pub tx_buf: LocalTxBuf, +} + +impl VsockConnection { + /// Create a new vsock connection object for locally i.e host-side + /// inititated connections. + pub fn new_local_init( + stream: S, + local_cid: u64, + local_port: u32, + guest_cid: u64, + guest_port: u32, + epoll_fd: RawFd, + ) -> Self { + Self { + stream, + connect: false, + peer_port: guest_port, + rx_queue: RxQueue::new(), + local_cid, + local_port, + guest_cid, + fwd_cnt: Wrapping(0), + last_fwd_cnt: Wrapping(0), + peer_buf_alloc: 0, + peer_fwd_cnt: Wrapping(0), + rx_cnt: Wrapping(0), + epoll_fd, + tx_buf: LocalTxBuf::new(), + } + } + + /// Create a new vsock connection object for connections initiated by + /// an application running in the guest. + pub fn new_peer_init( + stream: S, + local_cid: u64, + local_port: u32, + guest_cid: u64, + guest_port: u32, + epoll_fd: RawFd, + peer_buf_alloc: u32, + ) -> Self { + let mut rx_queue = RxQueue::new(); + rx_queue.enqueue(RxOps::Response); + Self { + stream, + connect: false, + peer_port: guest_port, + rx_queue, + local_cid, + local_port, + guest_cid, + fwd_cnt: Wrapping(0), + last_fwd_cnt: Wrapping(0), + peer_buf_alloc, + peer_fwd_cnt: Wrapping(0), + rx_cnt: Wrapping(0), + epoll_fd, + tx_buf: LocalTxBuf::new(), + } + } + + /// Set the peer port to the guest side application's port. + pub fn set_peer_port(&mut self, peer_port: u32) { + self.peer_port = peer_port; + } + + /// Process a vsock packet that is meant for this connection. + /// Forward data to the host-side application if the vsock packet + /// contains a RW operation. + pub(crate) fn recv_pkt(&mut self, pkt: &mut VsockPacket) -> Result<()> { + // Initialize all fields in the packet header + self.init_pkt(pkt); + + match self.rx_queue.dequeue() { + Some(RxOps::Request) => { + // Send a connection request to the guest-side application + pkt.set_op(VSOCK_OP_REQUEST); + Ok(()) + } + Some(RxOps::Rw) => { + if !self.connect { + // There is no host-side application listening for this + // packet, hence send back an RST. + pkt.set_op(VSOCK_OP_RST); + return Ok(()); + } + + // Check if peer has space for receiving data + if self.need_credit_update_from_peer() { + self.last_fwd_cnt = self.fwd_cnt; + pkt.set_op(VSOCK_OP_CREDIT_REQUEST); + return Ok(()); + } + let buf = pkt.buf_mut().ok_or(Error::PktBufMissing)?; + + // Perform a credit check to find the maximum read size. The read + // data must fit inside a packet buffer and be within peer's + // available buffer space + let max_read_len = std::cmp::min(buf.len(), self.peer_avail_credit()); + + // Read data from the stream directly into the buffer + if let Ok(read_cnt) = self.stream.read(&mut buf[..max_read_len]) { + if read_cnt == 0 { + // If no data was read then the stream was closed down unexpectedly. + // Send a shutdown packet to the guest-side application. + pkt.set_op(VSOCK_OP_SHUTDOWN) + .set_flag(VSOCK_FLAGS_SHUTDOWN_RCV) + .set_flag(VSOCK_FLAGS_SHUTDOWN_SEND); + } else { + // If data was read, then set the length field in the packet header + // to the amount of data that was read. + pkt.set_op(VSOCK_OP_RW).set_len(read_cnt as u32); + + // Re-register the stream file descriptor for read and write events + VhostUserVsockThread::epoll_register( + self.epoll_fd, + self.stream.as_raw_fd(), + epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, + )?; + } + + // Update the rx_cnt with the amount of data in the vsock packet. + self.rx_cnt += Wrapping(pkt.len()); + self.last_fwd_cnt = self.fwd_cnt; + } + Ok(()) + } + Some(RxOps::Response) => { + // A response has been received to a newly initiated host-side connection + self.connect = true; + pkt.set_op(VSOCK_OP_RESPONSE); + Ok(()) + } + Some(RxOps::CreditUpdate) => { + // Request credit update from the guest. + if !self.rx_queue.pending_rx() { + // Waste an rx buffer if no rx is pending + pkt.set_op(VSOCK_OP_CREDIT_UPDATE); + self.last_fwd_cnt = self.fwd_cnt; + } + Ok(()) + } + _ => Err(Error::NoRequestRx), + } + } + + /// Deliver a guest generated packet to this connection. + /// + /// Returns: + /// - always `Ok(())` to indicate that the packet has been consumed + pub(crate) fn send_pkt(&mut self, pkt: &VsockPacket) -> Result<()> { + // Update peer credit information + self.peer_buf_alloc = pkt.buf_alloc(); + self.peer_fwd_cnt = Wrapping(pkt.fwd_cnt()); + + match pkt.op() { + VSOCK_OP_RESPONSE => { + // Confirmation for a host initiated connection + // TODO: Handle stream write error in a better manner + let response = format!("OK {}\n", self.peer_port); + self.stream.write_all(response.as_bytes()).unwrap(); + self.connect = true; + } + VSOCK_OP_RW => { + // Data has to be written to the host-side stream + if pkt.buf().is_none() { + info!( + "Dropping empty packet from guest (lp={}, pp={})", + self.local_port, self.peer_port + ); + return Ok(()); + } + + let buf_slice = &pkt.buf().unwrap()[..(pkt.len() as usize)]; + if let Err(err) = self.send_bytes(buf_slice) { + // TODO: Terminate this connection + dbg!("err:{:?}", err); + return Ok(()); + } + } + VSOCK_OP_CREDIT_UPDATE => { + // Already updated the credit + + // Re-register the stream file descriptor for read and write events + if VhostUserVsockThread::epoll_modify( + self.epoll_fd, + self.stream.as_raw_fd(), + epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, + ) + .is_err() + { + VhostUserVsockThread::epoll_register( + self.epoll_fd, + self.stream.as_raw_fd(), + epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, + ) + .unwrap(); + }; + } + VSOCK_OP_CREDIT_REQUEST => { + // Send back this connection's credit information + self.rx_queue.enqueue(RxOps::CreditUpdate); + } + VSOCK_OP_SHUTDOWN => { + // Shutdown this connection + let recv_off = pkt.flags() & VSOCK_FLAGS_SHUTDOWN_RCV != 0; + let send_off = pkt.flags() & VSOCK_FLAGS_SHUTDOWN_SEND != 0; + + if recv_off && send_off && self.tx_buf.is_empty() { + self.rx_queue.enqueue(RxOps::Reset); + } + } + _ => {} + } + + Ok(()) + } + + /// Write data to the host-side stream. + /// + /// Returns: + /// - Ok(cnt) where cnt is the number of bytes written to the stream + /// - Err(Error::UnixWrite) if there was an error writing to the stream + fn send_bytes(&mut self, buf: &[u8]) -> Result<()> { + if !self.tx_buf.is_empty() { + // Data is already present in the buffer and the backend + // is waiting for a EPOLLOUT event to flush it + return self.tx_buf.push(buf); + } + + // Write data to the stream + let written_count = match self.stream.write(buf) { + Ok(cnt) => cnt, + Err(e) => { + if e.kind() == ErrorKind::WouldBlock { + 0 + } else { + println!("send_bytes error: {:?}", e); + return Err(Error::UnixWrite); + } + } + }; + + // Increment forwarded count by number of bytes written to the stream + self.fwd_cnt += Wrapping(written_count as u32); + // TODO: https://github.com/torvalds/linux/commit/c69e6eafff5f725bc29dcb8b52b6782dca8ea8a2 + self.rx_queue.enqueue(RxOps::CreditUpdate); + + if written_count != buf.len() { + return self.tx_buf.push(&buf[written_count..]); + } + + Ok(()) + } + + /// Initialize all header fields in the vsock packet. + fn init_pkt<'a>(&self, pkt: &'a mut VsockPacket) -> &'a mut VsockPacket { + // Zero out the packet header + for b in pkt.hdr_mut() { + *b = 0; + } + + pkt.set_src_cid(self.local_cid) + .set_dst_cid(self.guest_cid) + .set_src_port(self.local_port) + .set_dst_port(self.peer_port) + .set_type(VSOCK_TYPE_STREAM) + .set_buf_alloc(CONN_TX_BUF_SIZE) + .set_fwd_cnt(self.fwd_cnt.0) + } + + /// Get max number of bytes we can send to peer without overflowing + /// the peer's buffer. + fn peer_avail_credit(&self) -> usize { + (Wrapping(self.peer_buf_alloc as u32) - (self.rx_cnt - self.peer_fwd_cnt)).0 as usize + } + + /// Check if we need a credit update from the peer before sending + /// more data to it. + fn need_credit_update_from_peer(&self) -> bool { + self.peer_avail_credit() == 0 + } +} + +#[cfg(test)] +mod tests { + use byteorder::{ByteOrder, LittleEndian}; + + use super::*; + use crate::packet::tests::{prepare_desc_chain_vsock, HeadParams}; + use crate::vhu_vsock::VSOCK_HOST_CID; + use std::io::Result as IoResult; + + struct VsockDummySocket { + data: Vec, + } + + impl VsockDummySocket { + fn new() -> Self { + Self { data: Vec::new() } + } + } + + impl Write for VsockDummySocket { + fn write(&mut self, buf: &[u8]) -> std::result::Result { + self.data.clear(); + self.data.extend_from_slice(buf); + + Ok(buf.len()) + } + fn flush(&mut self) -> IoResult<()> { + Ok(()) + } + } + + impl Read for VsockDummySocket { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + buf[..self.data.len()].copy_from_slice(&self.data); + Ok(self.data.len()) + } + } + + impl AsRawFd for VsockDummySocket { + fn as_raw_fd(&self) -> RawFd { + -1 + } + } + + #[test] + fn test_vsock_conn_init() { + // new locally inititated connection + let dummy_file = VsockDummySocket::new(); + let mut vsock_conn_local = + VsockConnection::new_local_init(dummy_file, VSOCK_HOST_CID, 5000, 3, 5001, -1); + + assert!(!vsock_conn_local.connect); + assert_eq!(vsock_conn_local.peer_port, 5001); + assert_eq!(vsock_conn_local.rx_queue, RxQueue::new()); + assert_eq!(vsock_conn_local.local_cid, VSOCK_HOST_CID); + assert_eq!(vsock_conn_local.local_port, 5000); + assert_eq!(vsock_conn_local.guest_cid, 3); + + // set peer port + vsock_conn_local.set_peer_port(5002); + assert_eq!(vsock_conn_local.peer_port, 5002); + + // New connection initiated by the peer/guest + let dummy_file = VsockDummySocket::new(); + let mut vsock_conn_peer = + VsockConnection::new_peer_init(dummy_file, VSOCK_HOST_CID, 5000, 3, 5001, -1, 65536); + + assert!(!vsock_conn_peer.connect); + assert_eq!(vsock_conn_peer.peer_port, 5001); + assert_eq!(vsock_conn_peer.rx_queue.dequeue().unwrap(), RxOps::Response); + assert!(!vsock_conn_peer.rx_queue.pending_rx()); + assert_eq!(vsock_conn_peer.local_cid, VSOCK_HOST_CID); + assert_eq!(vsock_conn_peer.local_port, 5000); + assert_eq!(vsock_conn_peer.guest_cid, 3); + assert_eq!(vsock_conn_peer.peer_buf_alloc, 65536); + } + + #[test] + fn test_vsock_conn_credit() { + // new locally inititated connection + let dummy_file = VsockDummySocket::new(); + let mut vsock_conn_local = + VsockConnection::new_local_init(dummy_file, VSOCK_HOST_CID, 5000, 3, 5001, -1); + + assert_eq!(vsock_conn_local.peer_avail_credit(), 0); + assert!(vsock_conn_local.need_credit_update_from_peer()); + + vsock_conn_local.peer_buf_alloc = 65536; + assert_eq!(vsock_conn_local.peer_avail_credit(), 65536); + assert!(!vsock_conn_local.need_credit_update_from_peer()); + + vsock_conn_local.rx_cnt = Wrapping(32768); + assert_eq!(vsock_conn_local.peer_avail_credit(), 32768); + assert!(!vsock_conn_local.need_credit_update_from_peer()); + + vsock_conn_local.rx_cnt = Wrapping(65536); + assert_eq!(vsock_conn_local.peer_avail_credit(), 0); + assert!(vsock_conn_local.need_credit_update_from_peer()); + } + + #[test] + fn test_vsock_conn_init_pkt() { + // parameters for packet head construction + let head_params = HeadParams::new(VSOCK_PKT_HDR_SIZE, 10); + + // new locally inititated connection + let dummy_file = VsockDummySocket::new(); + let vsock_conn_local = + VsockConnection::new_local_init(dummy_file, VSOCK_HOST_CID, 5000, 3, 5001, -1); + + // write only descriptor chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 2, 10); + let mut vsock_pkt = VsockPacket::from_rx_virtq_head(&mut descr_chain, mem).unwrap(); + + // initialize a vsock packet for the guest + vsock_conn_local.init_pkt(&mut vsock_pkt); + + assert_eq!(vsock_pkt.src_cid(), VSOCK_HOST_CID); + assert_eq!(vsock_pkt.dst_cid(), 3); + assert_eq!(vsock_pkt.src_port(), 5000); + assert_eq!(vsock_pkt.dst_port(), 5001); + assert_eq!(vsock_pkt.pkt_type(), VSOCK_TYPE_STREAM); + assert_eq!(vsock_pkt.buf_alloc(), CONN_TX_BUF_SIZE); + assert_eq!(vsock_pkt.fwd_cnt(), 0); + } + + #[test] + fn test_vsock_conn_recv_pkt() { + // parameters for packet head construction + let head_params = HeadParams::new(VSOCK_PKT_HDR_SIZE, 5); + + // new locally inititated connection + let dummy_file = VsockDummySocket::new(); + let mut vsock_conn_local = + VsockConnection::new_local_init(dummy_file, VSOCK_HOST_CID, 5000, 3, 5001, -1); + + // write only descriptor chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 1, 5); + let mut vsock_pkt = VsockPacket::from_rx_virtq_head(&mut descr_chain, mem).unwrap(); + + // VSOCK_OP_REQUEST: new local conn request + vsock_conn_local.rx_queue.enqueue(RxOps::Request); + let vsock_op_req = vsock_conn_local.recv_pkt(&mut vsock_pkt); + assert!(vsock_op_req.is_ok()); + assert!(!vsock_conn_local.rx_queue.pending_rx()); + assert_eq!(vsock_pkt.op(), VSOCK_OP_REQUEST); + + // VSOCK_OP_RST: reset if connection not established + vsock_conn_local.rx_queue.enqueue(RxOps::Rw); + let vsock_op_rst = vsock_conn_local.recv_pkt(&mut vsock_pkt); + assert!(vsock_op_rst.is_ok()); + assert!(!vsock_conn_local.rx_queue.pending_rx()); + assert_eq!(vsock_pkt.op(), VSOCK_OP_RST); + + // VSOCK_OP_CREDIT_UPDATE: need credit update from peer/guest + vsock_conn_local.connect = true; + vsock_conn_local.rx_queue.enqueue(RxOps::Rw); + vsock_conn_local.fwd_cnt = Wrapping(1024); + let vsock_op_credit_update = vsock_conn_local.recv_pkt(&mut vsock_pkt); + assert!(vsock_op_credit_update.is_ok()); + assert!(!vsock_conn_local.rx_queue.pending_rx()); + assert_eq!(vsock_pkt.op(), VSOCK_OP_CREDIT_REQUEST); + assert_eq!(vsock_conn_local.last_fwd_cnt, Wrapping(1024)); + + // VSOCK_OP_SHUTDOWN: zero data read from stream/file + vsock_conn_local.peer_buf_alloc = 65536; + vsock_conn_local.rx_queue.enqueue(RxOps::Rw); + let vsock_op_zero_read_shutdown = vsock_conn_local.recv_pkt(&mut vsock_pkt); + assert!(vsock_op_zero_read_shutdown.is_ok()); + assert!(!vsock_conn_local.rx_queue.pending_rx()); + assert_eq!(vsock_conn_local.rx_cnt, Wrapping(0)); + assert_eq!(vsock_conn_local.last_fwd_cnt, Wrapping(1024)); + assert_eq!(vsock_pkt.op(), VSOCK_OP_SHUTDOWN); + assert_eq!( + vsock_pkt.flags(), + VSOCK_FLAGS_SHUTDOWN_RCV | VSOCK_FLAGS_SHUTDOWN_SEND + ); + + // VSOCK_OP_RW: finite data read from stream/file + vsock_conn_local.stream.write_all(b"hello").unwrap(); + vsock_conn_local.rx_queue.enqueue(RxOps::Rw); + let vsock_op_zero_read = vsock_conn_local.recv_pkt(&mut vsock_pkt); + // below error due to epoll add + assert!(vsock_op_zero_read.is_err()); + assert_eq!(vsock_pkt.op(), VSOCK_OP_RW); + assert!(!vsock_conn_local.rx_queue.pending_rx()); + assert_eq!(vsock_pkt.len(), 5); + assert_eq!(vsock_pkt.buf().unwrap(), b"hello"); + + // VSOCK_OP_RESPONSE: response from a locally initiated connection + vsock_conn_local.rx_queue.enqueue(RxOps::Response); + let vsock_op_response = vsock_conn_local.recv_pkt(&mut vsock_pkt); + assert!(vsock_op_response.is_ok()); + assert!(!vsock_conn_local.rx_queue.pending_rx()); + assert_eq!(vsock_pkt.op(), VSOCK_OP_RESPONSE); + assert!(vsock_conn_local.connect); + + // VSOCK_OP_CREDIT_UPDATE: guest needs credit update + vsock_conn_local.rx_queue.enqueue(RxOps::CreditUpdate); + let vsock_op_credit_update = vsock_conn_local.recv_pkt(&mut vsock_pkt); + assert!(!vsock_conn_local.rx_queue.pending_rx()); + assert!(vsock_op_credit_update.is_ok()); + assert_eq!(vsock_pkt.op(), VSOCK_OP_CREDIT_UPDATE); + assert_eq!(vsock_conn_local.last_fwd_cnt, Wrapping(1024)); + + // non-existent request + let vsock_op_error = vsock_conn_local.recv_pkt(&mut vsock_pkt); + assert!(vsock_op_error.is_err()); + } + + #[test] + fn test_vsock_conn_send_pkt() { + // parameters for packet head construction + let head_params = HeadParams::new(VSOCK_PKT_HDR_SIZE, 5); + + // new locally inititated connection + let dummy_file = VsockDummySocket::new(); + let mut vsock_conn_local = + VsockConnection::new_local_init(dummy_file, VSOCK_HOST_CID, 5000, 3, 5001, -1); + + // write only descriptor chain + let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 1, 5); + let mut vsock_pkt = VsockPacket::from_tx_virtq_head(&mut descr_chain, mem).unwrap(); + + // peer credit information + const HDROFF_BUF_ALLOC: usize = 36; + const HDROFF_FWD_CNT: usize = 40; + LittleEndian::write_u32(&mut vsock_pkt.hdr_mut()[HDROFF_BUF_ALLOC..], 65536); + LittleEndian::write_u32(&mut vsock_pkt.hdr_mut()[HDROFF_FWD_CNT..], 1024); + + // check if peer credit information is updated currently + let credit_check = vsock_conn_local.send_pkt(&vsock_pkt); + assert!(credit_check.is_ok()); + assert_eq!(vsock_conn_local.peer_buf_alloc, 65536); + assert_eq!(vsock_conn_local.peer_fwd_cnt, Wrapping(1024)); + + // VSOCK_OP_RESPONSE + vsock_pkt.set_op(VSOCK_OP_RESPONSE); + let peer_response = vsock_conn_local.send_pkt(&vsock_pkt); + assert!(peer_response.is_ok()); + assert!(vsock_conn_local.connect); + let mut resp_buf = vec![0; 8]; + vsock_conn_local.stream.read_exact(&mut resp_buf).unwrap(); + assert_eq!(resp_buf, b"OK 5001\n"); + + // VSOCK_OP_RW + vsock_pkt.set_op(VSOCK_OP_RW); + vsock_pkt.buf_mut().unwrap().copy_from_slice(b"hello"); + let rw_response = vsock_conn_local.send_pkt(&vsock_pkt); + assert!(rw_response.is_ok()); + let mut resp_buf = vec![0; 5]; + vsock_conn_local.stream.read_exact(&mut resp_buf).unwrap(); + assert_eq!(resp_buf, b"hello"); + + // VSOCK_OP_CREDIT_REQUEST + vsock_pkt.set_op(VSOCK_OP_CREDIT_REQUEST); + let credit_response = vsock_conn_local.send_pkt(&vsock_pkt); + assert!(credit_response.is_ok()); + assert_eq!( + vsock_conn_local.rx_queue.peek().unwrap(), + RxOps::CreditUpdate + ); + + // VSOCK_OP_SHUTDOWN + vsock_pkt.set_op(VSOCK_OP_SHUTDOWN); + vsock_pkt.set_flags(VSOCK_FLAGS_SHUTDOWN_RCV | VSOCK_FLAGS_SHUTDOWN_SEND); + let shutdown_response = vsock_conn_local.send_pkt(&vsock_pkt); + assert!(shutdown_response.is_ok()); + assert!(vsock_conn_local.rx_queue.contains(RxOps::Reset.bitmask())); + } +}