more formatting fixups

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2020-08-28 08:56:54 +02:00
parent bbc94222f5
commit ba7c389142

View File

@ -3,27 +3,21 @@
//! Provides methods to read and write from websockets The reader and writer take a reader/writer //! Provides methods to read and write from websockets The reader and writer take a reader/writer
//! with AsyncRead/AsyncWrite respectively and provides the same //! with AsyncRead/AsyncWrite respectively and provides the same
use std::cmp::min;
use std::future::Future;
use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::cmp::min;
use std::io;
use std::future::Future;
use futures::select;
use anyhow::{bail, format_err, Error}; use anyhow::{bail, format_err, Error};
use futures::select;
use hyper::header::{
HeaderMap, HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
};
use hyper::{Body, Response, StatusCode};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use hyper::{Body, Response, StatusCode};
use hyper::header::{
HeaderMap,
HeaderValue,
UPGRADE,
CONNECTION,
SEC_WEBSOCKET_KEY,
SEC_WEBSOCKET_PROTOCOL,
SEC_WEBSOCKET_VERSION,
SEC_WEBSOCKET_ACCEPT,
};
use futures::future::FutureExt; use futures::future::FutureExt;
use futures::ready; use futures::ready;
@ -56,16 +50,16 @@ impl std::fmt::Display for WebSocketErrorKind {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct WebSocketError{ pub struct WebSocketError {
kind: WebSocketErrorKind, kind: WebSocketErrorKind,
message: String, message: String,
} }
impl WebSocketError { impl WebSocketError {
pub fn new(kind: WebSocketErrorKind, message: &str) -> Self { pub fn new(kind: WebSocketErrorKind, message: &str) -> Self {
Self{ Self {
kind, kind,
message: message.to_string() message: message.to_string(),
} }
} }
@ -114,13 +108,13 @@ impl OpCode {
fn mask_bytes(mask: Option<[u8; 4]>, data: &mut [u8]) { fn mask_bytes(mask: Option<[u8; 4]>, data: &mut [u8]) {
let mask = match mask { let mask = match mask {
Some([0,0,0,0]) | None => return, Some([0, 0, 0, 0]) | None => return,
Some(mask) => mask, Some(mask) => mask,
}; };
if data.len() < 32 { if data.len() < 32 {
for i in 0..data.len() { for i in 0..data.len() {
data[i] ^= mask[i%4]; data[i] ^= mask[i % 4];
} }
return; return;
} }
@ -202,7 +196,11 @@ pub fn create_frame(
)); ));
} }
let mask_bit = if mask.is_some() { 0b10000000 } else { 0b00000000 }; let mask_bit = if mask.is_some() {
0b10000000
} else {
0b00000000
};
let mut buf = Vec::new(); let mut buf = Vec::new();
buf.push(first_byte); buf.push(first_byte);
@ -263,21 +261,26 @@ impl<W: AsyncWrite + Unpin> WebSocketWriter<W> {
} }
} }
pub async fn send_control_frame(&mut self, mask: Option<[u8; 4]>, opcode: OpCode, data: &[u8]) -> Result<(), Error> { pub async fn send_control_frame(
&mut self,
mask: Option<[u8; 4]>,
opcode: OpCode,
data: &[u8],
) -> Result<(), Error> {
let frame = create_frame(mask, data, opcode).map_err(Error::from)?; let frame = create_frame(mask, data, opcode).map_err(Error::from)?;
self.writer.write_all(&frame).await.map_err(Error::from) self.writer.write_all(&frame).await.map_err(Error::from)
} }
} }
impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> { impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
fn poll_write( fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = Pin::get_mut(self); let this = Pin::get_mut(self);
let frametype = if this.text { OpCode::Text } else { OpCode::Binary }; let frametype = if this.text {
OpCode::Text
} else {
OpCode::Binary
};
if this.frame.is_none() { if this.frame.is_none() {
// create frame buf // create frame buf
@ -301,11 +304,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
this.frame = None; this.frame = None;
return Poll::Ready(Ok(size)); return Poll::Ready(Ok(size));
} }
}, }
Err(err) => { Err(err) => {
eprintln!("error in writer: {}", err); eprintln!("error in writer: {}", err);
return Poll::Ready(Err(err)) return Poll::Ready(Err(err));
}, }
} }
} }
} }
@ -321,7 +324,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
} }
} }
#[derive(Debug,PartialEq)] #[derive(Debug, PartialEq)]
/// Represents the header of a websocket Frame /// Represents the header of a websocket Frame
pub struct FrameHeader { pub struct FrameHeader {
/// True if the frame is either non-fragmented, or the last fragment /// True if the frame is either non-fragmented, or the last fragment
@ -485,11 +488,18 @@ pub struct WebSocketReader<R: AsyncRead> {
impl<R: AsyncReadExt> WebSocketReader<R> { impl<R: AsyncReadExt> WebSocketReader<R> {
/// Creates a new WebSocketReader with the given CallBack for control frames /// Creates a new WebSocketReader with the given CallBack for control frames
/// and a default buffer size of 4096. /// and a default buffer size of 4096.
pub fn new(reader: R, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> { pub fn new(
reader: R,
sender: mpsc::UnboundedSender<WebSocketReadResult>,
) -> WebSocketReader<R> {
Self::with_capacity(reader, 4096, sender) Self::with_capacity(reader, 4096, sender)
} }
pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> { pub fn with_capacity(
reader: R,
capacity: usize,
sender: mpsc::UnboundedSender<WebSocketReadResult>,
) -> WebSocketReader<R> {
WebSocketReader { WebSocketReader {
reader: Some(reader), reader: Some(reader),
sender, sender,
@ -537,25 +547,32 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
}; };
let future = async move { let future = async move {
buffer.read_from_async(&mut reader) buffer
.read_from_async(&mut reader)
.await .await
.map(move |len| ReadResult { len, reader, buffer }) .map(move |len| ReadResult {
len,
reader,
buffer,
})
}; };
this.state = ReaderState::Receiving(future.boxed()); this.state = ReaderState::Receiving(future.boxed());
}, }
ReaderState::Receiving(ref mut future) => { ReaderState::Receiving(ref mut future) => match ready!(future.as_mut().poll(cx)) {
match ready!(future.as_mut().poll(cx)) { Ok(ReadResult {
Ok(ReadResult { len, reader, buffer }) => { len,
reader,
buffer,
}) => {
this.reader = Some(reader); this.reader = Some(reader);
this.read_buffer = Some(buffer); this.read_buffer = Some(buffer);
this.state = ReaderState::HaveData; this.state = ReaderState::HaveData;
if len == 0 { if len == 0 {
return Poll::Ready(Ok(0)); return Poll::Ready(Ok(0));
} }
},
Err(err) => return Poll::Ready(Err(err)),
} }
Err(err) => return Poll::Ready(Err(err)),
}, },
ReaderState::HaveData => { ReaderState::HaveData => {
let mut read_buffer = match this.read_buffer.take() { let mut read_buffer = match this.read_buffer.take() {
@ -572,23 +589,22 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
this.state = ReaderState::NoData; this.state = ReaderState::NoData;
this.read_buffer = Some(read_buffer); this.read_buffer = Some(read_buffer);
continue; continue;
}, }
Err(err) => { Err(err) => {
if let Err(err) = this.sender.send(Err(err.clone())) { if let Err(err) = this.sender.send(Err(err.clone())) {
return Poll::Ready(Err(io_err_other(err))); return Poll::Ready(Err(io_err_other(err)));
} }
return Poll::Ready(Err(io_err_other(err))); return Poll::Ready(Err(io_err_other(err)));
}, }
}; };
read_buffer.consume(header.header_len as usize); read_buffer.consume(header.header_len as usize);
header header
}, }
}; };
if header.is_control_frame() { if header.is_control_frame() {
if read_buffer.len() >= header.payload_len { if read_buffer.len() >= header.payload_len {
let mut data = read_buffer.remove_data(header.payload_len); let mut data = read_buffer.remove_data(header.payload_len);
mask_bytes(header.mask, &mut data); mask_bytes(header.mask, &mut data);
if let Err(err) = this.sender.send(Ok((header.frametype, data))) { if let Err(err) = this.sender.send(Ok((header.frametype, data))) {
@ -610,11 +626,14 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
} }
} }
let len = min(buf.len() - offset, min(header.payload_len, read_buffer.len())); let len = min(
buf.len() - offset,
min(header.payload_len, read_buffer.len()),
);
let mut data = read_buffer.remove_data(len); let mut data = read_buffer.remove_data(len);
mask_bytes(header.mask, &mut data); mask_bytes(header.mask, &mut data);
buf[offset..offset+len].copy_from_slice(&data); buf[offset..offset + len].copy_from_slice(&data);
offset += len; offset += len;
header.payload_len -= len; header.payload_len -= len;
@ -633,7 +652,7 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
if offset > 0 { if offset > 0 {
return Poll::Ready(Ok(offset)); return Poll::Ready(Ok(offset));
} }
}, }
} }
} }
} }
@ -701,7 +720,7 @@ impl WebSocket {
async fn handle_channel_message<W>( async fn handle_channel_message<W>(
result: WebSocketReadResult, result: WebSocketReadResult,
writer: &mut WebSocketWriter<W> writer: &mut WebSocketWriter<W>,
) -> Result<OpCode, Error> ) -> Result<OpCode, Error>
where where
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
@ -718,9 +737,11 @@ impl WebSocket {
Ok((opcode, _)) => { Ok((opcode, _)) => {
// ignore other frames // ignore other frames
Ok(opcode) Ok(opcode)
}, }
Err(err) => { Err(err) => {
writer.send_control_frame(None, OpCode::Close, &err.generate_frame_payload()).await?; writer
.send_control_frame(None, OpCode::Close, &err.generate_frame_payload())
.await?;
Err(Error::from(err)) Err(Error::from(err))
} }
} }
@ -729,7 +750,8 @@ impl WebSocket {
async fn copy_to_websocket<R, W>( async fn copy_to_websocket<R, W>(
mut reader: &mut R, mut reader: &mut R,
mut writer: &mut WebSocketWriter<W>, mut writer: &mut WebSocketWriter<W>,
receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>) -> Result<bool, Error> receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>,
) -> Result<bool, Error>
where where
R: AsyncRead + Unpin + Send, R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
@ -738,7 +760,7 @@ impl WebSocket {
let mut eof = false; let mut eof = false;
loop { loop {
if !buf.is_full() { if !buf.is_full() {
let bytes = select!{ let bytes = select! {
res = buf.read_from_async(&mut reader).fuse() => res?, res = buf.read_from_async(&mut reader).fuse() => res?,
res = receiver.recv().fuse() => { res = receiver.recv().fuse() => {
let res = res.ok_or_else(|| format_err!("control channel closed"))?; let res = res.ok_or_else(|| format_err!("control channel closed"))?;
@ -777,7 +799,6 @@ impl WebSocket {
S: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
L: AsyncRead + AsyncWrite + Unpin + Send, L: AsyncRead + AsyncWrite + Unpin + Send,
{ {
let (usreader, uswriter) = tokio::io::split(upstream); let (usreader, uswriter) = tokio::io::split(upstream);
let (mut dsreader, mut dswriter) = tokio::io::split(downstream); let (mut dsreader, mut dswriter) = tokio::io::split(downstream);
@ -785,11 +806,10 @@ impl WebSocket {
let mut wsreader = WebSocketReader::new(usreader, tx); let mut wsreader = WebSocketReader::new(usreader, tx);
let mut wswriter = WebSocketWriter::new(None, self.text, uswriter); let mut wswriter = WebSocketWriter::new(None, self.text, uswriter);
let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter); let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter);
let term_future = Self::copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx); let term_future = Self::copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx);
let res = select!{ let res = select! {
res = ws_future.fuse() => match res { res = ws_future.fuse() => match res {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => Err(Error::from(err)), Err(err) => Err(Error::from(err)),