diff --git a/Cargo.toml b/Cargo.toml index 3effc75a..1fb4a41f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async_wasi" readme = "README.md" repository = "https://github.com/WasmEdge/mysql_async_wasi" -version = "0.30.1" +version = "0.31.2" exclude = ["test/*"] edition = "2018" categories = ["asynchronous", "database"] @@ -20,12 +20,13 @@ futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" lazy_static = "1" -lru = "0.7.0" -mysql_common = { version = "0.29.0", default-features = false } +lru = "0.8.1" +mysql_common = { version = "0.29.2", default-features = false } once_cell = "1.7.2" pem = "1.0.1" percent-encoding = "2.1.0" pin-project = "1.0.2" +priority-queue = "1" serde = "1" serde_json = "1" thiserror = "1.0.4" @@ -52,6 +53,40 @@ wasmedge_wasi_socket = "0.4.2" # rand = "0.8.0" [target.'cfg(target_os="wasi")'.dev-dependencies] +tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } +tokio-util = { version = "0.7.2", features = ["codec", "io"] } + + +[dependencies.tokio-rustls] +version = "0.23.4" +optional = true + +[dependencies.tokio-native-tls] +version = "0.3.0" +optional = true + +[dependencies.native-tls] +version = "0.2" +optional = true + +[dependencies.rustls] +version = "0.20.0" +features = ["dangerous_configuration"] +optional = true + +[dependencies.rustls-pemfile] +version = "1.0.1" +optional = true + +[dependencies.webpki] +version = "0.22.0" +optional = true + +[dependencies.webpki-roots] +version = "0.22.1" +optional = true + +[dev-dependencies] tempfile = "3.1.0" tokio_wasi = { version = "1", features = [ "io-util", "fs", "net", "time", "rt", "macros"] } rand = "0.8.0" @@ -63,6 +98,25 @@ default = [ "mysql_common/time03", "mysql_common/uuid", "mysql_common/frunk", + # "native-tls-tls", +] +default-rustls = [ + "flate2/zlib", + "mysql_common/bigdecimal03", + "mysql_common/rust_decimal", + "mysql_common/time03", + "mysql_common/uuid", + "mysql_common/frunk", + "rustls-tls", +] +minimal = ["flate2/zlib"] +native-tls-tls = ["native-tls", "tokio-native-tls"] +rustls-tls = [ + "rustls", + "tokio-rustls", + "webpki", + "webpki-roots", + "rustls-pemfile", ] nightly = [] zlib = ["flate2/zlib"] diff --git a/README.md b/README.md index d217f83a..7cd7b068 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,85 @@ Note: We do not yet support SSL / TLS connections to databases in this WebAssemb mysql_async_wasi = "" ``` +## Crate Features + +Default feature set is wide – it includes all default [`mysql_common` features][myslqcommonfeatures] +as well as `native-tls`-based TLS support. + +### List Of Features + +* `minimal` – enables only necessary features (at the moment the only necessary feature + is `flate2` backend). Enables: + + - `flate2/zlib" + + **Example:** + + ```toml + [dependencies] + mysql_async = { version = "*", default-features = false, features = ["minimal"]} + ``` + + **Note:* it is possible to use another `flate2` backend by directly choosing it: + + ```toml + [dependencies] + mysql_async = { version = "*", default-features = false } + flate2 = { version = "*", default-features = false, features = ["rust_backend"] } + ``` + +* `default` – enables the following set of crate's and dependencies' features: + + - `native-tls-tls` + - `flate2/zlib" + - `mysql_common/bigdecimal03` + - `mysql_common/rust_decimal` + - `mysql_common/time03` + - `mysql_common/uuid` + - `mysql_common/frunk` + +* `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. + + **Example:** + + ```toml + [dependencies] + mysql_async = { version = "*", default-features = false, features = ["default-rustls"] } + ``` + +* `native-tls-tls` – enables `native-tls`-based TLS support _(conflicts with `rustls-tls`)_ + + **Example:** + + ```toml + [dependencies] + mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } + +* `rustls-tls` – enables `native-tls`-based TLS support _(conflicts with `native-tls-tls`)_ + + **Example:** + + ```toml + [dependencies] + mysql_async = { version = "*", default-features = false, features = ["rustls-tls"] } + +[myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features + +## TLS/SSL Support + +SSL support comes in two flavors: + +1. Based on native-tls – this is the default option, that usually works without pitfalls + (see the `native-tls-tls` crate feature). + +2. Based on rustls – TLS backend written in Rust (see the `rustls-tls` crate feature). + + Please also note a few things about rustls: + - it will fail if you'll try to connect to the server by its IP address, + hostname is required; + - it, most likely, won't work on windows, at least with default server certs, + generated by the MySql installer. + ## Example ```rust diff --git a/src/conn/mod.rs b/src/conn/mod.rs index a91babb3..6d865b9b 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -10,13 +10,13 @@ use futures_util::FutureExt; pub use mysql_common::named_params; use mysql_common::{ - constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8_GENERAL_CI}, + constants::DEFAULT_MAX_ALLOWED_PACKET, crypto, io::ParseBuf, packets::{ binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, - ResultSetTerminator, SslRequest, + OldEofPacket, ResultSetTerminator, }, proto::MySerialize, }; @@ -415,12 +415,14 @@ impl Conn { /// Returns true if io stream is encrypted. fn is_secure(&self) -> bool { - #[cfg(not(target_os = "wasi"))] + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] if let Some(ref stream) = self.inner.stream { stream.is_secure() } else { false } + + #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))] false } @@ -492,10 +494,24 @@ impl Conn { .get_capabilities() .contains(CapabilityFlags::CLIENT_SSL) { + if !self + .inner + .capabilities + .contains(CapabilityFlags::CLIENT_SSL) + { + return Err(DriverError::NoClientSslFlagFromServer.into()); + } + + let collation = if self.inner.version >= (5, 5, 3) { + UTF8MB4_GENERAL_CI + } else { + UTF8_GENERAL_CI + }; + let ssl_request = SslRequest::new( self.inner.capabilities, DEFAULT_MAX_ALLOWED_PACKET as u32, - UTF8_GENERAL_CI as u8, + collation as u8, ); self.write_struct(&ssl_request).await?; let conn = self; @@ -681,9 +697,18 @@ impl Conn { /// Returns `true` for ProgressReport packet. fn handle_packet(&mut self, packet: &PooledBuf) -> Result { let ok_packet = if self.has_pending_result() { - ParseBuf(&*packet) - .parse::>(self.capabilities()) - .map(|x| x.into_inner()) + if self + .capabilities() + .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) + { + ParseBuf(&*packet) + .parse::>(self.capabilities()) + .map(|x| x.into_inner()) + } else { + ParseBuf(&*packet) + .parse::>(self.capabilities()) + .map(|x| x.into_inner()) + } } else { ParseBuf(&*packet) .parse::>(self.capabilities()) @@ -1046,7 +1071,7 @@ impl Conn { mod test { use bytes::Bytes; use futures_util::stream::{self, StreamExt}; - use mysql_common::binlog::events::EventData; + use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN}; use tokio::time::timeout; use std::time::Duration; @@ -1435,15 +1460,15 @@ mod test { #[tokio::test] async fn should_perform_queries() -> super::Result<()> { - let long_string = ::std::iter::repeat('A') - .take(18 * 1024 * 1024) - .collect::(); let mut conn = Conn::new(get_opts()).await?; - let result: Vec<(String, u8)> = conn - .query(format!(r"SELECT '{}', 231", long_string)) - .await?; + for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) { + let long_string = ::std::iter::repeat('A').take(x).collect::(); + let result: Vec<(String, u8)> = conn + .query(format!(r"SELECT '{}', 231", long_string)) + .await?; + assert_eq!((long_string, 231_u8), result[0]); + } conn.disconnect().await?; - assert_eq!((long_string, 231_u8), result[0]); Ok(()) } diff --git a/src/conn/pool/futures/disconnect_pool.rs b/src/conn/pool/futures/disconnect_pool.rs index c409e18e..4e5c4f4d 100644 --- a/src/conn/pool/futures/disconnect_pool.rs +++ b/src/conn/pool/futures/disconnect_pool.rs @@ -16,7 +16,7 @@ use futures_core::ready; use tokio::sync::mpsc::UnboundedSender; use crate::{ - conn::pool::{Inner, Pool}, + conn::pool::{Inner, Pool, QUEUE_END_ID}, error::Error, Conn, }; @@ -50,7 +50,7 @@ impl Future for DisconnectPool { self.pool_inner.close.store(true, atomic::Ordering::Release); let mut exchange = self.pool_inner.exchange.lock().unwrap(); exchange.spawn_futures_if_needed(&self.pool_inner); - exchange.waiting.push_back(cx.waker().clone()); + exchange.waiting.push(cx.waker().clone(), QUEUE_END_ID); drop(exchange); if self.pool_inner.closed.load(atomic::Ordering::Acquire) { diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 429a016a..854950ab 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -16,7 +16,10 @@ use std::{ use futures_core::ready; use crate::{ - conn::{pool::Pool, Conn}, + conn::{ + pool::{Pool, QueueId}, + Conn, + }, error::*, }; @@ -58,6 +61,7 @@ impl GetConnInner { #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct GetConn { + pub(crate) queue_id: Option, pub(crate) pool: Option, pub(crate) inner: GetConnInner, } @@ -65,6 +69,7 @@ pub struct GetConn { impl GetConn { pub(crate) fn new(pool: &Pool) -> GetConn { GetConn { + queue_id: None, pool: Some(pool.clone()), inner: GetConnInner::New, } @@ -91,23 +96,26 @@ impl Future for GetConn { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match self.inner { - GetConnInner::New => match ready!(Pin::new(self.pool_mut()).poll_new_conn(cx))? - .inner - .take() - { - GetConnInner::Connecting(conn_fut) => { - self.inner = GetConnInner::Connecting(conn_fut); - } - GetConnInner::Checking(conn_fut) => { - self.inner = GetConnInner::Checking(conn_fut); - } - GetConnInner::Done => unreachable!( - "Pool::poll_new_conn never gives out already-consumed GetConns" - ), - GetConnInner::New => { - unreachable!("Pool::poll_new_conn never gives out GetConnInner::New") + GetConnInner::New => { + let queued = self.queue_id.is_some(); + let queue_id = *self.queue_id.get_or_insert_with(QueueId::next); + let next = + ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queued, queue_id))?; + match next { + GetConnInner::Connecting(conn_fut) => { + self.inner = GetConnInner::Connecting(conn_fut); + } + GetConnInner::Checking(conn_fut) => { + self.inner = GetConnInner::Checking(conn_fut); + } + GetConnInner::Done => unreachable!( + "Pool::poll_new_conn never gives out already-consumed GetConns" + ), + GetConnInner::New => { + unreachable!("Pool::poll_new_conn never gives out GetConnInner::New") + } } - }, + } GetConnInner::Done => { unreachable!("GetConn::poll polled after returning Async::Ready"); } @@ -158,6 +166,11 @@ impl Drop for GetConn { // We drop a connection before it can be resolved, a.k.a. cancelling it. // Make sure we maintain the necessary invariants towards the pool. if let Some(pool) = self.pool.take() { + // Remove the waker from the pool's waitlist in case this task was + // woken by another waker, like from tokio::time::timeout. + if let Some(queue_id) = self.queue_id { + pool.unqueue(queue_id); + } if let GetConnInner::Connecting(..) = self.inner.take() { pool.cancel_connection(); } diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 64792c5e..9fa107be 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -7,11 +7,14 @@ // modified, or distributed except according to those terms. use futures_util::FutureExt; +use priority_queue::PriorityQueue; use tokio::sync::mpsc; use std::{ + cmp::{Ordering, Reverse}, collections::VecDeque, convert::TryFrom, + hash::{Hash, Hasher}, pin::Pin, str::FromStr, sync::{atomic, Arc, Mutex}, @@ -62,7 +65,7 @@ impl From for IdlingConn { /// This is fine as long as we never do expensive work while holding the lock! #[derive(Debug)] struct Exchange { - waiting: VecDeque, + waiting: Waitlist, available: VecDeque, exist: usize, // only used to spawn the recycler the first time we're in async context @@ -87,6 +90,87 @@ impl Exchange { } } +#[derive(Default, Debug)] +struct Waitlist { + queue: PriorityQueue, +} + +impl Waitlist { + fn push(&mut self, w: Waker, queue_id: QueueId) { + self.queue.push( + QueuedWaker { + queue_id, + waker: Some(w), + }, + queue_id, + ); + } + + fn pop(&mut self) -> Option { + match self.queue.pop() { + Some((qw, _)) => Some(qw.waker.unwrap()), + None => None, + } + } + + fn remove(&mut self, id: QueueId) { + let tmp = QueuedWaker { + queue_id: id, + waker: None, + }; + self.queue.remove(&tmp); + } + + fn is_empty(&self) -> bool { + self.queue.is_empty() + } +} + +const QUEUE_END_ID: QueueId = QueueId(Reverse(u64::MAX)); + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub(crate) struct QueueId(Reverse); + +impl QueueId { + fn next() -> Self { + static NEXT_QUEUE_ID: atomic::AtomicU64 = atomic::AtomicU64::new(0); + let id = NEXT_QUEUE_ID.fetch_add(1, atomic::Ordering::SeqCst); + QueueId(Reverse(id)) + } +} + +#[derive(Debug)] +struct QueuedWaker { + queue_id: QueueId, + waker: Option, +} + +impl Eq for QueuedWaker {} + +impl PartialEq for QueuedWaker { + fn eq(&self, other: &Self) -> bool { + self.queue_id == other.queue_id + } +} + +impl Ord for QueuedWaker { + fn cmp(&self, other: &Self) -> Ordering { + self.queue_id.cmp(&other.queue_id) + } +} + +impl PartialOrd for QueuedWaker { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Hash for QueuedWaker { + fn hash(&self, state: &mut H) { + self.queue_id.hash(state) + } +} + /// Connection pool data. #[derive(Debug)] pub struct Inner { @@ -131,7 +215,7 @@ impl Pool { closed: false.into(), exchange: Mutex::new(Exchange { available: VecDeque::with_capacity(pool_opts.constraints().max()), - waiting: VecDeque::new(), + waiting: Waitlist::default(), exist: 0, recycler: Some((rx, pool_opts)), }), @@ -181,7 +265,7 @@ impl Pool { let mut exchange = self.inner.exchange.lock().unwrap(); if exchange.available.len() < self.opts.pool_opts().active_bound() { exchange.available.push_back(conn.into()); - if let Some(w) = exchange.waiting.pop_front() { + if let Some(w) = exchange.waiting.pop() { w.wake(); } return; @@ -216,17 +300,27 @@ impl Pool { let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= 1; // we just enabled the creation of a new connection! - if let Some(w) = exchange.waiting.pop_front() { + if let Some(w) = exchange.waiting.pop() { w.wake(); } } /// Poll the pool for an available connection. - fn poll_new_conn(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_new_conn_inner(cx) - } - - fn poll_new_conn_inner(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_new_conn( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + queued: bool, + queue_id: QueueId, + ) -> Poll> { + self.poll_new_conn_inner(cx, queued, queue_id) + } + + fn poll_new_conn_inner( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + queued: bool, + queue_id: QueueId, + ) -> Poll> { let mut exchange = self.inner.exchange.lock().unwrap(); // NOTE: this load must happen while we hold the lock, @@ -238,18 +332,21 @@ impl Pool { exchange.spawn_futures_if_needed(&self.inner); + // Check if others are waiting and we're not queued. + if !exchange.waiting.is_empty() && !queued { + exchange.waiting.push(cx.waker().clone(), queue_id); + return Poll::Pending; + } + while let Some(IdlingConn { mut conn, .. }) = exchange.available.pop_back() { if !conn.expired() { - return Poll::Ready(Ok(GetConn { - pool: Some(self.clone()), - inner: GetConnInner::Checking( - async move { - conn.stream_mut()?.check().await?; - Ok(conn) - } - .boxed(), - ), - })); + return Poll::Ready(Ok(GetConnInner::Checking( + async move { + conn.stream_mut()?.check().await?; + Ok(conn) + } + .boxed(), + ))); } else { self.send_to_recycler(conn); } @@ -261,16 +358,20 @@ impl Pool { // we are allowed to make a new connection, so we will! exchange.exist += 1; - return Poll::Ready(Ok(GetConn { - pool: Some(self.clone()), - inner: GetConnInner::Connecting(Conn::new(self.opts.clone()).boxed()), - })); + return Poll::Ready(Ok(GetConnInner::Connecting( + Conn::new(self.opts.clone()).boxed(), + ))); } - // no go -- we have to wait - exchange.waiting.push_back(cx.waker().clone()); + // Polled, but no conn available? Back into the queue. + exchange.waiting.push(cx.waker().clone(), queue_id); Poll::Pending } + + fn unqueue(&self, queue_id: QueueId) { + let mut exchange = self.inner.exchange.lock().unwrap(); + exchange.waiting.remove(queue_id); + } } impl Drop for Conn { @@ -301,12 +402,20 @@ mod test { try_join, FutureExt, }; use mysql_common::row::Row; - use tokio::time::sleep; + use tokio::time::{sleep, timeout}; - use std::time::Duration; + use std::{ + cmp::Reverse, + task::{RawWaker, RawWakerVTable, Waker}, + time::Duration, + }; use crate::{ - conn::pool::Pool, opts::PoolOpts, prelude::*, test_misc::get_opts, PoolConstraints, TxOpts, + conn::pool::{Pool, QueueId, Waitlist, QUEUE_END_ID}, + opts::PoolOpts, + prelude::*, + test_misc::get_opts, + PoolConstraints, TxOpts, }; macro_rules! conn_ex_field { @@ -762,6 +871,27 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_remove_waker_of_cancelled_task() { + let pool_constraints = PoolConstraints::new(1, 1).unwrap(); + let pool_opts = PoolOpts::default().with_constraints(pool_constraints); + + let pool = Pool::new(get_opts().pool_opts(pool_opts)); + let only_conn = pool.get_conn().await.unwrap(); + + let join_handle = tokio::spawn(timeout(Duration::from_secs(1), pool.get_conn())); + + sleep(Duration::from_secs(2)).await; + + match join_handle.await.unwrap() { + Err(_elapsed) => (), + _ => panic!("unexpected Ok()"), + } + drop(only_conn); + + assert_eq!(0, pool.inner.exchange.lock().unwrap().waiting.queue.len()); + } + #[tokio::test] async fn should_work_if_pooled_connection_operation_is_cancelled() -> super::Result<()> { let pool = Pool::new(get_opts()); @@ -806,6 +936,40 @@ mod test { Ok(()) } + #[test] + fn waitlist_integrity() { + const DATA: *const () = &(); + const NOOP_CLONE_FN: unsafe fn(*const ()) -> RawWaker = |_| RawWaker::new(DATA, &RW_VTABLE); + const NOOP_FN: unsafe fn(*const ()) = |_| {}; + static RW_VTABLE: RawWakerVTable = + RawWakerVTable::new(NOOP_CLONE_FN, NOOP_FN, NOOP_FN, NOOP_FN); + let w = unsafe { Waker::from_raw(RawWaker::new(DATA, &RW_VTABLE)) }; + + let mut waitlist = Waitlist::default(); + assert_eq!(0, waitlist.queue.len()); + + waitlist.push(w.clone(), QueueId(Reverse(4))); + waitlist.push(w.clone(), QueueId(Reverse(2))); + waitlist.push(w.clone(), QueueId(Reverse(8))); + waitlist.push(w.clone(), QUEUE_END_ID); + waitlist.push(w.clone(), QueueId(Reverse(10))); + + waitlist.remove(QueueId(Reverse(8))); + + assert_eq!(4, waitlist.queue.len()); + + let (_, id) = waitlist.queue.pop().unwrap(); + assert_eq!(2, id.0 .0); + let (_, id) = waitlist.queue.pop().unwrap(); + assert_eq!(4, id.0 .0); + let (_, id) = waitlist.queue.pop().unwrap(); + assert_eq!(10, id.0 .0); + let (_, id) = waitlist.queue.pop().unwrap(); + assert_eq!(QUEUE_END_ID, id); + + assert_eq!(0, waitlist.queue.len()); + } + #[cfg(feature = "nightly")] mod bench { use futures_util::future::{FutureExt, TryFutureExt}; diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index b60066d1..2a704dbc 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -76,7 +76,7 @@ impl Future for Recycler { $self.discard.push($conn.close_conn().boxed()); } else { exchange.available.push_back($conn.into()); - if let Some(w) = exchange.waiting.pop_front() { + if let Some(w) = exchange.waiting.pop() { w.wake(); } } @@ -163,7 +163,7 @@ impl Future for Recycler { let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= self.discarded; for _ in 0..self.discarded { - if let Some(w) = exchange.waiting.pop_front() { + if let Some(w) = exchange.waiting.pop() { w.wake(); } } @@ -197,7 +197,7 @@ impl Future for Recycler { if self.inner.closed.load(Ordering::Acquire) { // `DisconnectPool` might still wait to be woken up. let mut exchange = self.inner.exchange.lock().unwrap(); - while let Some(w) = exchange.waiting.pop_front() { + while let Some(w) = exchange.waiting.pop() { w.wake(); } // we're about to exit, so there better be no outstanding connections diff --git a/src/error.rs b/src/error/mod.rs similarity index 95% rename from src/error.rs rename to src/error/mod.rs index 8e350839..81087260 100644 --- a/src/error.rs +++ b/src/error/mod.rs @@ -8,6 +8,8 @@ pub use url::ParseError; +mod tls; + use mysql_common::{ named_params::MixedParamsError, params::MissingNamedParameterError, proto::codec::error::PacketCodecError, row::Row, value::Value, @@ -53,10 +55,9 @@ impl Error { pub enum IoError { #[error("Input/output error: {}", _0)] Io(#[source] io::Error), - - #[cfg(not(target_os = "wasi"))] + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] #[error("TLS error: `{}'", _0)] - Tls(#[source] native_tls::Error), + Tls(#[source] tls::TlsError), } /// This type represents MySql server error. @@ -156,6 +157,12 @@ pub enum DriverError { #[error("LOCAL INFILE error: {}", _0)] LocalInfile(#[from] LocalInfileError), + + #[error("No private key found in the file specified")] + NoKeyFound, + + #[error("Client asked for SSL but server does not have this capability")] + NoClientSslFlagFromServer, } #[derive(Debug, Error)] @@ -220,10 +227,11 @@ impl From for Error { Error::Url(err) } } -#[cfg(not(target_os = "wasi"))] + +#[cfg(feature = "native-tls-tls")] impl From for IoError { fn from(err: native_tls::Error) -> Self { - IoError::Tls(err) + IoError::Tls(tls::TlsError::TlsError(err)) } } diff --git a/src/error/tls/mod.rs b/src/error/tls/mod.rs new file mode 100644 index 00000000..220ed850 --- /dev/null +++ b/src/error/tls/mod.rs @@ -0,0 +1,10 @@ +#![cfg(any(feature = "native-tls", feature = "rustls-tls"))] + +pub mod native_tls_error; +pub mod rustls_error; + +#[cfg(feature = "native-tls")] +pub use native_tls_error::TlsError; + +#[cfg(feature = "rustls")] +pub use rustls_error::TlsError; diff --git a/src/error/tls/native_tls_error.rs b/src/error/tls/native_tls_error.rs new file mode 100644 index 00000000..8ca8b6cb --- /dev/null +++ b/src/error/tls/native_tls_error.rs @@ -0,0 +1,45 @@ +#![cfg(feature = "native-tls")] + +use std::fmt::Display; + +#[derive(Debug)] +pub enum TlsError { + TlsError(native_tls::Error), + TlsHandshakeError(native_tls::HandshakeError), +} + +impl From for crate::Error { + fn from(err: TlsError) -> crate::Error { + crate::Error::Io(crate::error::IoError::Tls(err)) + } +} + +impl From for crate::Error { + fn from(err: native_tls::Error) -> crate::Error { + crate::Error::Io(crate::error::IoError::Tls(TlsError::TlsError(err))) + } +} + +impl From> for crate::Error { + fn from(err: native_tls::HandshakeError) -> crate::Error { + crate::Error::Io(crate::error::IoError::Tls(TlsError::TlsHandshakeError(err))) + } +} + +impl std::error::Error for TlsError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + TlsError::TlsError(e) => Some(e), + TlsError::TlsHandshakeError(e) => Some(e), + } + } +} + +impl Display for TlsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TlsError::TlsError(e) => e.fmt(f), + TlsError::TlsHandshakeError(e) => e.fmt(f), + } + } +} diff --git a/src/error/tls/rustls_error.rs b/src/error/tls/rustls_error.rs new file mode 100644 index 00000000..2ee67d39 --- /dev/null +++ b/src/error/tls/rustls_error.rs @@ -0,0 +1,72 @@ +#![cfg(feature = "rustls")] + +use std::fmt::Display; + +#[derive(Debug)] +pub enum TlsError { + Tls(rustls::Error), + Pki(webpki::Error), + InvalidDnsName(webpki::InvalidDnsNameError), +} + +impl From for crate::Error { + fn from(e: TlsError) -> Self { + crate::Error::Io(crate::error::IoError::Tls(e)) + } +} + +impl From for TlsError { + fn from(e: rustls::Error) -> Self { + TlsError::Tls(e) + } +} + +impl From for TlsError { + fn from(e: webpki::InvalidDnsNameError) -> Self { + TlsError::InvalidDnsName(e) + } +} + +impl From for TlsError { + fn from(e: webpki::Error) -> Self { + TlsError::Pki(e) + } +} + +impl From for crate::Error { + fn from(e: rustls::Error) -> Self { + crate::Error::Io(crate::error::IoError::Tls(e.into())) + } +} + +impl From for crate::Error { + fn from(e: webpki::Error) -> Self { + crate::Error::Io(crate::error::IoError::Tls(e.into())) + } +} + +impl From for crate::Error { + fn from(e: webpki::InvalidDnsNameError) -> Self { + crate::Error::Io(crate::error::IoError::Tls(e.into())) + } +} + +impl std::error::Error for TlsError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + TlsError::Tls(e) => Some(e), + TlsError::Pki(e) => Some(e), + TlsError::InvalidDnsName(e) => Some(e), + } + } +} + +impl Display for TlsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TlsError::Tls(e) => e.fmt(f), + TlsError::Pki(e) => e.fmt(f), + TlsError::InvalidDnsName(e) => e.fmt(f), + } + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs index e1dde8e4..5dadb985 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -11,8 +11,6 @@ pub use self::{read_packet::ReadPacket, write_packet::WritePacket}; use bytes::BytesMut; use futures_core::{ready, stream}; use mysql_common::proto::codec::PacketCodec as PacketCodecInner; -#[cfg(not(target_os = "wasi"))] -use native_tls::{Certificate, Identity, TlsConnector}; use pin_project::pin_project; #[cfg(not(target_os = "wasi"))] use socket2::{Socket as Socket2Socket, TcpKeepalive}; @@ -22,18 +20,16 @@ use tokio::{ io::{AsyncRead, AsyncWrite, ErrorKind::Interrupted, ReadBuf}, net::TcpStream, }; -use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; +use tokio_util::codec::{Decoder, Encoder, Framed}; #[cfg(unix)] use std::path::Path; use std::{ fmt, - fs::File, future::Future, io::{ self, ErrorKind::{BrokenPipe, NotConnected, Other}, - Read, }, mem::replace, ops::{Deref, DerefMut}, @@ -42,15 +38,13 @@ use std::{ time::Duration, }; -use crate::{ - buffer_pool::PooledBuf, - error::IoError, - opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT}, -}; +use crate::{buffer_pool::PooledBuf, error::IoError, opts::HostPortOrUrl}; #[cfg(unix)] use crate::io::socket::Socket; +mod tls; + macro_rules! with_interrupted { ($e:expr) => { loop { @@ -121,8 +115,10 @@ impl Encoder for PacketCodec { #[derive(Debug)] pub(crate) enum Endpoint { Plain(Option), - #[cfg(not(target_os = "wasi"))] + #[cfg(feature = "native-tls-tls")] Secure(#[pin] tokio_native_tls::TlsStream), + #[cfg(feature = "rustls-tls")] + Secure(#[pin] tokio_rustls::client::TlsStream), #[cfg(unix)] Socket(#[pin] Socket), } @@ -154,6 +150,14 @@ impl Future for CheckTcpStream<'_> { } impl Endpoint { + #[cfg(all(any(feature = "native-tls-tls", feature = "rustls-tls"), unix))] + fn is_socket(&self) -> bool { + match self { + Self::Socket(_) => true, + _ => false, + } + } + /// Checks, that connection is alive. async fn check(&mut self) -> std::result::Result<(), IoError> { //return Ok(()); @@ -162,11 +166,17 @@ impl Endpoint { CheckTcpStream(stream).await?; Ok(()) } - #[cfg(not(target_os = "wasi"))] + #[cfg(feature = "native-tls-tls")] Endpoint::Secure(tls_stream) => { CheckTcpStream(tls_stream.get_mut().get_mut().get_mut()).await?; Ok(()) } + #[cfg(feature = "rustls-tls")] + Endpoint::Secure(tls_stream) => { + let stream = tls_stream.get_mut().0; + CheckTcpStream(stream).await?; + Ok(()) + } #[cfg(unix)] Endpoint::Socket(socket) => { socket.write(&[]).await?; @@ -175,79 +185,42 @@ impl Endpoint { Endpoint::Plain(None) => unreachable!(), } } - #[cfg(not(target_os = "wasi"))] + + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub fn is_secure(&self) -> bool { matches!(self, Endpoint::Secure(_)) } + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls")))] + pub async fn _make_secure( + &mut self, + _domain: String, + _ssl_opts: crate::SslOpts, + ) -> crate::error::Result<()> { + panic!( + "Client had asked for TLS connection but TLS support is disabled. \ + Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]" + ) + } + pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> { match *self { Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?, Endpoint::Plain(None) => unreachable!(), - #[cfg(not(target_os = "wasi"))] + #[cfg(feature = "native-tls-tls")] Endpoint::Secure(ref stream) => { stream.get_ref().get_ref().get_ref().set_nodelay(val)? } + #[cfg(feature = "rustls-tls")] + Endpoint::Secure(ref stream) => { + let stream = stream.get_ref().0; + stream.set_nodelay(val)?; + } #[cfg(unix)] Endpoint::Socket(_) => (/* inapplicable */), } Ok(()) } - #[cfg(not(target_os = "wasi"))] - pub async fn make_secure( - &mut self, - domain: String, - ssl_opts: SslOpts, - ) -> std::result::Result<(), IoError> { - #[cfg(unix)] - if let Endpoint::Socket(_) = self { - // inapplicable - return Ok(()); - } - - let mut builder = TlsConnector::builder(); - if let Some(root_cert_path) = ssl_opts.root_cert_path() { - let mut root_cert_data = vec![]; - let mut root_cert_file = File::open(root_cert_path)?; - root_cert_file.read_to_end(&mut root_cert_data)?; - - let root_certs = Certificate::from_der(&*root_cert_data) - .map(|x| vec![x]) - .or_else(|_| { - pem::parse_many(&*root_cert_data) - .unwrap_or_default() - .iter() - .map(pem::encode) - .map(|s| Certificate::from_pem(s.as_bytes())) - .collect() - })?; - - for root_cert in root_certs { - builder.add_root_certificate(root_cert); - } - } - if let Some(pkcs12_path) = ssl_opts.pkcs12_path() { - let der = std::fs::read(pkcs12_path)?; - let identity = Identity::from_pkcs12(&*der, ssl_opts.password().unwrap_or(""))?; - builder.identity(identity); - } - builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); - builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs()); - let tls_connector: tokio_native_tls::TlsConnector = builder.build()?.into(); - - *self = match self { - Endpoint::Plain(stream) => { - let stream = stream.take().unwrap(); - let tls_stream = tls_connector.connect(&*domain, stream).await?; - Endpoint::Secure(tls_stream) - } - Endpoint::Secure(_) => unreachable!(), - #[cfg(unix)] - Endpoint::Socket(_) => unreachable!(), - }; - - Ok(()) - } } impl From for Endpoint { @@ -262,13 +235,18 @@ impl From for Endpoint { Endpoint::Socket(socket) } } -#[cfg(not(target_os = "wasi"))] + +#[cfg(feature = "native-tls-tls")] impl From> for Endpoint { fn from(stream: tokio_native_tls::TlsStream) -> Self { Endpoint::Secure(stream) } } +/* TODO +#[cfg(feature = "rustls-tls")] +*/ + impl AsyncRead for Endpoint { fn poll_read( self: Pin<&mut Self>, @@ -280,7 +258,9 @@ impl AsyncRead for Endpoint { EndpointProj::Plain(ref mut stream) => { Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf) } - #[cfg(not(target_os = "wasi"))] + #[cfg(feature = "native-tls-tls")] + EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf), + #[cfg(feature = "rustls-tls")] EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf), #[cfg(unix)] EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf), @@ -299,7 +279,9 @@ impl AsyncWrite for Endpoint { EndpointProj::Plain(ref mut stream) => { Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf) } - #[cfg(not(target_os = "wasi"))] + #[cfg(feature = "native-tls-tls")] + EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf), + #[cfg(feature = "rustls-tls")] EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf), #[cfg(unix)] EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf), @@ -315,7 +297,9 @@ impl AsyncWrite for Endpoint { EndpointProj::Plain(ref mut stream) => { Pin::new(stream.as_mut().unwrap()).poll_flush(cx) } - #[cfg(not(target_os = "wasi"))] + #[cfg(feature = "native-tls-tls")] + EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx), + #[cfg(feature = "rustls-tls")] EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx), #[cfg(unix)] EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx), @@ -331,7 +315,9 @@ impl AsyncWrite for Endpoint { EndpointProj::Plain(ref mut stream) => { Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx) } - #[cfg(not(target_os = "wasi"))] + #[cfg(feature = "native-tls-tls")] + EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx), + #[cfg(feature = "rustls-tls")] EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx), #[cfg(unix)] EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx), @@ -368,7 +354,7 @@ impl Stream { pub(crate) async fn connect_tcp( addr: &HostPortOrUrl, - keepalive: Option, + _keepalive: Option, ) -> io::Result { let tcp_stream = match addr { HostPortOrUrl::HostPort(host, port) => { @@ -445,7 +431,8 @@ impl Stream { self.codec = Some(Box::new(codec)); Ok(()) } - #[cfg(not(target_os = "wasi"))] + + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub(crate) fn is_secure(&self) -> bool { self.codec.as_ref().unwrap().get_ref().is_secure() } @@ -526,6 +513,9 @@ mod test { let endpoint = stream.codec.as_mut().unwrap().get_ref(); let stream = match endpoint { super::Endpoint::Plain(Some(stream)) => stream, + #[cfg(feature = "rustls-tls")] + super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0, + #[cfg(feature = "native-tls")] super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(), _ => unreachable!(), }; diff --git a/src/io/tls/mod.rs b/src/io/tls/mod.rs new file mode 100644 index 00000000..92f5e7c2 --- /dev/null +++ b/src/io/tls/mod.rs @@ -0,0 +1,4 @@ +#![cfg(any(feature = "native-tls", feature = "rustls"))] + +mod native_tls_io; +mod rustls_io; diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs new file mode 100644 index 00000000..910387d7 --- /dev/null +++ b/src/io/tls/native_tls_io.rs @@ -0,0 +1,65 @@ +#![cfg(feature = "native-tls")] + +use std::{fs::File, io::Read}; + +use native_tls::{Certificate, Identity, TlsConnector}; + +use crate::io::Endpoint; +use crate::{Result, SslOpts}; + +impl Endpoint { + pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { + #[cfg(unix)] + if self.is_socket() { + // won't secure socket connection + return Ok(()); + } + + let mut builder = TlsConnector::builder(); + if let Some(root_cert_path) = ssl_opts.root_cert_path() { + let mut root_cert_data = vec![]; + let mut root_cert_file = File::open(root_cert_path)?; + root_cert_file.read_to_end(&mut root_cert_data)?; + + let root_certs = Certificate::from_der(&*root_cert_data) + .map(|x| vec![x]) + .or_else(|_| { + pem::parse_many(&*root_cert_data) + .unwrap_or_default() + .iter() + .map(pem::encode) + .map(|s| Certificate::from_pem(s.as_bytes())) + .collect() + })?; + + for root_cert in root_certs { + builder.add_root_certificate(root_cert); + } + } + + if let Some(client_identity) = ssl_opts.client_identity() { + let pkcs12_path = client_identity.pkcs12_path(); + let password = client_identity.password().unwrap_or(""); + + let der = std::fs::read(pkcs12_path)?; + let identity = Identity::from_pkcs12(&*der, password)?; + builder.identity(identity); + } + builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); + builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs()); + let tls_connector: tokio_native_tls::TlsConnector = builder.build()?.into(); + + *self = match self { + Endpoint::Plain(ref mut stream) => { + let stream = stream.take().unwrap(); + let tls_stream = tls_connector.connect(&*domain, stream).await?; + Endpoint::Secure(tls_stream) + } + Endpoint::Secure(_) => unreachable!(), + #[cfg(unix)] + Endpoint::Socket(_) => unreachable!(), + }; + + Ok(()) + } +} diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs new file mode 100644 index 00000000..654581d3 --- /dev/null +++ b/src/io/tls/rustls_io.rs @@ -0,0 +1,148 @@ +#![cfg(feature = "rustls-tls")] + +use std::{convert::TryInto, sync::Arc}; + +use rustls::{ + client::{ServerCertVerifier, WebPkiVerifier}, + Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, +}; + +use tokio::{fs::File, io::AsyncReadExt}; + +use rustls_pemfile::certs; +use tokio_rustls::TlsConnector; + +use crate::{io::Endpoint, Result, SslOpts}; + +impl Endpoint { + pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { + #[cfg(unix)] + if self.is_socket() { + // won't secure socket connection + return Ok(()); + } + + let mut root_store = RootCertStore::empty(); + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + if let Some(root_cert_path) = ssl_opts.root_cert_path() { + let mut root_cert_data = vec![]; + let mut root_cert_file = File::open(root_cert_path).await?; + root_cert_file.read_to_end(&mut root_cert_data).await?; + + let mut root_certs = Vec::new(); + for cert in certs(&mut &*root_cert_data)? { + root_certs.push(Certificate(cert)); + } + + if root_certs.is_empty() && !root_cert_data.is_empty() { + root_certs.push(Certificate(root_cert_data)); + } + + for cert in &root_certs { + root_store.add(cert)?; + } + } + + let config_builder = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store.clone()); + + let mut config = if let Some(identity) = ssl_opts.client_identity() { + let (cert_chain, priv_key) = identity.load()?; + config_builder.with_single_cert(cert_chain, priv_key)? + } else { + config_builder.with_no_client_auth() + }; + + let server_name = domain + .as_str() + .try_into() + .map_err(|_| webpki::InvalidDnsNameError)?; + let mut dangerous = config.dangerous(); + let web_pki_verifier = WebPkiVerifier::new(root_store, None); + let dangerous_verifier = DangerousVerifier::new( + ssl_opts.accept_invalid_certs(), + ssl_opts.skip_domain_validation(), + web_pki_verifier, + ); + dangerous.set_certificate_verifier(Arc::new(dangerous_verifier)); + + *self = match self { + Endpoint::Plain(ref mut stream) => { + let stream = stream.take().unwrap(); + + let client_config = Arc::new(config); + let tls_connector = TlsConnector::from(client_config); + let connection = tls_connector.connect(server_name, stream).await?; + + Endpoint::Secure(connection) + } + Endpoint::Secure(_) => unreachable!(), + #[cfg(unix)] + Endpoint::Socket(_) => unreachable!(), + }; + + Ok(()) + } +} + +struct DangerousVerifier { + accept_invalid_certs: bool, + skip_domain_validation: bool, + verifier: WebPkiVerifier, +} + +impl DangerousVerifier { + fn new( + accept_invalid_certs: bool, + skip_domain_validation: bool, + verifier: WebPkiVerifier, + ) -> Self { + Self { + accept_invalid_certs, + skip_domain_validation, + verifier, + } + } +} + +impl ServerCertVerifier for DangerousVerifier { + fn verify_server_cert( + &self, + end_entity: &Certificate, + intermediates: &[Certificate], + server_name: &rustls::ServerName, + scts: &mut dyn Iterator, + ocsp_response: &[u8], + now: std::time::SystemTime, + ) -> std::result::Result { + if self.accept_invalid_certs { + Ok(rustls::client::ServerCertVerified::assertion()) + } else { + match self.verifier.verify_server_cert( + end_entity, + intermediates, + server_name, + scts, + ocsp_response, + now, + ) { + Ok(assertion) => Ok(assertion), + Err(ref e) + if e.to_string().contains("CertNotValidForName") + && self.skip_domain_validation => + { + Ok(rustls::client::ServerCertVerified::assertion()) + } + Err(e) => Err(e), + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 942ff3c2..72393cb1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,85 @@ //! mysql_async = "" //! ``` //! +//! # Crate Features +//! +//! Default feature set is wide – it includes all default [`mysql_common` features][myslqcommonfeatures] +//! as well as `native-tls`-based TLS support. +//! +//! ## List Of Features +//! +//! * `minimal` – enables only necessary features (at the moment the only necessary feature +//! is `flate2` backend). Enables: +//! +//! - `flate2/zlib" +//! +//! **Example:** +//! +//! ```toml +//! [dependencies] +//! mysql_async = { version = "*", default-features = false, features = ["minimal"]} +//! ``` +//! +//! **Note:* it is possible to use another `flate2` backend by directly choosing it: +//! +//! ```toml +//! [dependencies] +//! mysql_async = { version = "*", default-features = false } +//! flate2 = { version = "*", default-features = false, features = ["rust_backend"] } +//! ``` +//! +//! * `default` – enables the following set of crate's and dependencies' features: +//! +//! - `native-tls-tls` +//! - `flate2/zlib" +//! - `mysql_common/bigdecimal03` +//! - `mysql_common/rust_decimal` +//! - `mysql_common/time03` +//! - `mysql_common/uuid` +//! - `mysql_common/frunk` +//! +//! * `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. +//! +//! **Example:** +//! +//! ```toml +//! [dependencies] +//! mysql_async = { version = "*", default-features = false, features = ["default-rustls"] } +//! ``` +//! +//! * `native-tls-tls` – enables `native-tls`-based TLS support _(conflicts with `rustls-tls`)_ +//! +//! **Example:** +//! +//! ```toml +//! [dependencies] +//! mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } +//! +//! * `rustls-tls` – enables `native-tls`-based TLS support _(conflicts with `native-tls-tls`)_ +//! +//! **Example:** +//! +//! ```toml +//! [dependencies] +//! mysql_async = { version = "*", default-features = false, features = ["rustls-tls"] } +//! +//! [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features +//! +//! # TLS/SSL Support +//! +//! SSL support comes in two flavors: +//! +//! 1. Based on native-tls – this is the default option, that usually works without pitfalls +//! (see the `native-tls-tls` crate feature). +//! +//! 2. Based on rustls – TLS backend written in Rust (see the `rustls-tls` crate feature). +//! +//! Please also note a few things about rustls: +//! - it will fail if you'll try to connect to the server by its IP address, +//! hostname is required; +//! - it, most likely, won't work on windows, at least with default server certs, +//! generated by the MySql installer. +//! //! # Example //! //! ```rust @@ -361,6 +440,10 @@ pub use self::query::QueryWithParams; #[doc(inline)] pub use self::queryable::transaction::IsolationLevel; +#[doc(inline)] +#[cfg(any(feature = "rustls", feature = "native-tls"))] +pub use self::opts::ClientIdentity; + #[doc(inline)] pub use self::opts::{ Opts, OptsBuilder, PoolConstraints, PoolOpts, SslOpts, DEFAULT_INACTIVE_CONNECTION_TTL, @@ -509,7 +592,7 @@ pub mod test_misc { } url } else { - "mysql://root:password@127.0.0.1:3307/mysql".into() + "mysql://root:password@localhost:3307/mysql".into() } }; } diff --git a/src/opts.rs b/src/opts/mod.rs similarity index 97% rename from src/opts.rs rename to src/opts/mod.rs index 2e98c41f..b8386ee4 100644 --- a/src/opts.rs +++ b/src/opts/mod.rs @@ -6,6 +6,15 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +mod native_tls_opts; +mod rustls_opts; + +#[cfg(feature = "native-tls")] +pub use native_tls_opts::ClientIdentity; + +#[cfg(feature = "rustls-tls")] +pub use rustls_opts::ClientIdentity; + use percent_encoding::percent_decode; use url::{Host, Url}; @@ -109,29 +118,36 @@ impl HostPortOrUrl { /// ``` /// # use mysql_async::SslOpts; /// # use std::path::Path; +/// # #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] +/// # use mysql_async::ClientIdentity; +/// // With native-tls +/// # #[cfg(feature = "native-tls-tls")] +/// let ssl_opts = SslOpts::default() +/// .with_client_identity(Some(ClientIdentity::new(Path::new("/path")) +/// .with_password("******") +/// )); +/// +/// // With rustls +/// # #[cfg(feature = "rustls-tls")] /// let ssl_opts = SslOpts::default() -/// .with_pkcs12_path(Some(Path::new("/path"))) -/// .with_password(Some("******")); +/// .with_client_identity(Some(ClientIdentity::new( +/// Path::new("/path/to/chain"), +/// Path::new("/path/to/priv_key") +/// ))); /// ``` #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct SslOpts { - pkcs12_path: Option>, - password: Option>, + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + client_identity: Option, root_cert_path: Option>, skip_domain_validation: bool, accept_invalid_certs: bool, } impl SslOpts { - /// Sets path to the pkcs12 archive (in `der` format). - pub fn with_pkcs12_path>>(mut self, pkcs12_path: Option) -> Self { - self.pkcs12_path = pkcs12_path.map(Into::into); - self - } - - /// Sets the password for a pkcs12 archive (defaults to `None`). - pub fn with_password>>(mut self, password: Option) -> Self { - self.password = password.map(Into::into); + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + pub fn with_client_identity(mut self, identity: Option) -> Self { + self.client_identity = identity; self } @@ -160,12 +176,9 @@ impl SslOpts { self } - pub fn pkcs12_path(&self) -> Option<&Path> { - self.pkcs12_path.as_ref().map(|x| x.as_ref()) - } - - pub fn password(&self) -> Option<&str> { - self.password.as_ref().map(AsRef::as_ref) + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + pub fn client_identity(&self) -> Option<&ClientIdentity> { + self.client_identity.as_ref() } pub fn root_cert_path(&self) -> Option<&Path> { diff --git a/src/opts/native_tls_opts.rs b/src/opts/native_tls_opts.rs new file mode 100644 index 00000000..49eb4c46 --- /dev/null +++ b/src/opts/native_tls_opts.rs @@ -0,0 +1,41 @@ +#![cfg(feature = "native-tls")] + +use std::{borrow::Cow, path::Path}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ClientIdentity { + pkcs12_path: Cow<'static, Path>, + password: Option>, +} + +impl ClientIdentity { + /// Creates new identity with the given path to the pkcs12 archive. + pub fn new(pkcs12_path: T) -> Self + where + T: Into>, + { + Self { + pkcs12_path: pkcs12_path.into(), + password: None, + } + } + + /// Sets the archive password. + pub fn with_password(mut self, pass: T) -> Self + where + T: Into>, + { + self.password = Some(pass.into()); + self + } + + /// Returns the pkcs12 archive path. + pub fn pkcs12_path(&self) -> &Path { + self.pkcs12_path.as_ref() + } + + /// Returns the archive password. + pub fn password(&self) -> Option<&str> { + self.password.as_ref().map(AsRef::as_ref) + } +} diff --git a/src/opts/rustls_opts.rs b/src/opts/rustls_opts.rs new file mode 100644 index 00000000..143dc62d --- /dev/null +++ b/src/opts/rustls_opts.rs @@ -0,0 +1,87 @@ +#![cfg(feature = "rustls-tls")] + +use rustls::{Certificate, PrivateKey}; +use rustls_pemfile::{certs, rsa_private_keys}; + +use std::{borrow::Cow, path::Path}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ClientIdentity { + cert_chain_path: Cow<'static, Path>, + priv_key_path: Cow<'static, Path>, +} + +impl ClientIdentity { + /// Creates new identity. + /// + /// `cert_chain_path` - path to a certificate chain (in PEM or DER) + /// `priv_key_path` - path to a private key (in DER or PEM) (it'll take the first one) + pub fn new(cert_chain_path: T, priv_key_path: U) -> Self + where + T: Into>, + U: Into>, + { + Self { + cert_chain_path: cert_chain_path.into(), + priv_key_path: priv_key_path.into(), + } + } + + /// Sets the certificate chain path (in DER or PEM). + pub fn with_cert_chain_path(mut self, cert_chain_path: T) -> Self + where + T: Into>, + { + self.cert_chain_path = cert_chain_path.into(); + self + } + + /// Sets the private key path (in DER or PEM) (it'll take the first one). + pub fn with_priv_key_path(mut self, priv_key_path: T) -> Self + where + T: Into>, + { + self.priv_key_path = priv_key_path.into(); + self + } + + /// Returns the certificate chain path. + pub fn cert_chain_path(&self) -> &Path { + self.cert_chain_path.as_ref() + } + + /// Returns the private key path. + pub fn priv_key_path(&self) -> &Path { + self.priv_key_path.as_ref() + } + + pub(crate) fn load(&self) -> crate::Result<(Vec, PrivateKey)> { + let cert_data = std::fs::read(self.cert_chain_path.as_ref())?; + let key_data = std::fs::read(self.priv_key_path.as_ref())?; + + let mut cert_chain = Vec::new(); + if std::str::from_utf8(&cert_data).is_err() { + cert_chain.push(Certificate(cert_data)); + } else { + for cert in certs(&mut &*cert_data)? { + cert_chain.push(Certificate(cert)); + } + } + + let priv_key; + if std::str::from_utf8(&key_data).is_err() { + priv_key = Some(PrivateKey(key_data)); + } else { + priv_key = rsa_private_keys(&mut &*key_data)? + .into_iter() + .take(1) + .map(PrivateKey) + .next(); + } + + Ok(( + cert_chain, + priv_key.ok_or_else(|| crate::Error::from(crate::DriverError::NoKeyFound))?, + )) + } +} diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index b6bcc2b6..52bb4252 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -8,8 +8,8 @@ use futures_util::FutureExt; use mysql_common::{ + constants::MAX_PAYLOAD_LEN, io::ParseBuf, - packets::{OkPacketDeserializer, ResultSetTerminator}, proto::{Binary, Text}, row::RowDeserializer, value::ServerSide, @@ -42,10 +42,11 @@ pub trait Protocol: fmt::Debug + Send + Sync + 'static { fn result_set_meta(columns: Arc<[Column]>) -> ResultSetMeta; fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result; fn is_last_result_set_packet(capabilities: CapabilityFlags, packet: &[u8]) -> bool { - packet.len() < 8 - && ParseBuf(packet) - .parse::>(capabilities) - .is_ok() + if capabilities.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) { + packet[0] == 0xFE && packet.len() < MAX_PAYLOAD_LEN + } else { + packet[0] == 0xFE && packet.len() < 8 + } } } diff --git a/test/ca-cert.der b/test/ca-cert.der index 57f97da2..4c5e1771 100644 Binary files a/test/ca-cert.der and b/test/ca-cert.der differ diff --git a/test/client.p12 b/test/client.p12 index 1ff7eb9f..d7f1b666 100644 Binary files a/test/client.p12 and b/test/client.p12 differ diff --git a/test/key.pem b/test/key.pem new file mode 100644 index 00000000..44dfa86d --- /dev/null +++ b/test/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAoHCE8pwgBvudC43lD3P+QbqlH66Lrjb1MJ8rkS+4JIFSaKIv +V32HIeXGhk3oaQ6CGo+E8nSiToP47s9GvNk86AGFLpvMbQVSliPxlEGrmiVcPyys +mC9FmHaEW68XQKSMxtwf4/NDC1gsIwnT4jfmzF8VLaTQCLD9KPr+o/waRl3cHL3P +1qMngutYsGmslpY5X9RW9C0AwkP/+dSXmucFW7BauK+f9OlXOKloZKw13BY0ajT3 +rz6DJ+LK/qYAvAM0Fjl5wjfVRGqoOTNzzixmbKLqdAdt8vzXmqAQmMy+g6W4C2Zr ++goQ5q5mPrvPScS4uuwiLb7pCUqPga99ryiQiwIDAQABAoIBAGJceXWP+CavzdlO +lfdCYsgDWMayqRoWwX2cqAYr3lYrHs3dWO7nm5hRmcOvMeRuq58DDDvk+7jtOgmW +9ERFXwzSGce4Zr0T/UzlHm+JT16CtypYBjyLBrzxNDZNgxDzkQc93yNOeXUUCoM0 +vD09jncPeBlyqMQbVinwr3rzzVwDqIiojTgxoq9MyR2a/bKQky94bQvc0qOpBn+z +RtUvuW2s4ZGGf0BWIGWaSTop0oc16ulSCyh71WUUVILp2JYgA3Jg8b7jkdkZmype +1KpIfdzan9YUtXkka7S4kHrEs7W8K+yHRkDCeOksQux7lbkH1btvxiuEiR0CUig0 +b68iBnkCgYEA0nZe8I9W8gA/Ruo6xcojowpkChRFb7JPNcxf4RyOXjvsz08v1k24 +qFzcouJzINh0tKcRYCm0HinNg1CHdpSGpSBCeVirAt3QPyojlYiqeVY3GYiQ0UW5 +hbJDVB7qX39pOG9egyGQpIriNHI7iKkbnYW81y93MuZVO0Q6sIZFiscCgYEAwyda +vOSpX6oj+MheMfdExX7F4pvvMk1v1BsojrT/U7gAkBEruud2nKAH7o6sLBdsmioI +T/+S4dEfDNM+YZSvBFi6Drrn1XlVGyReRopJZ1quoseDDVAbKvU2OcaSASloiLQZ +4omg/6nL2Y/NlHBKgqQIW/pRpZ/jPuesUsDhaB0CgYBOZ7jAx7WtXDg2lAYnL0IN +eE6CjsC7duMZeLTzaS8EnjB/ntGEddnoJwgvSkt3ngwETQUlHQQ0BIDCfdqpa3Wp +yJXbHRRAciAll+4/w/U2VM8cHQtOWzpdO2bnzMiloRKy6pJ8KaH4GqFgxnm1VMKr +8WnDhLRUawivlqCCqNL5ewKBgCAvBU/Rhf042eXVZXNoC/dmCMxuWuw4yRB5yh5+ +yvzLg4w+yK9yLKV33tcAwHQlCMwD0ose4uJK0owS6l69Xn+hAk4blNAnyllHjiSj ++acJ1XMS5BH1/AUBm4e7r6hxY8Pnr70kZWDEZ9HhXU31ltQkqRxCE+T0kU12d3zO +Ql4hAoGAevwdcteN8nMOzpHXo7cY2H9TGimKlsowg2riWYWJHndnmSoAgNYKF+l2 +n3UAbaowCNRtlQGNFBJLZmqvSbu1ruP2S20ZfWBVm3WFENSpg4gcwT+Y0tB8Xp2r +jQfxFzIY5GLzhIPq7eAg0IXWoZ+AkDbfMN7weKhkdO8FhbuFfww= +-----END RSA PRIVATE KEY-----