mirror of
https://git.proxmox.com/git/proxmox-websocket-tunnel
synced 2025-06-14 21:26:05 +00:00

avoid allocations and skip the utf8 check on the hex string Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
442 lines
16 KiB
Rust
442 lines
16 KiB
Rust
use anyhow::{bail, format_err, Error};
|
|
|
|
use std::collections::VecDeque;
|
|
use std::sync::Arc;
|
|
|
|
use futures::future::FutureExt;
|
|
use futures::select;
|
|
|
|
use hyper::client::{Client, HttpConnector};
|
|
use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE};
|
|
use hyper::upgrade::Upgraded;
|
|
use hyper::{Body, Request, StatusCode};
|
|
|
|
use openssl::ssl::{SslConnector, SslMethod};
|
|
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{Map, Value};
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::{UnixListener, UnixStream};
|
|
use tokio::sync::{mpsc, oneshot};
|
|
use tokio_stream::wrappers::LinesStream;
|
|
use tokio_stream::StreamExt;
|
|
|
|
use proxmox_http::client::HttpsConnector;
|
|
use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter};
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
#[serde(rename_all = "kebab-case")]
|
|
enum CmdType {
|
|
Connect,
|
|
Forward,
|
|
NonControl,
|
|
}
|
|
|
|
type CmdData = Map<String, Value>;
|
|
|
|
#[derive(Serialize, Deserialize, Debug)]
|
|
#[serde(rename_all = "kebab-case")]
|
|
struct ConnectCmdData {
|
|
/// target URL for WS connection
|
|
url: String,
|
|
|
|
/// fingerprint of TLS certificate
|
|
fingerprint: Option<String>,
|
|
|
|
/// addition headers such as authorization
|
|
headers: Option<Vec<(String, String)>>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
#[serde(rename_all = "kebab-case")]
|
|
struct ForwardCmdData {
|
|
/// target URL for WS connection
|
|
url: String,
|
|
|
|
/// addition headers such as authorization
|
|
headers: Option<Vec<(String, String)>>,
|
|
|
|
/// fingerprint of TLS certificate
|
|
fingerprint: Option<String>,
|
|
|
|
/// local UNIX socket path for forwarding
|
|
unix: String,
|
|
|
|
/// request ticket using these parameters
|
|
ticket: Option<Map<String, Value>>,
|
|
}
|
|
|
|
struct CtrlTunnel {
|
|
sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
|
|
}
|
|
|
|
impl CtrlTunnel {
|
|
async fn read_cmd_loop(mut self) -> Result<(), Error> {
|
|
let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines());
|
|
while let Some(res) = stdin_stream.next().await {
|
|
match res {
|
|
Ok(line) => {
|
|
let (cmd_type, data) = Self::parse_cmd(&line)?;
|
|
match cmd_type {
|
|
CmdType::Connect => self.handle_connect_cmd(data).await,
|
|
CmdType::Forward => {
|
|
let res = self.handle_forward_cmd(data).await;
|
|
match &res {
|
|
Ok(()) => println!("{}", serde_json::json!({"success": true})),
|
|
Err(msg) => println!(
|
|
"{}",
|
|
serde_json::json!({"success": false, "msg": msg.to_string()})
|
|
),
|
|
};
|
|
res
|
|
}
|
|
CmdType::NonControl => self
|
|
.handle_tunnel_cmd(data)
|
|
.await
|
|
.map(|res| println!("{}", res)),
|
|
}
|
|
}
|
|
Err(err) => bail!("Failed to read from STDIN - {}", err),
|
|
}?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> {
|
|
let mut json: Map<String, Value> = serde_json::from_str(line)?;
|
|
match json.remove("control") {
|
|
Some(Value::Bool(true)) => {
|
|
match json.remove("cmd").map(serde_json::from_value::<CmdType>) {
|
|
None => bail!("input has 'control' flag, but no control 'cmd' set.."),
|
|
Some(Err(e)) => bail!("failed to parse control cmd - {}", e),
|
|
Some(Ok(cmd_type)) => Ok((cmd_type, json)),
|
|
}
|
|
}
|
|
_ => Ok((CmdType::NonControl, json)),
|
|
}
|
|
}
|
|
|
|
async fn websocket_connect(
|
|
url: String,
|
|
extra_headers: Vec<(String, String)>,
|
|
fingerprint: Option<String>,
|
|
) -> Result<Upgraded, Error> {
|
|
let ws_key = proxmox_sys::linux::random_data(16)?;
|
|
let ws_key = base64::encode(&ws_key);
|
|
let mut req = Request::builder()
|
|
.uri(url)
|
|
.header(UPGRADE, "websocket")
|
|
.header(SEC_WEBSOCKET_VERSION, "13")
|
|
.header(SEC_WEBSOCKET_KEY, ws_key)
|
|
.body(Body::empty())?;
|
|
|
|
let headers = req.headers_mut();
|
|
for (name, value) in extra_headers {
|
|
let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
|
|
let value = hyper::header::HeaderValue::from_str(&value)?;
|
|
headers.insert(name, value);
|
|
}
|
|
|
|
let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?;
|
|
if let Some(expected) = fingerprint {
|
|
ssl_connector_builder.set_verify_callback(
|
|
openssl::ssl::SslVerifyMode::PEER,
|
|
move |_valid, ctx| {
|
|
let cert = match ctx.current_cert() {
|
|
Some(cert) => cert,
|
|
None => {
|
|
// should not happen
|
|
eprintln!("SSL context lacks current certificate.");
|
|
return false;
|
|
}
|
|
};
|
|
|
|
// skip CA certificates, we only care about the peer cert
|
|
let depth = ctx.error_depth();
|
|
if depth != 0 {
|
|
return true;
|
|
}
|
|
|
|
use itertools::Itertools;
|
|
let fp = match cert.digest(openssl::hash::MessageDigest::sha256()) {
|
|
Ok(fp) => fp,
|
|
Err(err) => {
|
|
// should not happen
|
|
eprintln!("failed to calculate certificate FP - {}", err);
|
|
return false;
|
|
}
|
|
};
|
|
let fp_string = hex::encode(&fp);
|
|
let fp_string = fp_string
|
|
.as_bytes()
|
|
.chunks(2)
|
|
.map(|v| unsafe { std::str::from_utf8_unchecked(v) })
|
|
.join(":");
|
|
|
|
let expected = expected.to_lowercase();
|
|
if expected == fp_string {
|
|
true
|
|
} else {
|
|
eprintln!("certificate fingerprint does not match expected fingerprint!");
|
|
eprintln!("expected: {}", expected);
|
|
eprintln!("encountered: {}", fp_string);
|
|
false
|
|
}
|
|
},
|
|
);
|
|
} else {
|
|
ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
|
|
}
|
|
|
|
let mut httpc = HttpConnector::new();
|
|
httpc.enforce_http(false); // we want https...
|
|
httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
|
|
let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120);
|
|
|
|
let client = Client::builder().build::<_, Body>(https);
|
|
let res = client.request(req).await?;
|
|
if res.status() != StatusCode::SWITCHING_PROTOCOLS {
|
|
bail!("server didn't upgrade: {}", res.status());
|
|
}
|
|
|
|
hyper::upgrade::on(res)
|
|
.await
|
|
.map_err(|err| format_err!("failed to upgrade - {}", err))
|
|
}
|
|
|
|
async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
|
|
let mut data: ConnectCmdData = data
|
|
.remove("data")
|
|
.ok_or_else(|| format_err!("'connect' command missing 'data'"))
|
|
.map(serde_json::from_value)??;
|
|
|
|
if self.sender.is_some() {
|
|
bail!("already connected!");
|
|
}
|
|
|
|
let upgraded = Self::websocket_connect(
|
|
data.url.clone(),
|
|
data.headers.take().unwrap_or_else(Vec::new),
|
|
data.fingerprint.take(),
|
|
)
|
|
.await?;
|
|
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
self.sender = Some(tx);
|
|
tokio::spawn(async move {
|
|
if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await {
|
|
eprintln!("Tunnel to {} failed - {}", data.url, err);
|
|
}
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
|
|
let data: ForwardCmdData = data
|
|
.remove("data")
|
|
.ok_or_else(|| format_err!("'forward' command missing 'data'"))
|
|
.map(serde_json::from_value)??;
|
|
|
|
if self.sender.is_none() && data.ticket.is_some() {
|
|
bail!("dynamically requesting ticket requires cmd tunnel connection!");
|
|
}
|
|
|
|
let unix_listener = UnixListener::bind(data.unix.clone())?;
|
|
let data = Arc::new(data);
|
|
let cmd_sender = self.sender.clone();
|
|
|
|
tokio::spawn(async move {
|
|
let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
|
|
let data2 = data.clone();
|
|
|
|
loop {
|
|
let data3 = data2.clone();
|
|
|
|
match unix_listener.accept().await {
|
|
Ok((unix_stream, _)) => {
|
|
eprintln!("accepted new connection on '{}'", data3.unix);
|
|
let cmd_sender2 = cmd_sender.clone();
|
|
|
|
let task = tokio::spawn(async move {
|
|
if let Err(err) = Self::handle_forward_tunnel(
|
|
cmd_sender2.clone(),
|
|
data3.clone(),
|
|
unix_stream,
|
|
)
|
|
.await
|
|
{
|
|
eprintln!("Tunnel for {} failed - {}", data3.unix, err);
|
|
}
|
|
});
|
|
tasks.push(task);
|
|
}
|
|
Err(err) => eprintln!(
|
|
"Failed to accept unix connection on {} - {}",
|
|
data3.unix, err
|
|
),
|
|
};
|
|
}
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> {
|
|
match &mut self.sender {
|
|
None => bail!("not connected!"),
|
|
Some(sender) => {
|
|
let data: Value = data.into();
|
|
let (tx, rx) = oneshot::channel::<String>();
|
|
if let Some(cmd) = data.get("cmd") {
|
|
eprintln!("-> sending command {} to remote", cmd);
|
|
} else {
|
|
eprintln!("-> sending data line to remote");
|
|
}
|
|
sender.send((data, tx))?;
|
|
let res = rx.await?;
|
|
eprintln!("<- got reply");
|
|
Ok(res)
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_ctrl_tunnel(
|
|
websocket: Upgraded,
|
|
mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>,
|
|
) -> Result<(), Error> {
|
|
let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket);
|
|
let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel();
|
|
let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx);
|
|
let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer);
|
|
|
|
let mut framed_reader =
|
|
tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new());
|
|
|
|
let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new();
|
|
let mut shutting_down = false;
|
|
|
|
loop {
|
|
let mut close_future = ws_close_rx.recv().boxed().fuse();
|
|
let mut frame_future = framed_reader.next().boxed().fuse();
|
|
let mut cmd_future = cmd_receiver.recv().boxed().fuse();
|
|
|
|
select! {
|
|
res = close_future => {
|
|
let res = res.ok_or_else(|| format_err!("WS control channel closed"))?;
|
|
eprintln!("WS: received control message: '{:?}'", res);
|
|
shutting_down = true;
|
|
},
|
|
res = frame_future => {
|
|
match res {
|
|
None if shutting_down => {
|
|
eprintln!("WS closed");
|
|
break;
|
|
},
|
|
None => bail!("WS closed unexpectedly"),
|
|
Some(Ok(res)) => {
|
|
resp_tx_queue
|
|
.pop_front()
|
|
.ok_or_else(|| format_err!("no response handler"))?
|
|
.send(res)
|
|
.map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?;
|
|
},
|
|
Some(Err(err)) => {
|
|
bail!("reading from control tunnel failed - WS receive failed: {}", err);
|
|
},
|
|
}
|
|
},
|
|
res = cmd_future => {
|
|
if shutting_down { continue };
|
|
match res {
|
|
None => {
|
|
eprintln!("CMD channel closed, shutting down");
|
|
ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?;
|
|
shutting_down = true;
|
|
},
|
|
Some((msg, resp_tx)) => {
|
|
resp_tx_queue.push_back(resp_tx);
|
|
|
|
let line = format!("{}\n", msg);
|
|
ws_writer.write_all(line.as_bytes()).await?;
|
|
ws_writer.flush().await?;
|
|
},
|
|
}
|
|
},
|
|
};
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_forward_tunnel(
|
|
cmd_sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
|
|
data: Arc<ForwardCmdData>,
|
|
unix: UnixStream,
|
|
) -> Result<(), Error> {
|
|
let data = match (&cmd_sender, &data.ticket) {
|
|
(Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await,
|
|
_ => Ok(data.clone()),
|
|
}?;
|
|
|
|
let upgraded = Self::websocket_connect(
|
|
data.url.clone(),
|
|
data.headers.clone().unwrap_or_else(Vec::new),
|
|
data.fingerprint.clone(),
|
|
)
|
|
.await?;
|
|
|
|
let ws = WebSocket {
|
|
mask: Some([0, 0, 0, 0]),
|
|
};
|
|
eprintln!("established new WS for forwarding '{}'", data.unix);
|
|
ws.serve_connection(upgraded, unix).await?;
|
|
|
|
eprintln!("done handling forwarded connection from '{}'", data.unix);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn get_ticket(
|
|
cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>,
|
|
cmd_data: Arc<ForwardCmdData>,
|
|
) -> Result<Arc<ForwardCmdData>, Error> {
|
|
eprintln!("requesting WS ticket via tunnel");
|
|
let ticket_cmd = match cmd_data.ticket.clone() {
|
|
Some(mut ticket_cmd) => {
|
|
ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket"));
|
|
ticket_cmd
|
|
}
|
|
None => bail!("can't get ticket without ticket parameters"),
|
|
};
|
|
let (tx, rx) = oneshot::channel::<String>();
|
|
cmd_sender.send((serde_json::json!(ticket_cmd), tx))?;
|
|
let ticket = rx.await?;
|
|
let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?;
|
|
let ticket = ticket
|
|
.remove("ticket")
|
|
.ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?;
|
|
|
|
let ticket = ticket
|
|
.as_str()
|
|
.ok_or_else(|| format_err!("failed to format received ticket"))?;
|
|
let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string();
|
|
|
|
let mut data = cmd_data.clone();
|
|
let mut url = data.url.clone();
|
|
url.push_str("ticket=");
|
|
url.push_str(&ticket);
|
|
let mut d = Arc::make_mut(&mut data);
|
|
d.url = url;
|
|
Ok(data)
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Error> {
|
|
let tunnel = CtrlTunnel { sender: None };
|
|
tunnel.read_cmd_loop().await
|
|
}
|