diff --git a/Cargo.lock b/Cargo.lock index abfcbfae4..b8ba3d57b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -394,9 +394,9 @@ checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" @@ -1990,9 +1990,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "1.2.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", @@ -2026,9 +2026,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.3" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" dependencies = [ "bytes", "futures-channel", @@ -3618,7 +3618,7 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "bytes", "encoding_rs", "futures-core", @@ -3789,7 +3789,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "rustls-pki-types", ] @@ -4845,7 +4845,7 @@ name = "tauri-plugin-authenticator" version = "0.0.0" dependencies = [ "authenticator", - "base64 0.22.0", + "base64 0.22.1", "byteorder", "bytes", "chrono", @@ -5029,8 +5029,11 @@ dependencies = [ name = "tauri-plugin-websocket" version = "0.0.0" dependencies = [ + "base64 0.22.1", "futures-util", "http 1.0.0", + "hyper", + "hyper-util", "log", "rand 0.8.5", "serde", @@ -5373,7 +5376,6 @@ dependencies = [ "tokio", "tower-layer", "tower-service", - "tracing", ] [[package]] diff --git a/plugins/websocket/Cargo.toml b/plugins/websocket/Cargo.toml index 5ce510e22..db1f185cb 100644 --- a/plugins/websocket/Cargo.toml +++ b/plugins/websocket/Cargo.toml @@ -20,3 +20,6 @@ rand = "0.8" futures-util = "0.3" tokio = { version = "1", features = ["net", "sync"] } tokio-tungstenite = { version = "0.23", features = ["native-tls"] } +hyper = { version = "1.4.1", features = ["client"] } +hyper-util = { version = "0.1.6", features = ["tokio", "http1"] } +base64 = "0.22.1" diff --git a/plugins/websocket/src/lib.rs b/plugins/websocket/src/lib.rs index e5d1007e4..a0261434d 100644 --- a/plugins/websocket/src/lib.rs +++ b/plugins/websocket/src/lib.rs @@ -1,16 +1,23 @@ +use base64::prelude::{Engine, BASE64_STANDARD}; use futures_util::{stream::SplitSink, SinkExt, StreamExt}; -use http::header::{HeaderName, HeaderValue}; +use http::{ + header::{HeaderName, HeaderValue}, + Request, +}; +use hyper::client::conn; +use hyper_util::rt::TokioIo; use serde::{ser::Serializer, Deserialize, Serialize}; use tauri::{ api::ipc::{format_callback, CallbackFn}, plugin::{Builder as PluginBuilder, TauriPlugin}, - Manager, Runtime, State, Window, + AppHandle, Manager, Runtime, State, Window, }; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{ - connect_async_tls_with_config, + client_async_tls_with_config, connect_async_with_config, tungstenite::{ client::IntoClientRequest, + error::UrlError, protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig}, Message, }, @@ -22,7 +29,8 @@ use std::str::FromStr; type Id = u32; type WebSocket = WebSocketStream>; -type WebSocketWriter = SplitSink; +type WebSocketWriter = + SplitSink>, Message>; type Result = std::result::Result; #[derive(Debug, thiserror::Error)] @@ -35,6 +43,14 @@ enum Error { InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), #[error(transparent)] InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), + #[error(transparent)] + ProxyConnectionError(#[from] hyper::Error), + #[error("proxy returned status code: {0}")] + ProxyStatusError(u16), + #[error(transparent)] + ProxyIoError(std::io::Error), + #[error(transparent)] + ProxyHttpError(http::Error), } impl Serialize for Error { @@ -50,6 +66,26 @@ impl Serialize for Error { struct ConnectionManager(Mutex>); struct TlsConnector(Mutex>); +struct ProxyConfigurationInternal(Mutex>); + +#[derive(Clone)] +pub struct ProxyAuth { + pub username: String, + pub password: String, +} + +impl ProxyAuth { + pub fn encode(&self) -> String { + BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password)) + } +} + +#[derive(Clone)] +pub struct ProxyConfiguration { + pub proxy_url: String, + pub proxy_port: u16, + pub auth: Option, +} #[derive(Deserialize)] #[serde(rename_all = "camelCase")] @@ -105,10 +141,6 @@ async fn connect( ) -> Result { let id = rand::random(); let mut request = url.into_client_request()?; - let tls_connector = match window.try_state::() { - Some(tls_connector) => tls_connector.0.lock().await.clone(), - None => None, - }; if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) { for (k, v) in headers { @@ -118,9 +150,32 @@ async fn connect( } } - let (ws_stream, _) = - connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector) - .await?; + #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] + let tls_connector = match window.try_state::() { + Some(tls_connector) => tls_connector.0.lock().await.clone(), + None => None, + }; + #[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))] + let tls_connector = None; + + let proxy_config = match window.try_state::() { + Some(proxy_config) => proxy_config.0.lock().await.clone(), + None => None, + }; + + let ws_stream = if let Some(proxy_config) = proxy_config { + connect_using_proxy(request, config, proxy_config, tls_connector).await? + } else { + #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] + let (ws_stream, _) = + connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector) + .await?; + #[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))] + let (ws_stream, _) = + connect_async_with_config(request, config.map(Into::into), false).await?; + + ws_stream + }; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split(); @@ -168,6 +223,70 @@ async fn connect( Ok(id) } +async fn connect_using_proxy( + request: Request<()>, + config: Option, + proxy_config: ProxyConfiguration, + tls_connector: Option, +) -> Result { + let domain = domain(&request)?; + let port = request + .uri() + .port_u16() + .or_else(|| match request.uri().scheme_str() { + Some("wss") => Some(443), + Some("ws") => Some(80), + _ => None, + }) + .ok_or(Error::Websocket( + tokio_tungstenite::tungstenite::Error::Url(UrlError::UnsupportedUrlScheme), + ))?; + + let tcp = TcpStream::connect(format!( + "{}:{}", + proxy_config.proxy_url, proxy_config.proxy_port + )) + .await + .map_err(|original| Error::ProxyIoError(original))?; + let io = TokioIo::new(tcp); + + let (mut request_sender, proxy_connection) = + conn::http1::handshake::, String>(io).await?; + let proxy_connection_task = tokio::spawn(proxy_connection.without_shutdown()); + + let addr = format!("{domain}:{port}"); + let mut req_builder = Request::connect(addr); + + if let Some(auth) = proxy_config.auth { + req_builder = req_builder.header("Proxy-Authorization", format!("Basic {}", auth.encode())); + } + + let req = req_builder + .body("".to_string()) + .map_err(|orig| Error::ProxyHttpError(orig))?; + let res = request_sender.send_request(req).await?; + if res.status().as_u16() < 200 || res.status().as_u16() >= 300 { + return Err(Error::ProxyStatusError(res.status().as_u16())); + } + + // expect is fine since it would only rely panics from within the tokio task (or a cancellation which does not happen) + let proxy_connection = proxy_connection_task + .await + .expect("Panic in tokio task during websocket proxy initialization")?; + + let proxy_tcp_wrapper = proxy_connection.io; + let proxied_tcp_socket = proxy_tcp_wrapper.into_inner(); + let (ws_stream, _) = client_async_tls_with_config( + request, + proxied_tcp_socket, + config.map(Into::into), + tls_connector, + ) + .await?; + + Ok(ws_stream) +} + #[tauri::command] async fn send( manager: State<'_, ConnectionManager>, @@ -200,12 +319,14 @@ pub fn init() -> TauriPlugin { #[derive(Default)] pub struct Builder { tls_connector: Option, + proxy_configuration: Option, } impl Builder { pub fn new() -> Self { Self { tls_connector: None, + proxy_configuration: None, } } @@ -214,14 +335,60 @@ impl Builder { self } + pub fn proxy_configuration(mut self, proxy_configuration: ProxyConfiguration) -> Self { + self.proxy_configuration.replace(proxy_configuration); + self + } + pub fn build(self) -> TauriPlugin { PluginBuilder::new("websocket") .invoke_handler(tauri::generate_handler![connect, send]) .setup(|app| { app.manage(ConnectionManager::default()); app.manage(TlsConnector(Mutex::new(self.tls_connector))); + app.manage(ProxyConfigurationInternal(Mutex::new( + self.proxy_configuration, + ))); + Ok(()) }) .build() } } + +pub async fn reconfigure_proxy(app: &AppHandle, proxy_config: Option) { + if let Some(state) = app.try_state::() { + if let Some(proxy_config) = proxy_config { + state.0.lock().await.replace(proxy_config); + } else { + state.0.lock().await.take(); + } + } +} + +pub async fn reconfigure_tls_connector(app: &AppHandle, tls_connector: Option) { + if let Some(state) = app.try_state::() { + if let Some(tls_connector) = tls_connector { + state.0.lock().await.replace(tls_connector); + } else { + state.0.lock().await.take(); + } + } +} + +// Copied from tokio-tungstenite internal function (tokio-tungstenite/src/lib.rs) with the same name +// Get a domain from an URL. +#[inline] +fn domain( + request: &tokio_tungstenite::tungstenite::handshake::client::Request, +) -> tokio_tungstenite::tungstenite::Result { + match request.uri().host() { + // rustls expects IPv6 addresses without the surrounding [] brackets + #[cfg(feature = "__rustls-tls")] + Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()), + Some(d) => Ok(d.to_string()), + None => Err(tokio_tungstenite::tungstenite::Error::Url( + tokio_tungstenite::tungstenite::error::UrlError::NoHostName, + )), + } +}