Skip to content

Commit

Permalink
allow multiple tls features
Browse files Browse the repository at this point in the history
fix imports

fix vendored-openssl
  • Loading branch information
esheppa committed Sep 7, 2022
1 parent 995170e commit 53c1086
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 18 deletions.
48 changes: 48 additions & 0 deletions src/client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,48 @@ pub struct Config {
pub(crate) encryption: EncryptionLevel,
pub(crate) trust: TrustConfig,
pub(crate) auth: AuthMethod,
pub(crate) tls_choice: TlsChoice,
}

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum TlsChoice {
#[cfg(not(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
)))]
None,
#[cfg(feature = "rustls")]
Rustls,
#[cfg(feature = "native-tls")]
NativeTls,
#[cfg(feature = "vendored-openssl")]
Openssl,
}

impl Default for TlsChoice {
#[allow(unreachable_code, clippy::needless_return)]
fn default() -> TlsChoice {
#[cfg(feature = "rustls")]
{
return TlsChoice::Rustls;
}
#[cfg(feature = "native-tls")]
{
return TlsChoice::NativeTls;
}
#[cfg(feature = "vendored-openssl")]
{
return TlsChoice::Openssl
}

#[cfg(not(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
)))]
TlsChoice::None
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -62,6 +104,7 @@ impl Default for Config {
encryption: EncryptionLevel::NotSupported,
trust: TrustConfig::Default,
auth: AuthMethod::None,
tls_choice: TlsChoice::default(),
}
}
}
Expand Down Expand Up @@ -120,6 +163,11 @@ impl Config {
self.encryption = encryption;
}

/// Set the choice of Tls
pub fn tls_choice(&mut self, tls_choice: TlsChoice) {
self.tls_choice = tls_choice;
}

/// If set, the server certificate will not be validated and it is accepted
/// as-is.
///
Expand Down
23 changes: 19 additions & 4 deletions src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
feature = "native-tls",
feature = "vendored-openssl"
))]
use crate::client::{tls::TlsPreloginWrapper, tls_stream::create_tls_stream};
use crate::client::{config::TlsChoice, tls::TlsPreloginWrapper, tls_stream};
use crate::{
client::{tls::MaybeTlsStream, AuthMethod, Config},
tds::{
Expand Down Expand Up @@ -442,10 +442,25 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
let Self {
transport, context, ..
} = self;
let mut stream = match transport.into_inner() {
MaybeTlsStream::Raw(tcp) => {
create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await?

let mut stream = match (transport.into_inner(), config.tls_choice) {
#[cfg(feature = "rustls")]
(MaybeTlsStream::Raw(tcp), TlsChoice::Rustls) => {
tls_stream::create_tls_stream_rustls(config, TlsPreloginWrapper::new(tcp))
.await?
}
#[cfg(feature = "vendored-openssl")]
(MaybeTlsStream::Raw(tcp), TlsChoice::Openssl) => {
tls_stream::create_tls_stream_openssl(config, TlsPreloginWrapper::new(tcp))
.await?
}
#[cfg(feature = "native-tls")]
(MaybeTlsStream::Raw(tcp), TlsChoice::NativeTls) => {
tls_stream::create_tls_stream_native_tls(config, TlsPreloginWrapper::new(tcp))
.await?
}
// this should still be fine as the relevant TlsChoices are only
// enabled when the equivalent tls crate is enabled
_ => unreachable!(),
};

Expand Down
121 changes: 108 additions & 13 deletions src/client/tls_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::Config;
use futures::{AsyncRead, AsyncWrite};

use std::{
io,
pin::Pin,
task::{Context, Poll},
};
#[cfg(feature = "native-tls")]
mod native_tls_stream;

Expand All @@ -10,35 +14,126 @@ mod rustls_tls_stream;
#[cfg(feature = "vendored-openssl")]
mod opentls_tls_stream;

#[cfg(feature = "native-tls")]
pub(crate) use native_tls_stream::TlsStream;
// #[cfg(feature = "native-tls")]
// pub(crate) use native_tls_stream::TlsStream as NativeTlsStream;

#[cfg(feature = "rustls")]
pub(crate) use rustls_tls_stream::TlsStream;
// #[cfg(feature = "rustls")]
// pub(crate) use rustls_tls_stream::TlsStream as RustlsTlsStream;

#[cfg(feature = "vendored-openssl")]
pub(crate) use opentls_tls_stream::TlsStream;
// #[cfg(feature = "vendored-openssl")]
// pub(crate) use opentls_tls_stream::TlsStream as OptenSslTlsStream;

pub(crate) enum TlsStream<S: AsyncRead + AsyncWrite + Unpin + Send> {
#[cfg(feature = "vendored-openssl")]
Openssl(opentls_tls_stream::TlsStream<S>),
#[cfg(feature = "rustls")]
Rustls(rustls_tls_stream::TlsStream<S>),
#[cfg(feature = "native-tls")]
NativeTls(native_tls_stream::TlsStream<S>),
}

impl<S> TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
pub(crate) fn get_mut(&mut self) -> &mut S {
match self {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => s.get_mut(),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => s.get_mut(),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => s.get_mut(),
}
}
}

impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_read(cx, buf),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}

impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_write(cx, buf),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_flush(cx),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let inner = Pin::get_mut(self);
match inner {
#[cfg(feature = "vendored-openssl")]
TlsStream::Openssl(s) => Pin::new(s).poll_close(cx),
#[cfg(feature = "rustls")]
TlsStream::Rustls(s) => Pin::new(&mut s.0).poll_close(cx),
#[cfg(feature = "native-tls")]
TlsStream::NativeTls(s) => Pin::new(s).poll_close(cx),
}
}
}

#[cfg(feature = "rustls")]
pub(crate) async fn create_tls_stream<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn create_tls_stream_rustls<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &Config,
stream: S,
) -> crate::Result<TlsStream<S>> {
TlsStream::new(config, stream).await
rustls_tls_stream::TlsStream::new(config, stream)
.await
.map(TlsStream::Rustls)
}

#[cfg(feature = "native-tls")]
pub(crate) async fn create_tls_stream<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn create_tls_stream_native_tls<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &Config,
stream: S,
) -> crate::Result<TlsStream<S>> {
native_tls_stream::create_tls_stream(config, stream).await
native_tls_stream::create_tls_stream(config, stream)
.await
.map(TlsStream::NativeTls)
}

#[cfg(feature = "vendored-openssl")]
pub(crate) async fn create_tls_stream<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn create_tls_stream_openssl<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &Config,
stream: S,
) -> crate::Result<TlsStream<S>> {
opentls_tls_stream::create_tls_stream(config, stream).await
opentls_tls_stream::create_tls_stream(config, stream)
.await
.map(TlsStream::Openssl)
}
2 changes: 1 addition & 1 deletion src/client/tls_stream/rustls_tls_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl From<tokio_rustls::webpki::Error> for Error {
}

pub(crate) struct TlsStream<S: AsyncRead + AsyncWrite + Unpin + Send>(
Compat<tokio_rustls::client::TlsStream<Compat<S>>>,
pub(super) Compat<tokio_rustls::client::TlsStream<Compat<S>>>,
);

struct NoCertVerifier;
Expand Down

0 comments on commit 53c1086

Please sign in to comment.