//! Websocket helpers //! //! Provides methods to read and write from websockets The reader and writer take a reader/writer //! with AsyncRead/AsyncWrite respectively and provides the same use std::cmp::min; use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use anyhow::{bail, format_err, Error}; use futures::select; use hyper::header::{ HeaderMap, HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE, }; use hyper::{Body, Response, StatusCode}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::sync::mpsc; use futures::future::FutureExt; use futures::ready; use proxmox_io::ByteBuffer; use proxmox_sys::error::io_err_other; // see RFC6455 section 7.4.1 #[derive(Debug, Clone, Copy)] #[repr(u16)] pub enum WebSocketErrorKind { Normal = 1000, ProtocolError = 1002, InvalidData = 1003, Other = 1008, Unexpected = 1011, } impl WebSocketErrorKind { #[inline] pub fn to_be_bytes(self) -> [u8; 2] { (self as u16).to_be_bytes() } } impl std::fmt::Display for WebSocketErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { write!(f, "{}", *self as u16) } } #[derive(Debug, Clone)] pub struct WebSocketError { kind: WebSocketErrorKind, message: String, } impl WebSocketError { pub fn new(kind: WebSocketErrorKind, message: &str) -> Self { Self { kind, message: message.to_string(), } } pub fn generate_frame_payload(&self) -> Vec { let msglen = self.message.len().min(125); let code = self.kind.to_be_bytes(); let mut data = Vec::with_capacity(msglen + 2); data.extend_from_slice(&code); data.extend_from_slice(&self.message.as_bytes()[..msglen]); data } } impl std::fmt::Display for WebSocketError { fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { write!(f, "{} (Code: {})", self.message, self.kind) } } impl std::error::Error for WebSocketError {} #[repr(u8)] #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)] /// Represents an OpCode of a websocket frame pub enum OpCode { /// A fragmented frame Continuation = 0, /// A non-fragmented text frame Text = 1, /// A non-fragmented binary frame Binary = 2, /// A closing frame Close = 8, /// A ping frame Ping = 9, /// A pong frame Pong = 10, } impl OpCode { /// Tells whether it is a control frame or not pub fn is_control(self) -> bool { (self as u8 & 0b1000) > 0 } } fn mask_bytes(mask: Option<[u8; 4]>, data: &mut [u8]) { let mask = match mask { Some([0, 0, 0, 0]) | None => return, Some(mask) => mask, }; if data.len() < 32 { for i in 0..data.len() { data[i] ^= mask[i & 3]; } return; } let mut newmask: u32 = u32::from_le_bytes(mask); let (prefix, middle, suffix) = unsafe { data.align_to_mut::() }; for p in prefix { *p ^= newmask as u8; newmask = newmask.rotate_right(8); } for m in middle { *m ^= newmask; } for s in suffix { *s ^= newmask as u8; newmask = newmask.rotate_right(8); } } /// Can be used to create a complete WebSocket Frame. /// /// Takes an optional mask, the data and the frame type /// /// Examples: /// /// A normal Frame /// ``` /// # use proxmox_http::websocket::*; /// # use std::io; /// # fn main() -> Result<(), WebSocketError> { /// let data = vec![0,1,2,3,4]; /// let frame = create_frame(None, &data, OpCode::Text)?; /// assert_eq!(frame, vec![0b10000001, 5, 0, 1, 2, 3, 4]); /// # Ok(()) /// # } /// /// ``` /// /// A masked Frame /// ``` /// # use proxmox_http::websocket::*; /// # use std::io; /// # fn main() -> Result<(), WebSocketError> { /// let data = vec![0,1,2,3,4]; /// let frame = create_frame(Some([0u8, 1u8, 2u8, 3u8]), &data, OpCode::Text)?; /// assert_eq!(frame, vec![0b10000001, 0b10000101, 0, 1, 2, 3, 0, 0, 0, 0, 4]); /// # Ok(()) /// # } /// /// ``` /// /// A ping Frame /// ``` /// # use proxmox_http::websocket::*; /// # use std::io; /// # fn main() -> Result<(), WebSocketError> { /// let data = vec![0,1,2,3,4]; /// let frame = create_frame(None, &data, OpCode::Ping)?; /// assert_eq!(frame, vec![0b10001001, 0b00000101, 0, 1, 2, 3, 4]); /// # Ok(()) /// # } /// /// ``` pub fn create_frame( mask: Option<[u8; 4]>, data: &[u8], frametype: OpCode, ) -> Result, WebSocketError> { let first_byte = 0b10000000 | (frametype as u8); let len = data.len(); if (frametype as u8) & 0b00001000 > 0 && len > 125 { return Err(WebSocketError::new( WebSocketErrorKind::Unexpected, "Control frames cannot have data longer than 125 bytes", )); } let mask_bit = if mask.is_some() { 0b10000000 } else { 0b00000000 }; let mut buf = vec![first_byte]; if len < 126 { buf.push(mask_bit | (len as u8)); } else if len < u16::MAX as usize { buf.push(mask_bit | 126); buf.extend_from_slice(&(len as u16).to_be_bytes()); } else { buf.push(mask_bit | 127); buf.extend_from_slice(&(len as u64).to_be_bytes()); } if let Some(mask) = mask { buf.extend_from_slice(&mask); } let mut data = data.to_vec().into_boxed_slice(); mask_bytes(mask, &mut data); buf.append(&mut data.into_vec()); Ok(buf) } /// Wrap (encapsulate) an `AsyncWrite`er into a WebSocket transparently /// /// Send websocket frames to anything accepting AsyncWrite. /// /// Note: Every write to it gets encoded as a seperate websocket frame, without any fragmentation /// being enforced. /// /// Example usage: /// ``` /// # use proxmox_http::websocket::*; /// # use std::io; /// # use tokio::io::{AsyncWrite, AsyncWriteExt}; /// async fn code(writer: I) -> io::Result<()> { /// let mut ws = WebSocketWriter::new(None, writer); /// ws.write(&[1u8,2u8,3u8]).await?; /// Ok(()) /// } /// ``` pub struct WebSocketWriter { writer: W, mask: Option<[u8; 4]>, frame: Option<(Vec, usize, usize)>, } impl WebSocketWriter { /// Create a new WebSocketWriter which will use the given mask, if any, creating Binary frames pub fn new(mask: Option<[u8; 4]>, writer: W) -> WebSocketWriter { WebSocketWriter { writer, mask, frame: None, } } pub async fn send_control_frame( &mut self, mask: Option<[u8; 4]>, opcode: OpCode, data: &[u8], ) -> Result<(), Error> { let frame = create_frame(mask, data, opcode).map_err(Error::from)?; self.writer.write_all(&frame).await.map_err(Error::from) } } impl AsyncWrite for WebSocketWriter { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = Pin::get_mut(self); if this.frame.is_none() { // create frame buf let frame = match create_frame(this.mask, buf, OpCode::Binary) { Ok(f) => f, Err(e) => { return Poll::Ready(Err(io_err_other(e))); } }; this.frame = Some((frame, 0, buf.len())); } // we have a frame in any case, so unwrap is ok let (buf, pos, origsize) = this.frame.as_mut().unwrap(); loop { match ready!(Pin::new(&mut this.writer).poll_write(cx, &buf[*pos..])) { Ok(size) => { *pos += size; if *pos == buf.len() { let size = *origsize; this.frame = None; return Poll::Ready(Ok(size)); } } Err(err) => { eprintln!("error in writer: {}", err); return Poll::Ready(Err(err)); } } } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = Pin::get_mut(self); Pin::new(&mut this.writer).poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = Pin::get_mut(self); Pin::new(&mut this.writer).poll_shutdown(cx) } } #[derive(Debug, PartialEq)] /// Represents the header of a websocket Frame pub struct FrameHeader { /// True if the frame is either non-fragmented, or the last fragment pub fin: bool, /// The optional mask of the frame pub mask: Option<[u8; 4]>, /// The frametype pub frametype: OpCode, /// The length of the header (without payload). pub header_len: u8, /// The length of the payload. pub payload_len: usize, } impl FrameHeader { /// Returns true if the frame is a control frame. pub fn is_control_frame(&self) -> bool { self.frametype.is_control() } /// Tries to parse a FrameHeader from bytes. /// /// When there are not enough bytes to completely parse the header, /// returns Ok(None) /// /// Example: /// ``` /// # use proxmox_http::websocket::*; /// # use std::io; /// # fn main() -> Result<(), WebSocketError> { /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?; /// let header = FrameHeader::try_from_bytes(&frame[..1])?; /// match header { /// Some(_) => unreachable!(), /// None => {}, /// } /// let header = FrameHeader::try_from_bytes(&frame[..2])?; /// match header { /// None => unreachable!(), /// Some(header) => assert_eq!(header, FrameHeader{ /// fin: true, /// mask: None, /// frametype: OpCode::Ping, /// header_len: 2, /// payload_len: 4, /// }), /// } /// # Ok(()) /// # } /// ``` pub fn try_from_bytes(data: &[u8]) -> Result, WebSocketError> { let len = data.len(); if len < 2 { return Ok(None); } let data = data; // we do not support extensions if data[0] & 0b01110000 > 0 { return Err(WebSocketError::new( WebSocketErrorKind::ProtocolError, "Extensions not supported", )); } let fin = data[0] & 0b10000000 != 0; let frametype = match data[0] & 0b1111 { 0 => OpCode::Continuation, 1 => OpCode::Text, 2 => OpCode::Binary, 8 => OpCode::Close, 9 => OpCode::Ping, 10 => OpCode::Pong, other => { return Err(WebSocketError::new( WebSocketErrorKind::ProtocolError, &format!("Unknown OpCode {}", other), )); } }; if !fin && frametype.is_control() { return Err(WebSocketError::new( WebSocketErrorKind::ProtocolError, "Control frames cannot be fragmented", )); } let mask_bit = data[1] & 0b10000000 != 0; let mut mask_offset = 2; let mut payload_offset = 2; if mask_bit { payload_offset += 4; } let mut payload_len: usize = (data[1] & 0b01111111).into(); if payload_len == 126 { if len < 4 { return Ok(None); } payload_len = u16::from_be_bytes([data[2], data[3]]) as usize; mask_offset += 2; payload_offset += 2; } else if payload_len == 127 { if len < 10 { return Ok(None); } payload_len = u64::from_be_bytes([ data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9], ]) as usize; mask_offset += 8; payload_offset += 8; } if payload_len > 125 && frametype.is_control() { return Err(WebSocketError::new( WebSocketErrorKind::ProtocolError, "Control frames cannot carry more than 125 bytes of data", )); } let mask = if mask_bit { if len < mask_offset + 4 { return Ok(None); } let mut mask = [0u8; 4]; mask.copy_from_slice(&data[mask_offset as usize..payload_offset as usize]); Some(mask) } else { None }; Ok(Some(FrameHeader { fin, mask, frametype, payload_len, header_len: payload_offset, })) } } type WebSocketReadResult = Result<(OpCode, Box<[u8]>), WebSocketError>; /// Wraps a `AsyncRead`er for decoding WebSocket frames returning the inner payload. /// /// Polls the underlying reader, decodes the web socket frames while returning the inner data /// stream via `AsyncRead` itself. /// /// Any control frame encountered will get relayed to the 'sender' channel /// /// Incomplete headers get buffered internally. pub struct WebSocketReader { reader: Option, sender: mpsc::UnboundedSender, read_buffer: Option, header: Option, state: ReaderState, } impl WebSocketReader { /// Creates a new WebSocketReader with the given sender for control frames /// and a default buffer size of 4096. pub fn new( reader: R, sender: mpsc::UnboundedSender, ) -> WebSocketReader { Self::with_capacity(reader, 4096, sender) } pub fn with_capacity( reader: R, capacity: usize, sender: mpsc::UnboundedSender, ) -> WebSocketReader { WebSocketReader { reader: Some(reader), sender, read_buffer: Some(ByteBuffer::with_capacity(capacity)), header: None, state: ReaderState::NoData, } } } struct ReadResult { len: usize, reader: R, buffer: ByteBuffer, } enum ReaderState { NoData, Receiving(Pin>> + Send + 'static>>), HaveData, } unsafe impl Sync for ReaderState {} impl AsyncRead for WebSocketReader { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf, ) -> Poll> { let this = Pin::get_mut(self); loop { match &mut this.state { ReaderState::NoData => { let mut reader = match this.reader.take() { Some(reader) => reader, None => return Poll::Ready(Err(io_err_other("no reader"))), }; let mut buffer = match this.read_buffer.take() { Some(buffer) => buffer, None => return Poll::Ready(Err(io_err_other("no buffer"))), }; let future = async move { buffer .read_from_async(&mut reader) .await .map(move |len| ReadResult { len, reader, buffer, }) }; this.state = ReaderState::Receiving(future.boxed()); } ReaderState::Receiving(ref mut future) => match ready!(future.as_mut().poll(cx)) { Ok(ReadResult { len, reader, buffer, }) => { this.reader = Some(reader); this.read_buffer = Some(buffer); this.state = ReaderState::HaveData; if len == 0 { return Poll::Ready(Ok(())); } } Err(err) => return Poll::Ready(Err(err)), }, ReaderState::HaveData => { let mut read_buffer = match this.read_buffer.take() { Some(read_buffer) => read_buffer, None => return Poll::Ready(Err(io_err_other("no buffer"))), }; let mut header = match this.header.take() { Some(header) => header, None => { let header = match FrameHeader::try_from_bytes(&read_buffer[..]) { Ok(Some(header)) => header, Ok(None) => { this.state = ReaderState::NoData; this.read_buffer = Some(read_buffer); continue; } Err(err) => { if let Err(err) = this.sender.send(Err(err.clone())) { return Poll::Ready(Err(io_err_other(err))); } return Poll::Ready(Err(io_err_other(err))); } }; read_buffer.consume(header.header_len as usize); header } }; if header.is_control_frame() { if read_buffer.len() >= header.payload_len { let mut data = read_buffer.remove_data(header.payload_len); mask_bytes(header.mask, &mut data); if let Err(err) = this.sender.send(Ok((header.frametype, data))) { eprintln!("error sending control frame: {}", err); } this.state = if read_buffer.is_empty() { ReaderState::NoData } else { ReaderState::HaveData }; this.read_buffer = Some(read_buffer); } else { this.header = Some(header); this.read_buffer = Some(read_buffer); this.state = ReaderState::NoData; } continue; } let len = min(buf.remaining(), min(header.payload_len, read_buffer.len())); let mut data = read_buffer.remove_data(len); mask_bytes(header.mask, &mut data); buf.put_slice(&data); header.payload_len -= len; if header.payload_len > 0 { this.header = Some(header); } this.state = if read_buffer.is_empty() { ReaderState::NoData } else { ReaderState::HaveData }; this.read_buffer = Some(read_buffer); if len > 0 { return Poll::Ready(Ok(())); } } } } } } /// Global Identifier for WebSockets, see RFC6455 pub const MAGIC_WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; /// Provides methods for connecting one WebSocket endpoint with another pub struct WebSocket { pub mask: Option<[u8; 4]>, } impl WebSocket { /// Returns a new WebSocket instance and the correct WebSocket response derived from the /// upgrade request's headers pub fn new(headers: HeaderMap) -> Result<(Self, Response), Error> { let protocols = headers .get(UPGRADE) .ok_or_else(|| format_err!("missing Upgrade header"))? .to_str()?; let version = headers .get(SEC_WEBSOCKET_VERSION) .ok_or_else(|| format_err!("missing websocket version"))? .to_str()?; let key = headers .get(SEC_WEBSOCKET_KEY) .ok_or_else(|| format_err!("missing websocket key"))? .to_str()?; if protocols != "websocket" { bail!("invalid protocol name"); } if version != "13" { bail!("invalid websocket version"); } // we ignore extensions let mut sha1 = openssl::sha::Sha1::new(); let data = format!("{}{}", key, MAGIC_WEBSOCKET_GUID); sha1.update(data.as_bytes()); let response_key = base64::encode(sha1.finish()); let mut response = Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) .header(UPGRADE, HeaderValue::from_static("websocket")) .header(CONNECTION, HeaderValue::from_static("Upgrade")) .header(SEC_WEBSOCKET_ACCEPT, response_key); // FIXME: remove compat in PBS 3.x // // We currently do not support any subprotocols and we always send binary frames, but for // backwards compatibility we need to reply the requested protocols if let Some(ws_proto) = headers.get(SEC_WEBSOCKET_PROTOCOL) { response = response.header(SEC_WEBSOCKET_PROTOCOL, ws_proto) } let response = response.body(Body::empty())?; Ok((Self { mask: None }, response)) } pub async fn handle_channel_message( &self, result: WebSocketReadResult, writer: &mut WebSocketWriter, ) -> Result where W: AsyncWrite + Unpin + Send, { match result { Ok((OpCode::Ping, msg)) => { writer .send_control_frame(self.mask, OpCode::Pong, &msg) .await?; Ok(OpCode::Pong) } Ok((OpCode::Close, msg)) => { writer .send_control_frame(self.mask, OpCode::Close, &msg) .await?; Ok(OpCode::Close) } Ok((opcode, _)) => { // ignore other frames Ok(opcode) } Err(err) => { writer .send_control_frame(self.mask, OpCode::Close, &err.generate_frame_payload()) .await?; Err(Error::from(err)) } } } async fn copy_to_websocket( &self, mut reader: &mut R, writer: &mut WebSocketWriter, receiver: &mut mpsc::UnboundedReceiver, ) -> Result where R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { let mut buf = ByteBuffer::with_capacity(16 * 1024); let mut eof = false; loop { if !buf.is_full() { let bytes = select! { res = buf.read_from_async(&mut reader).fuse() => res?, res = receiver.recv().fuse() => { let res = res.ok_or_else(|| format_err!("control channel closed"))?; match self.handle_channel_message(res, writer).await? { OpCode::Close => return Ok(true), _ => { continue; }, } } }; if bytes == 0 { eof = true; } } if buf.len() > 0 { let bytes = writer.write(&buf).await?; if bytes == 0 { eof = true; } buf.consume(bytes); } if eof && buf.is_empty() { writer.flush().await?; return Ok(false); } } } /// Takes two endpoints and connects them via a websocket, where the 'upstream' endpoint sends /// and receives WebSocket frames, while 'downstream' only expects and sends raw data. /// /// This method takes care of copying the data between endpoints, and sending correct responses /// for control frames (e.g. a Pont to a Ping). pub async fn serve_connection(&self, upstream: S, downstream: L) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, L: AsyncRead + AsyncWrite + Unpin + Send, { let (usreader, uswriter) = tokio::io::split(upstream); let (mut dsreader, mut dswriter) = tokio::io::split(downstream); let (tx, mut rx) = mpsc::unbounded_channel(); let mut wsreader = WebSocketReader::new(usreader, tx); let mut wswriter = WebSocketWriter::new(self.mask, uswriter); let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter); let term_future = self.copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx); select! { res = ws_future.fuse() => match res { Ok(_) => Ok(()), Err(err) => Err(Error::from(err)), }, res = term_future.fuse() => match res { Ok(sent_close) if !sent_close => { // status code 1000 => 0x03E8 wswriter .send_control_frame(self.mask, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()) .await?; Ok(()) } Ok(_) => Ok(()), Err(err) => Err(err), } } } }