diff --git a/proxmox-http/src/websocket/mod.rs b/proxmox-http/src/websocket/mod.rs index 2c53d016..d72c5505 100644 --- a/proxmox-http/src/websocket/mod.rs +++ b/proxmox-http/src/websocket/mod.rs @@ -650,7 +650,9 @@ impl AsyncRead for WebSocketReader { pub const MAGIC_WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; /// Provides methods for connecting a WebSocket endpoint with another -pub struct WebSocket; +pub struct WebSocket { + pub mask: Option<[u8; 4]>, +} impl WebSocket { /// Returns a new WebSocket instance and the generates the correct @@ -702,10 +704,13 @@ impl WebSocket { let response = response.body(Body::empty())?; - Ok((Self, response)) + let mask = None; + + Ok((Self { mask }, response)) } - async fn handle_channel_message( + pub async fn handle_channel_message( + &self, result: WebSocketReadResult, writer: &mut WebSocketWriter, ) -> Result @@ -714,11 +719,11 @@ impl WebSocket { { match result { Ok((OpCode::Ping, msg)) => { - writer.send_control_frame(None, OpCode::Pong, &msg).await?; + writer.send_control_frame(self.mask, OpCode::Pong, &msg).await?; Ok(OpCode::Pong) } Ok((OpCode::Close, msg)) => { - writer.send_control_frame(None, OpCode::Close, &msg).await?; + writer.send_control_frame(self.mask, OpCode::Close, &msg).await?; Ok(OpCode::Close) } Ok((opcode, _)) => { @@ -727,7 +732,7 @@ impl WebSocket { } Err(err) => { writer - .send_control_frame(None, OpCode::Close, &err.generate_frame_payload()) + .send_control_frame(self.mask, OpCode::Close, &err.generate_frame_payload()) .await?; Err(Error::from(err)) } @@ -735,6 +740,7 @@ impl WebSocket { } async fn copy_to_websocket( + &self, mut reader: &mut R, mut writer: &mut WebSocketWriter, receiver: &mut mpsc::UnboundedReceiver, @@ -743,7 +749,7 @@ impl WebSocket { R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { - let mut buf = ByteBuffer::new(); + let mut buf = ByteBuffer::with_capacity(16*1024); let mut eof = false; loop { if !buf.is_full() { @@ -751,7 +757,7 @@ impl WebSocket { res = buf.read_from_async(&mut reader).fuse() => res?, res = receiver.recv().fuse() => { let res = res.ok_or_else(|| format_err!("control channel closed"))?; - match Self::handle_channel_message(res, &mut writer).await? { + match self.handle_channel_message(res, &mut writer).await? { OpCode::Close => return Ok(true), _ => { continue; }, } @@ -771,6 +777,7 @@ impl WebSocket { } if eof && buf.is_empty() { + writer.flush().await?; return Ok(false); } } @@ -791,10 +798,10 @@ impl WebSocket { let (tx, mut rx) = mpsc::unbounded_channel(); let mut wsreader = WebSocketReader::new(usreader, tx); - let mut wswriter = WebSocketWriter::new(None, uswriter); + let mut wswriter = WebSocketWriter::new(self.mask, uswriter); let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter); - let term_future = Self::copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx); + let term_future = self.copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx); let res = select! { res = ws_future.fuse() => match res { @@ -804,7 +811,7 @@ impl WebSocket { res = term_future.fuse() => match res { Ok(sent_close) if !sent_close => { // status code 1000 => 0x03E8 - wswriter.send_control_frame(None, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?; + wswriter.send_control_frame(self.mask, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?; Ok(()) } Ok(_) => Ok(()),