websocket: adapt for client connection

previously, this was only used for the server side handling of web
sockets. by making the mask part of the WebSocket struct and making some
of the fns associated, we can re-use this for client-side connections
such as in proxmox-websocket-tunnel.

Signed-off-by: Fabian Grünbichler <f.gruenbichler@proxmox.com>
This commit is contained in:
Fabian Grünbichler 2021-11-05 14:03:38 +01:00 committed by Thomas Lamprecht
parent e0df53e793
commit e848148f5c

View File

@ -650,7 +650,9 @@ impl<R: AsyncRead + Unpin + Send + 'static> AsyncRead for WebSocketReader<R> {
pub const MAGIC_WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; pub const MAGIC_WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
/// Provides methods for connecting a WebSocket endpoint with another /// Provides methods for connecting a WebSocket endpoint with another
pub struct WebSocket; pub struct WebSocket {
pub mask: Option<[u8; 4]>,
}
impl WebSocket { impl WebSocket {
/// Returns a new WebSocket instance and the generates the correct /// Returns a new WebSocket instance and the generates the correct
@ -702,10 +704,13 @@ impl WebSocket {
let response = response.body(Body::empty())?; let response = response.body(Body::empty())?;
Ok((Self, response)) let mask = None;
Ok((Self { mask }, response))
} }
async fn handle_channel_message<W>( pub async fn handle_channel_message<W>(
&self,
result: WebSocketReadResult, result: WebSocketReadResult,
writer: &mut WebSocketWriter<W>, writer: &mut WebSocketWriter<W>,
) -> Result<OpCode, Error> ) -> Result<OpCode, Error>
@ -714,11 +719,11 @@ impl WebSocket {
{ {
match result { match result {
Ok((OpCode::Ping, msg)) => { Ok((OpCode::Ping, msg)) => {
writer.send_control_frame(None, OpCode::Pong, &msg).await?; writer.send_control_frame(self.mask, OpCode::Pong, &msg).await?;
Ok(OpCode::Pong) Ok(OpCode::Pong)
} }
Ok((OpCode::Close, msg)) => { Ok((OpCode::Close, msg)) => {
writer.send_control_frame(None, OpCode::Close, &msg).await?; writer.send_control_frame(self.mask, OpCode::Close, &msg).await?;
Ok(OpCode::Close) Ok(OpCode::Close)
} }
Ok((opcode, _)) => { Ok((opcode, _)) => {
@ -727,7 +732,7 @@ impl WebSocket {
} }
Err(err) => { Err(err) => {
writer writer
.send_control_frame(None, OpCode::Close, &err.generate_frame_payload()) .send_control_frame(self.mask, OpCode::Close, &err.generate_frame_payload())
.await?; .await?;
Err(Error::from(err)) Err(Error::from(err))
} }
@ -735,6 +740,7 @@ impl WebSocket {
} }
async fn copy_to_websocket<R, W>( async fn copy_to_websocket<R, W>(
&self,
mut reader: &mut R, mut reader: &mut R,
mut writer: &mut WebSocketWriter<W>, mut writer: &mut WebSocketWriter<W>,
receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>, receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>,
@ -743,7 +749,7 @@ impl WebSocket {
R: AsyncRead + Unpin + Send, R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
{ {
let mut buf = ByteBuffer::new(); let mut buf = ByteBuffer::with_capacity(16*1024);
let mut eof = false; let mut eof = false;
loop { loop {
if !buf.is_full() { if !buf.is_full() {
@ -751,7 +757,7 @@ impl WebSocket {
res = buf.read_from_async(&mut reader).fuse() => res?, res = buf.read_from_async(&mut reader).fuse() => res?,
res = receiver.recv().fuse() => { res = receiver.recv().fuse() => {
let res = res.ok_or_else(|| format_err!("control channel closed"))?; let res = res.ok_or_else(|| format_err!("control channel closed"))?;
match Self::handle_channel_message(res, &mut writer).await? { match self.handle_channel_message(res, &mut writer).await? {
OpCode::Close => return Ok(true), OpCode::Close => return Ok(true),
_ => { continue; }, _ => { continue; },
} }
@ -771,6 +777,7 @@ impl WebSocket {
} }
if eof && buf.is_empty() { if eof && buf.is_empty() {
writer.flush().await?;
return Ok(false); return Ok(false);
} }
} }
@ -791,10 +798,10 @@ impl WebSocket {
let (tx, mut rx) = mpsc::unbounded_channel(); let (tx, mut rx) = mpsc::unbounded_channel();
let mut wsreader = WebSocketReader::new(usreader, tx); let mut wsreader = WebSocketReader::new(usreader, tx);
let mut wswriter = WebSocketWriter::new(None, uswriter); let mut wswriter = WebSocketWriter::new(self.mask, uswriter);
let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter); let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter);
let term_future = Self::copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx); let term_future = self.copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx);
let res = select! { let res = select! {
res = ws_future.fuse() => match res { res = ws_future.fuse() => match res {
@ -804,7 +811,7 @@ impl WebSocket {
res = term_future.fuse() => match res { res = term_future.fuse() => match res {
Ok(sent_close) if !sent_close => { Ok(sent_close) if !sent_close => {
// status code 1000 => 0x03E8 // status code 1000 => 0x03E8
wswriter.send_control_frame(None, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?; wswriter.send_control_frame(self.mask, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?;
Ok(()) Ok(())
} }
Ok(_) => Ok(()), Ok(_) => Ok(()),