use anyhow::{bail, format_err, Error}; use std::collections::VecDeque; use std::sync::Arc; use futures::future::FutureExt; use futures::select; use hyper::client::{Client, HttpConnector}; use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::upgrade::Upgraded; use hyper::{Body, Request, StatusCode}; use openssl::ssl::{SslConnector, SslMethod}; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net::{UnixListener, UnixStream}; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::LinesStream; use tokio_stream::StreamExt; use proxmox_http::client::HttpsConnector; use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter}; #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "kebab-case")] enum CmdType { Connect, Forward, NonControl, } type CmdData = Map; #[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "kebab-case")] struct ConnectCmdData { /// target URL for WS connection url: String, /// fingerprint of TLS certificate fingerprint: Option, /// addition headers such as authorization headers: Option>, } #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "kebab-case")] struct ForwardCmdData { /// target URL for WS connection url: String, /// addition headers such as authorization headers: Option>, /// fingerprint of TLS certificate fingerprint: Option, /// local UNIX socket path for forwarding unix: String, /// request ticket using these parameters ticket: Option>, } struct CtrlTunnel { sender: Option)>>, } impl CtrlTunnel { async fn read_cmd_loop(mut self) -> Result<(), Error> { let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines()); while let Some(res) = stdin_stream.next().await { match res { Ok(line) => { let (cmd_type, data) = Self::parse_cmd(&line)?; match cmd_type { CmdType::Connect => self.handle_connect_cmd(data).await, CmdType::Forward => { let res = self.handle_forward_cmd(data).await; match &res { Ok(()) => println!("{}", serde_json::json!({"success": true})), Err(msg) => println!( "{}", serde_json::json!({"success": false, "msg": msg.to_string()}) ), }; res } CmdType::NonControl => self .handle_tunnel_cmd(data) .await .map(|res| println!("{}", res)), } } Err(err) => bail!("Failed to read from STDIN - {}", err), }?; } Ok(()) } fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> { let mut json: Map = serde_json::from_str(line)?; match json.remove("control") { Some(Value::Bool(true)) => { match json.remove("cmd").map(serde_json::from_value::) { None => bail!("input has 'control' flag, but no control 'cmd' set.."), Some(Err(e)) => bail!("failed to parse control cmd - {}", e), Some(Ok(cmd_type)) => Ok((cmd_type, json)), } } _ => Ok((CmdType::NonControl, json)), } } async fn websocket_connect( url: String, extra_headers: Vec<(String, String)>, fingerprint: Option, ) -> Result { let ws_key = proxmox_sys::linux::random_data(16)?; let ws_key = base64::encode(&ws_key); let mut req = Request::builder() .uri(url) .header(UPGRADE, "websocket") .header(SEC_WEBSOCKET_VERSION, "13") .header(SEC_WEBSOCKET_KEY, ws_key) .body(Body::empty())?; let headers = req.headers_mut(); for (name, value) in extra_headers { let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?; let value = hyper::header::HeaderValue::from_str(&value)?; headers.insert(name, value); } let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?; if let Some(expected) = fingerprint { ssl_connector_builder.set_verify_callback( openssl::ssl::SslVerifyMode::PEER, move |_valid, ctx| { let cert = match ctx.current_cert() { Some(cert) => cert, None => { // should not happen eprintln!("SSL context lacks current certificate."); return false; } }; // skip CA certificates, we only care about the peer cert let depth = ctx.error_depth(); if depth != 0 { return true; } let fp = match cert.digest(openssl::hash::MessageDigest::sha256()) { Ok(fp) => fp, Err(err) => { // should not happen eprintln!("failed to calculate certificate FP - {}", err); return false; } }; let fp_string = hex::encode(&fp); let fp_string = fp_string .as_bytes() .chunks(2) .map(|v| std::str::from_utf8(v).unwrap()) .collect::>() .join(":"); let expected = expected.to_lowercase(); if expected == fp_string { true } else { eprintln!("certificate fingerprint does not match expected fingerprint!"); eprintln!("expected: {}", expected); eprintln!("encountered: {}", fp_string); false } }, ); } else { ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER); } let mut httpc = HttpConnector::new(); httpc.enforce_http(false); // we want https... httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0))); let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120); let client = Client::builder().build::<_, Body>(https); let res = client.request(req).await?; if res.status() != StatusCode::SWITCHING_PROTOCOLS { bail!("server didn't upgrade: {}", res.status()); } hyper::upgrade::on(res) .await .map_err(|err| format_err!("failed to upgrade - {}", err)) } async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> { let mut data: ConnectCmdData = data .remove("data") .ok_or_else(|| format_err!("'connect' command missing 'data'")) .map(serde_json::from_value)??; if self.sender.is_some() { bail!("already connected!"); } let upgraded = Self::websocket_connect( data.url.clone(), data.headers.take().unwrap_or_else(Vec::new), data.fingerprint.take(), ) .await?; let (tx, rx) = mpsc::unbounded_channel(); self.sender = Some(tx); tokio::spawn(async move { if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await { eprintln!("Tunnel to {} failed - {}", data.url, err); } }); Ok(()) } async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> { let data: ForwardCmdData = data .remove("data") .ok_or_else(|| format_err!("'forward' command missing 'data'")) .map(serde_json::from_value)??; if self.sender.is_none() && data.ticket.is_some() { bail!("dynamically requesting ticket requires cmd tunnel connection!"); } let unix_listener = UnixListener::bind(data.unix.clone())?; let data = Arc::new(data); let cmd_sender = self.sender.clone(); tokio::spawn(async move { let mut tasks: Vec> = Vec::new(); let data2 = data.clone(); loop { let data3 = data2.clone(); match unix_listener.accept().await { Ok((unix_stream, _)) => { eprintln!("accepted new connection on '{}'", data3.unix); let cmd_sender2 = cmd_sender.clone(); let task = tokio::spawn(async move { if let Err(err) = Self::handle_forward_tunnel( cmd_sender2.clone(), data3.clone(), unix_stream, ) .await { eprintln!("Tunnel for {} failed - {}", data3.unix, err); } }); tasks.push(task); } Err(err) => eprintln!( "Failed to accept unix connection on {} - {}", data3.unix, err ), }; } }); Ok(()) } async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result { match &mut self.sender { None => bail!("not connected!"), Some(sender) => { let data: Value = data.into(); let (tx, rx) = oneshot::channel::(); if let Some(cmd) = data.get("cmd") { eprintln!("-> sending command {} to remote", cmd); } else { eprintln!("-> sending data line to remote"); } sender.send((data, tx))?; let res = rx.await?; eprintln!("<- got reply"); Ok(res) } } } async fn handle_ctrl_tunnel( websocket: Upgraded, mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender)>, ) -> Result<(), Error> { let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket); let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel(); let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx); let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer); let mut framed_reader = tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new()); let mut resp_tx_queue: VecDeque> = VecDeque::new(); let mut shutting_down = false; loop { let mut close_future = ws_close_rx.recv().boxed().fuse(); let mut frame_future = framed_reader.next().boxed().fuse(); let mut cmd_future = cmd_receiver.recv().boxed().fuse(); select! { res = close_future => { let res = res.ok_or_else(|| format_err!("WS control channel closed"))?; eprintln!("WS: received control message: '{:?}'", res); shutting_down = true; }, res = frame_future => { match res { None if shutting_down => { eprintln!("WS closed"); break; }, None => bail!("WS closed unexpectedly"), Some(Ok(res)) => { resp_tx_queue .pop_front() .ok_or_else(|| format_err!("no response handler"))? .send(res) .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?; }, Some(Err(err)) => { bail!("reading from control tunnel failed - WS receive failed: {}", err); }, } }, res = cmd_future => { if shutting_down { continue }; match res { None => { eprintln!("CMD channel closed, shutting down"); ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?; shutting_down = true; }, Some((msg, resp_tx)) => { resp_tx_queue.push_back(resp_tx); let line = format!("{}\n", msg); ws_writer.write_all(line.as_bytes()).await?; ws_writer.flush().await?; }, } }, }; } Ok(()) } async fn handle_forward_tunnel( cmd_sender: Option)>>, data: Arc, unix: UnixStream, ) -> Result<(), Error> { let data = match (&cmd_sender, &data.ticket) { (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await, _ => Ok(data.clone()), }?; let upgraded = Self::websocket_connect( data.url.clone(), data.headers.clone().unwrap_or_else(Vec::new), data.fingerprint.clone(), ) .await?; let ws = WebSocket { mask: Some([0, 0, 0, 0]), }; eprintln!("established new WS for forwarding '{}'", data.unix); ws.serve_connection(upgraded, unix).await?; eprintln!("done handling forwarded connection from '{}'", data.unix); Ok(()) } async fn get_ticket( cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender)>, cmd_data: Arc, ) -> Result, Error> { eprintln!("requesting WS ticket via tunnel"); let ticket_cmd = match cmd_data.ticket.clone() { Some(mut ticket_cmd) => { ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket")); ticket_cmd } None => bail!("can't get ticket without ticket parameters"), }; let (tx, rx) = oneshot::channel::(); cmd_sender.send((serde_json::json!(ticket_cmd), tx))?; let ticket = rx.await?; let mut ticket: Map = serde_json::from_str(&ticket)?; let ticket = ticket .remove("ticket") .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?; let ticket = ticket .as_str() .ok_or_else(|| format_err!("failed to format received ticket"))?; let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string(); let mut data = cmd_data.clone(); let mut url = data.url.clone(); url.push_str("ticket="); url.push_str(&ticket); let mut d = Arc::make_mut(&mut data); d.url = url; Ok(data) } } #[tokio::main] async fn main() -> Result<(), Error> { let tunnel = CtrlTunnel { sender: None }; tunnel.read_cmd_loop().await }