diff --git a/proxmox/src/tools/websocket.rs b/proxmox/src/tools/websocket.rs index fc9a0c53..c6775f06 100644 --- a/proxmox/src/tools/websocket.rs +++ b/proxmox/src/tools/websocket.rs @@ -7,7 +7,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::cmp::min; -use std::io::{self, ErrorKind}; +use std::io; use std::future::Future; use futures::select; @@ -29,9 +29,65 @@ use hyper::header::{ use futures::future::FutureExt; use futures::ready; -use crate::io_format_err; +use crate::sys::error::io_err_other; use crate::tools::byte_buffer::ByteBuffer; +// see RFC6455 section 7.4.1 +#[derive(Debug, Clone, Copy)] +#[repr(u16)] +pub enum WebSocketErrorKind { + Normal = 1000, + ProtocolError = 1002, + InvalidData = 1003, + Other = 1008, + Unexpected = 1011, +} + +impl WebSocketErrorKind { + #[inline] + pub fn to_be_bytes(self) -> [u8; 2] { + (self as u16).to_be_bytes() + } +} + +impl std::fmt::Display for WebSocketErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + write!(f, "{}", *self as u16) + } +} + +#[derive(Debug, Clone)] +pub struct WebSocketError{ + kind: WebSocketErrorKind, + message: String, +} + +impl WebSocketError { + pub fn new(kind: WebSocketErrorKind, message: &str) -> Self { + Self{ + kind, + message: message.to_string() + } + } + + pub fn generate_frame_payload(&self) -> Vec { + let msglen = self.message.len().min(125); + let code = self.kind.to_be_bytes(); + let mut data = Vec::with_capacity(msglen + 2); + data.extend_from_slice(&code); + data.extend_from_slice(&self.message.as_bytes()[..msglen]); + data + } +} + +impl std::fmt::Display for WebSocketError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { + write!(f, "{} (Code: {})", self.message, self.kind) + } +} + +impl std::error::Error for WebSocketError {} + #[repr(u8)] #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)] /// Represents an OpCode of a websocket frame @@ -293,25 +349,23 @@ impl FrameHeader { /// Tries to parse a FrameHeader from bytes. /// /// When there are not enough bytes to completely parse the header, - /// returns Ok(Err(size)) where size determines how many bytes - /// are missing to parse further (this amount can change when more - /// information is available) + /// returns Ok(None) /// /// Example: /// ``` /// # use proxmox::tools::websocket::*; /// # use std::io; - /// # fn main() -> io::Result<()> { + /// # fn main() -> Result<(), WebSocketError> { /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?; /// let header = FrameHeader::try_from_bytes(&frame[..1])?; /// match header { - /// Ok(_) => unreachable!(), - /// Err(x) => assert_eq!(x, 1), + /// Some(_) => unreachable!(), + /// None => {}, /// } /// let header = FrameHeader::try_from_bytes(&frame[..2])?; /// match header { - /// Err(x) => unreachable!(), - /// Ok(header) => assert_eq!(header, FrameHeader{ + /// None => unreachable!(), + /// Some(header) => assert_eq!(header, FrameHeader{ /// fin: true, /// mask: None, /// frametype: OpCode::Ping, @@ -322,19 +376,19 @@ impl FrameHeader { /// # Ok(()) /// # } /// ``` - pub fn try_from_bytes(data: &[u8]) -> io::Result> { + pub fn try_from_bytes(data: &[u8]) -> Result, WebSocketError> { let len = data.len(); if len < 2 { - return Ok(Err(2 - len)); + return Ok(None); } let data = data; // we do not support extensions if data[0] & 0b01110000 > 0 { - return Err(io::Error::new( - ErrorKind::InvalidData, - "Extensions not supported", + return Err(WebSocketError::new( + WebSocketErrorKind::ProtocolError, + "Extensions not supported", )); } @@ -347,14 +401,17 @@ impl FrameHeader { 9 => OpCode::Ping, 10 => OpCode::Pong, other => { - return Err(io::Error::new(ErrorKind::InvalidData, format!("Unknown OpCode {}", other))); + return Err(WebSocketError::new( + WebSocketErrorKind::ProtocolError, + &format!("Unknown OpCode {}", other), + )); } }; if !fin && frametype.is_control() { - return Err(io::Error::new( - ErrorKind::InvalidData, - "Control frames cannot be fragmented", + return Err(WebSocketError::new( + WebSocketErrorKind::ProtocolError, + "Control frames cannot be fragmented", )); } @@ -368,14 +425,14 @@ impl FrameHeader { let mut payload_len: usize = (data[1] & 0b01111111).into(); if payload_len == 126 { if len < 4 { - return Ok(Err(4 - len)); + return Ok(None); } payload_len = u16::from_be_bytes([data[2], data[3]]) as usize; mask_offset += 2; payload_offset += 2; } else if payload_len == 127 { if len < 10 { - return Ok(Err(10 - len)); + return Ok(None); } payload_len = u64::from_be_bytes([ data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9], @@ -385,16 +442,16 @@ impl FrameHeader { } if payload_len > 125 && frametype.is_control() { - return Err(io::Error::new( - ErrorKind::InvalidData, - "Control frames cannot carry more than 125 bytes of data", + return Err(WebSocketError::new( + WebSocketErrorKind::ProtocolError, + "Control frames cannot carry more than 125 bytes of data", )); } let mask = match mask_bit { true => { if len < mask_offset + 4 { - return Ok(Err(mask_offset + 4 - len)); + return Ok(None); } let mut mask = [0u8; 4]; mask.copy_from_slice(&data[mask_offset as usize..payload_offset as usize]); @@ -403,7 +460,7 @@ impl FrameHeader { false => None, }; - Ok(Ok(FrameHeader { + Ok(Some(FrameHeader { fin, mask, frametype, @@ -413,6 +470,8 @@ impl FrameHeader { } } +type WebSocketReadResult = Result<(OpCode, Box<[u8]>), WebSocketError>; + /// Wraps a reader that implements AsyncRead and implements it itself. /// /// On read, reads the underlying reader and tries to decode the frames and @@ -422,7 +481,7 @@ impl FrameHeader { /// Has an internal Buffer for storing incomplete headers. pub struct WebSocketReader { reader: Option, - sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>, + sender: mpsc::UnboundedSender, read_buffer: Option, header: Option, state: ReaderState, @@ -431,11 +490,11 @@ pub struct WebSocketReader { impl WebSocketReader { /// Creates a new WebSocketReader with the given CallBack for control frames /// and a default buffer size of 4096. - pub fn new(reader: R, sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>) -> WebSocketReader { + pub fn new(reader: R, sender: mpsc::UnboundedSender) -> WebSocketReader { Self::with_capacity(reader, 4096, sender) } - pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>) -> WebSocketReader { + pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender) -> WebSocketReader { WebSocketReader { reader: Some(reader), sender, @@ -512,13 +571,19 @@ impl AsyncRead for WebSocketReader let mut header = match this.header.take() { Some(header) => header, None => { - let header = match FrameHeader::try_from_bytes(&read_buffer[..])? { - Ok(header) => header, - Err(_) => { + let header = match FrameHeader::try_from_bytes(&read_buffer[..]) { + Ok(Some(header)) => header, + Ok(None) => { this.state = ReaderState::NoData; this.read_buffer = Some(read_buffer); continue; - } + }, + Err(err) => { + if let Err(err) = this.sender.send(Err(err.clone())) { + return Poll::Ready(Err(io_err_other(err))); + } + return Poll::Ready(Err(io_err_other(err))); + }, }; read_buffer.consume(header.header_len as usize); @@ -531,7 +596,7 @@ impl AsyncRead for WebSocketReader let mut data = read_buffer.remove_data(header.payload_len); mask_bytes(header.mask, &mut data); - if let Err(err) = this.sender.send((header.frametype, data)) { + if let Err(err) = this.sender.send(Ok((header.frametype, data))) { eprintln!("error sending control frame: {}", err); } @@ -639,10 +704,37 @@ impl WebSocket { Ok((Self { text }, response)) } + async fn handle_channel_message( + result: WebSocketReadResult, + writer: &mut WebSocketWriter + ) -> Result + where + W: AsyncWrite + Unpin + Send, + { + match result { + Ok((OpCode::Ping, msg)) => { + writer.send_control_frame(None, OpCode::Pong, &msg).await?; + Ok(OpCode::Pong) + } + Ok((OpCode::Close, msg)) => { + writer.send_control_frame(None, OpCode::Close, &msg).await?; + Ok(OpCode::Close) + } + Ok((opcode, _)) => { + // ignore other frames + Ok(opcode) + }, + Err(err) => { + writer.send_control_frame(None, OpCode::Close, &err.generate_frame_payload()).await?; + Err(Error::from(err)) + } + } + } + async fn copy_to_websocket( mut reader: &mut R, - writer: &mut WebSocketWriter, - receiver: &mut mpsc::UnboundedReceiver<(OpCode, Box<[u8]>)>) -> Result + mut writer: &mut WebSocketWriter, + receiver: &mut mpsc::UnboundedReceiver) -> Result where R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, @@ -654,20 +746,10 @@ impl WebSocket { let bytes = select!{ res = buf.read_from_async(&mut reader).fuse() => res?, res = receiver.recv().fuse() => { - let (opcode, msg) = res.ok_or(format_err!("control channel closed"))?; - match opcode { - OpCode::Ping => { - writer.send_control_frame(None, OpCode::Pong, &msg).await?; - continue; - } - OpCode::Close => { - writer.send_control_frame(None, OpCode::Close, &msg).await?; - return Ok(true); - } - _ => { - // ignore other frames - continue; - } + let res = res.ok_or_else(|| format_err!("control channel closed"))?; + match Self::handle_channel_message(res, &mut writer).await? { + OpCode::Close => return Ok(true), + _ => { continue; }, } } }; @@ -720,7 +802,7 @@ impl WebSocket { res = term_future.fuse() => match res { Ok(sent_close) if !sent_close => { // status code 1000 => 0x03E8 - wswriter.send_control_frame(None, OpCode::Close, &[0x03, 0xE8]).await?; + wswriter.send_control_frame(None, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?; Ok(()) } Ok(_) => Ok(()),