diff --git a/proxmox-rest-server/src/connection.rs b/proxmox-rest-server/src/connection.rs index 470021d7..2ca83fe2 100644 --- a/proxmox-rest-server/src/connection.rs +++ b/proxmox-rest-server/src/connection.rs @@ -2,7 +2,10 @@ //! //! Hyper building block. +use std::io; +use std::mem::ManuallyDrop; use std::net::SocketAddr; +use std::os::fd::FromRawFd; use std::os::unix::io::AsRawFd; use std::path::PathBuf; use std::pin::Pin; @@ -418,70 +421,79 @@ impl AcceptBuilder { secure_sender: ClientSender, insecure_sender: InsecureClientSender, ) { + const CLIENT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); + let peer = state.peer; - let client_initiates_handshake = { - #[cfg(feature = "rate-limited-stream")] - let socket_ref = state.socket.inner(); + #[cfg(feature = "rate-limited-stream")] + let socket_ref = state.socket.inner(); - #[cfg(not(feature = "rate-limited-stream"))] - let socket_ref = &state.socket; + #[cfg(not(feature = "rate-limited-stream"))] + let socket_ref = &state.socket; - match Self::wait_for_client_tls_handshake(socket_ref).await { - Ok(initiates_handshake) => initiates_handshake, - Err(err) => { - log::error!("[{peer}] error checking for TLS handshake: {err}"); - return; + let handshake_res = + Self::wait_for_client_tls_handshake(socket_ref, CLIENT_HANDSHAKE_TIMEOUT).await; + + match handshake_res { + Ok(true) => { + Self::do_accept_tls(state, flags, secure_sender).await; + } + Ok(false) => { + let insecure_stream = Box::pin(state.socket); + + if let Err(send_err) = insecure_sender.send(Ok(insecure_stream)).await { + log::error!("[{peer}] failed to accept connection - connection channel closed: {send_err}"); } } - }; - - if !client_initiates_handshake { - let insecure_stream = Box::pin(state.socket); - - if insecure_sender.send(Ok(insecure_stream)).await.is_err() && flags.is_debug { - log::error!("[{peer}] detected closed connection channel") + Err(err) => { + log::error!("[{peer}] failed to check for TLS handshake: {err}"); } - - return; } - - Self::do_accept_tls(state, flags, secure_sender).await } - async fn wait_for_client_tls_handshake(incoming_stream: &TcpStream) -> Result { - const MS_TIMEOUT: u64 = 1000; - const BYTES_BUF_SIZE: usize = 128; - - let mut buf = [0; BYTES_BUF_SIZE]; - let mut last_peek_size = 0; + async fn wait_for_client_tls_handshake( + incoming_stream: &TcpStream, + timeout: Duration, + ) -> Result { + const HANDSHAKE_BYTES_LEN: usize = 5; let future = async { - loop { - let peek_size = incoming_stream - .peek(&mut buf) - .await - .context("couldn't peek into incoming tcp stream")?; + incoming_stream + .async_io(tokio::io::Interest::READABLE, || { + let mut buf = [0; HANDSHAKE_BYTES_LEN]; - if contains_tls_handshake_fragment(&buf) { - return Ok(true); - } + // Convert to standard lib TcpStream so we can peek without interfering + // with tokio's internals. Wrap the stream in ManuallyDrop in order to prevent + // the destructor from being called, closing the connection and messing up + // invariants. + let raw_fd = incoming_stream.as_raw_fd(); + let std_stream = + unsafe { ManuallyDrop::new(std::net::TcpStream::from_raw_fd(raw_fd)) }; - // No more new data came in - if peek_size == last_peek_size { - return Ok(false); - } + let peek_res = std_stream.peek(&mut buf); - last_peek_size = peek_size; - - // explicitly yield to event loop; this future otherwise blocks ad infinitum - tokio::task::yield_now().await; - } + match peek_res { + // If we didn't get enough bytes, raise an EAGAIN / EWOULDBLOCK which tells + // tokio to await the readiness of the socket again. This should normally + // only be used if the socket isn't actually ready, but is fine to do here + // in our case. + // + // This means we will peek into the stream's queue until we got + // HANDSHAKE_BYTE_LEN bytes or an error. + Ok(peek_len) if peek_len < HANDSHAKE_BYTES_LEN => { + Err(io::ErrorKind::WouldBlock.into()) + } + // Either we got Ok(HANDSHAKE_BYTES_LEN) or some error. + res => res.map(|_| contains_tls_handshake_fragment(&buf)), + } + }) + .await + .context("couldn't peek into incoming TCP stream") }; - tokio::time::timeout(Duration::from_millis(MS_TIMEOUT), future) + tokio::time::timeout(timeout, future) .await - .unwrap_or(Ok(false)) + .context("timed out while waiting for client to initiate TLS handshake")? } }