more formatting fixups

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
This commit is contained in:
Wolfgang Bumiller 2020-08-28 08:56:54 +02:00
parent bbc94222f5
commit ba7c389142

View File

@ -3,27 +3,21 @@
//! Provides methods to read and write from websockets The reader and writer take a reader/writer
//! with AsyncRead/AsyncWrite respectively and provides the same
use std::cmp::min;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::cmp::min;
use std::io;
use std::future::Future;
use futures::select;
use anyhow::{bail, format_err, Error};
use futures::select;
use hyper::header::{
HeaderMap, HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
};
use hyper::{Body, Response, StatusCode};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc;
use hyper::{Body, Response, StatusCode};
use hyper::header::{
HeaderMap,
HeaderValue,
UPGRADE,
CONNECTION,
SEC_WEBSOCKET_KEY,
SEC_WEBSOCKET_PROTOCOL,
SEC_WEBSOCKET_VERSION,
SEC_WEBSOCKET_ACCEPT,
};
use futures::future::FutureExt;
use futures::ready;
@ -56,16 +50,16 @@ impl std::fmt::Display for WebSocketErrorKind {
}
#[derive(Debug, Clone)]
pub struct WebSocketError{
pub struct WebSocketError {
kind: WebSocketErrorKind,
message: String,
}
impl WebSocketError {
pub fn new(kind: WebSocketErrorKind, message: &str) -> Self {
Self{
Self {
kind,
message: message.to_string()
message: message.to_string(),
}
}
@ -114,13 +108,13 @@ impl OpCode {
fn mask_bytes(mask: Option<[u8; 4]>, data: &mut [u8]) {
let mask = match mask {
Some([0,0,0,0]) | None => return,
Some([0, 0, 0, 0]) | None => return,
Some(mask) => mask,
};
if data.len() < 32 {
for i in 0..data.len() {
data[i] ^= mask[i%4];
data[i] ^= mask[i % 4];
}
return;
}
@ -197,12 +191,16 @@ pub fn create_frame(
let len = data.len();
if (frametype as u8) & 0b00001000 > 0 && len > 125 {
return Err(WebSocketError::new(
WebSocketErrorKind::Unexpected,
"Control frames cannot have data longer than 125 bytes",
WebSocketErrorKind::Unexpected,
"Control frames cannot have data longer than 125 bytes",
));
}
let mask_bit = if mask.is_some() { 0b10000000 } else { 0b00000000 };
let mask_bit = if mask.is_some() {
0b10000000
} else {
0b00000000
};
let mut buf = Vec::new();
buf.push(first_byte);
@ -263,21 +261,26 @@ impl<W: AsyncWrite + Unpin> WebSocketWriter<W> {
}
}
pub async fn send_control_frame(&mut self, mask: Option<[u8; 4]>, opcode: OpCode, data: &[u8]) -> Result<(), Error> {
pub async fn send_control_frame(
&mut self,
mask: Option<[u8; 4]>,
opcode: OpCode,
data: &[u8],
) -> Result<(), Error> {
let frame = create_frame(mask, data, opcode).map_err(Error::from)?;
self.writer.write_all(&frame).await.map_err(Error::from)
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = Pin::get_mut(self);
let frametype = if this.text { OpCode::Text } else { OpCode::Binary };
let frametype = if this.text {
OpCode::Text
} else {
OpCode::Binary
};
if this.frame.is_none() {
// create frame buf
@ -301,11 +304,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
this.frame = None;
return Poll::Ready(Ok(size));
}
},
}
Err(err) => {
eprintln!("error in writer: {}", err);
return Poll::Ready(Err(err))
},
return Poll::Ready(Err(err));
}
}
}
}
@ -321,7 +324,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
}
}
#[derive(Debug,PartialEq)]
#[derive(Debug, PartialEq)]
/// Represents the header of a websocket Frame
pub struct FrameHeader {
/// True if the frame is either non-fragmented, or the last fragment
@ -383,8 +386,8 @@ impl FrameHeader {
// we do not support extensions
if data[0] & 0b01110000 > 0 {
return Err(WebSocketError::new(
WebSocketErrorKind::ProtocolError,
"Extensions not supported",
WebSocketErrorKind::ProtocolError,
"Extensions not supported",
));
}
@ -398,16 +401,16 @@ impl FrameHeader {
10 => OpCode::Pong,
other => {
return Err(WebSocketError::new(
WebSocketErrorKind::ProtocolError,
&format!("Unknown OpCode {}", other),
WebSocketErrorKind::ProtocolError,
&format!("Unknown OpCode {}", other),
));
}
};
if !fin && frametype.is_control() {
return Err(WebSocketError::new(
WebSocketErrorKind::ProtocolError,
"Control frames cannot be fragmented",
WebSocketErrorKind::ProtocolError,
"Control frames cannot be fragmented",
));
}
@ -439,8 +442,8 @@ impl FrameHeader {
if payload_len > 125 && frametype.is_control() {
return Err(WebSocketError::new(
WebSocketErrorKind::ProtocolError,
"Control frames cannot carry more than 125 bytes of data",
WebSocketErrorKind::ProtocolError,
"Control frames cannot carry more than 125 bytes of data",
));
}
@ -485,11 +488,18 @@ pub struct WebSocketReader<R: AsyncRead> {
impl<R: AsyncReadExt> WebSocketReader<R> {
/// Creates a new WebSocketReader with the given CallBack for control frames
/// and a default buffer size of 4096.
pub fn new(reader: R, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> {
pub fn new(
reader: R,
sender: mpsc::UnboundedSender<WebSocketReadResult>,
) -> WebSocketReader<R> {
Self::with_capacity(reader, 4096, sender)
}
pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> {
pub fn with_capacity(
reader: R,
capacity: usize,
sender: mpsc::UnboundedSender<WebSocketReadResult>,
) -> WebSocketReader<R> {
WebSocketReader {
reader: Some(reader),
sender,
@ -537,25 +547,32 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
};
let future = async move {
buffer.read_from_async(&mut reader)
buffer
.read_from_async(&mut reader)
.await
.map(move |len| ReadResult { len, reader, buffer })
.map(move |len| ReadResult {
len,
reader,
buffer,
})
};
this.state = ReaderState::Receiving(future.boxed());
},
ReaderState::Receiving(ref mut future) => {
match ready!(future.as_mut().poll(cx)) {
Ok(ReadResult { len, reader, buffer }) => {
this.reader = Some(reader);
this.read_buffer = Some(buffer);
this.state = ReaderState::HaveData;
if len == 0 {
return Poll::Ready(Ok(0));
}
},
Err(err) => return Poll::Ready(Err(err)),
}
ReaderState::Receiving(ref mut future) => match ready!(future.as_mut().poll(cx)) {
Ok(ReadResult {
len,
reader,
buffer,
}) => {
this.reader = Some(reader);
this.read_buffer = Some(buffer);
this.state = ReaderState::HaveData;
if len == 0 {
return Poll::Ready(Ok(0));
}
}
Err(err) => return Poll::Ready(Err(err)),
},
ReaderState::HaveData => {
let mut read_buffer = match this.read_buffer.take() {
@ -572,30 +589,29 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
this.state = ReaderState::NoData;
this.read_buffer = Some(read_buffer);
continue;
},
}
Err(err) => {
if let Err(err) = this.sender.send(Err(err.clone())) {
return Poll::Ready(Err(io_err_other(err)));
}
return Poll::Ready(Err(io_err_other(err)));
},
}
};
read_buffer.consume(header.header_len as usize);
header
},
}
};
if header.is_control_frame() {
if read_buffer.len() >= header.payload_len {
let mut data = read_buffer.remove_data(header.payload_len);
mask_bytes(header.mask, &mut data);
if let Err(err) = this.sender.send(Ok((header.frametype, data))) {
eprintln!("error sending control frame: {}", err);
}
this.state = if read_buffer.is_empty() {
this.state = if read_buffer.is_empty() {
ReaderState::NoData
} else {
ReaderState::HaveData
@ -610,11 +626,14 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
}
}
let len = min(buf.len() - offset, min(header.payload_len, read_buffer.len()));
let len = min(
buf.len() - offset,
min(header.payload_len, read_buffer.len()),
);
let mut data = read_buffer.remove_data(len);
mask_bytes(header.mask, &mut data);
buf[offset..offset+len].copy_from_slice(&data);
buf[offset..offset + len].copy_from_slice(&data);
offset += len;
header.payload_len -= len;
@ -633,7 +652,7 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
if offset > 0 {
return Poll::Ready(Ok(offset));
}
},
}
}
}
}
@ -701,7 +720,7 @@ impl WebSocket {
async fn handle_channel_message<W>(
result: WebSocketReadResult,
writer: &mut WebSocketWriter<W>
writer: &mut WebSocketWriter<W>,
) -> Result<OpCode, Error>
where
W: AsyncWrite + Unpin + Send,
@ -718,9 +737,11 @@ impl WebSocket {
Ok((opcode, _)) => {
// ignore other frames
Ok(opcode)
},
}
Err(err) => {
writer.send_control_frame(None, OpCode::Close, &err.generate_frame_payload()).await?;
writer
.send_control_frame(None, OpCode::Close, &err.generate_frame_payload())
.await?;
Err(Error::from(err))
}
}
@ -729,7 +750,8 @@ impl WebSocket {
async fn copy_to_websocket<R, W>(
mut reader: &mut R,
mut writer: &mut WebSocketWriter<W>,
receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>) -> Result<bool, Error>
receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>,
) -> Result<bool, Error>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
@ -738,7 +760,7 @@ impl WebSocket {
let mut eof = false;
loop {
if !buf.is_full() {
let bytes = select!{
let bytes = select! {
res = buf.read_from_async(&mut reader).fuse() => res?,
res = receiver.recv().fuse() => {
let res = res.ok_or_else(|| format_err!("control channel closed"))?;
@ -777,7 +799,6 @@ impl WebSocket {
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
L: AsyncRead + AsyncWrite + Unpin + Send,
{
let (usreader, uswriter) = tokio::io::split(upstream);
let (mut dsreader, mut dswriter) = tokio::io::split(downstream);
@ -785,11 +806,10 @@ impl WebSocket {
let mut wsreader = WebSocketReader::new(usreader, tx);
let mut wswriter = WebSocketWriter::new(None, self.text, uswriter);
let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter);
let term_future = Self::copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx);
let res = select!{
let res = select! {
res = ws_future.fuse() => match res {
Ok(_) => Ok(()),
Err(err) => Err(Error::from(err)),