Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
aviramha committed May 26, 2024
1 parent 4f72433 commit 5f85b4a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 129 deletions.
31 changes: 12 additions & 19 deletions .github/workflows/checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,6 @@ jobs:
test:
name: Test
runs-on: ubuntu-22.04
strategy:
matrix:
tls:
- no-tls
- native-tls
- rustls-native-roots
- rustls-webpki-roots
- openssl
include:
- tls: no-tls
features: native-tls
- tls: rustls-native-roots
features: rustls
- tls: rustls-webpki-roots
features: rustls-webpki
- tls: openssl
features: openssl-tls
steps:
- uses: actions/checkout@v4

Expand All @@ -45,5 +28,15 @@ jobs:
reporter: 'github-pr-check'
github_token: ${{ secrets.GITHUB_TOKEN }}

- name: Run tests
run: cargo test --no-default-features --features '${{ matrix.features }}'
- name: Run tests default
run: cargo test
- name: Run tests rustls-tls-manual-roots
run: cargo test --no-default-features rustls-tls-manual-roots
- name: Run tests rustls-tls-webpki-roots
run: cargo test --no-default-features rustls-tls-webpki-roots
- name: Run tests native-tls-vendored
run: cargo test --no-default-features native-tls-vendored
- name: Run tests native-tls
run: cargo test --no-default-features native-tls
- name: Run tests all features
run: cargo test --all-features
56 changes: 30 additions & 26 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[package]
name = "hyper-proxy2"
version = "0.1.0"
authors = ["Natsuki Ikeguchi <[email protected]>"]
authors = ["MetalBear Tech LTD <[email protected]>"]
description = "A proxy connector for Hyper-based applications"
documentation = "https://docs.rs/hyper-proxy2"
repository = "https://github.com/siketyan/hyper-proxy2"
documentation = "https://docs.rs/hyper-http-proxy"
repository = "https://github.com/metalbear-co/hyper-http-proxy"
readme = "README.md"
keywords = ["hyper", "proxy", "tokio", "ssl"]
categories = ["web-programming::http-client", "asynchronous", "authentication"]
Expand All @@ -15,39 +15,43 @@ rust-version = "1.70.0"
[dependencies]
tokio = { version = "1.35", features = ["io-std", "io-util"] }
hyper = { version = "1.0", features = ["client"] }
hyper-util = { version = "0.1.2", features = ["client", "client-legacy", "tokio"] }
hyper-util = { version = "0.1", features = ["client", "client-legacy", "tokio"] }

tower-service = "0.3.2"
http = "1.0"
futures-util = { version = "0.3.30", default-features = false }
tower-service = "0.3"
http = "1"
futures-util = { version = "0.3", default-features = false }
bytes = "1.5"
pin-project-lite = "0.2.13"
hyper-tls = { version = "0.6.0", optional = true }
tokio-native-tls = { version = "0.3.1", optional = true }
native-tls = { version = "0.2.11", optional = true }
openssl = { version = "0.10.62", optional = true }
tokio-openssl = { version = "0.6.4", optional = true }
tokio-rustls = { version = "0.26", optional = true }
pin-project-lite = "0.2"
hyper-tls = { version = "0.6", optional = true }
tokio-native-tls = { version = "0.3", optional = true }
native-tls = { version = "0.2", optional = true }
tokio-rustls = { version = "0.26", optional = true, default-features = false}
hyper-rustls = { version = "0.27.0", optional = true }

webpki = { version = "0.22", optional = true }
rustls-webpki = { version = "0.102.4", optional = true }
rustls-native-certs = { version = "0.7.0", optional = true }
webpki-roots = { version = "0.26.0", optional = true }
headers = "0.4"

[dev-dependencies]
tokio = { version = "1.35", features = ["full"] }
hyper = { version = "1.0", features = ["client", "http1"] }
hyper-util = { version = "0.1.2", features = ["client", "client-legacy", "http1", "tokio"] }
http-body-util = "0.1.0"
futures = "0.3.30"
hyper-util = { version = "0.1", features = ["client", "client-legacy", "http1", "tokio"] }
http-body-util = "0.1"
futures = "0.3"

[features]
openssl-tls = ["openssl", "tokio-openssl"]
tls = ["tokio-native-tls", "hyper-tls", "native-tls"]
# note that `rustls-base` is not a valid feature on its own - it will configure rustls without root
# certificates!
rustls-base = ["tokio-rustls", "hyper-rustls", "webpki"]
rustls = ["rustls-base", "rustls-native-certs", "hyper-rustls/native-tokio"]
rustls-webpki = ["rustls-base", "webpki-roots", "hyper-rustls/webpki-tokio"]
default = ["tls"]
default = ["default-tls"]
default-tls = ["rustls-tls-native-roots"]
native-tls = ["dep:native-tls", "tokio-native-tls", "hyper-tls", "__tls"]
native-tls-vendored = ["native-tls", "tokio-native-tls?/vendored"]
rustls-tls = ["rustls-tls-webpki-roots"]
rustls-tls-manual-roots = ["__rustls"]
rustls-tls-webpki-roots = ["dep:webpki-roots", "__rustls"]
rustls-tls-native-roots = ["dep:rustls-native-certs", "__rustls"]

__tls = []

# Enables common rustls code.
# Equivalent to rustls-tls-manual-roots but shorter :)
__rustls = ["dep:hyper-rustls", "dep:tokio-rustls", "__tls"]
1 change: 1 addition & 0 deletions LICENSE-MIT.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ The MIT License (MIT)

Copyright (c) 2017 Johann Tuffe
Copyright (c) 2024 Natsuki Ikeguchi
Copyright (c) 2024 MetalBear Tech LTD

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
82 changes: 18 additions & 64 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,21 @@ use tower_service::Service;

pub use stream::ProxyStream;

#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
use native_tls::TlsConnector as NativeTlsConnector;

#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
use tokio_native_tls::TlsConnector;

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
use hyper_rustls::ConfigBuilderExt;

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
use tokio_rustls::TlsConnector;

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
use tokio_rustls::rustls::pki_types::ServerName;

#[cfg(feature = "openssl-tls")]
use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod};

#[cfg(feature = "openssl-tls")]
use tokio_openssl::SslStream;

type BoxError = Box<dyn std::error::Error + Send + Sync>;

/// The Intercept enum to filter connections
Expand Down Expand Up @@ -253,16 +247,13 @@ pub struct ProxyConnector<C> {
proxies: Vec<Proxy>,
connector: C,

#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
tls: Option<NativeTlsConnector>,

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
tls: Option<TlsConnector>,

#[cfg(feature = "openssl-tls")]
tls: Option<OpenSslConnector>,

#[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
#[cfg(not(feature = "__tls"))]
tls: Option<()>,
}

Expand All @@ -284,7 +275,7 @@ impl<C: fmt::Debug> fmt::Debug for ProxyConnector<C> {

impl<C> ProxyConnector<C> {
/// Create a new secured Proxies
#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
pub fn new(connector: C) -> Result<Self, io::Error> {
let tls = NativeTlsConnector::builder()
.build()
Expand All @@ -298,14 +289,14 @@ impl<C> ProxyConnector<C> {
}

/// Create a new secured Proxies
#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
pub fn new(connector: C) -> Result<Self, io::Error> {
let config = tokio_rustls::rustls::ClientConfig::builder();

#[cfg(feature = "rustls")]
#[cfg(feature = "rustls-tls-native-roots")]
let config = config.with_native_roots()?;

#[cfg(feature = "rustls-webpki")]
#[cfg(feature = "rustls-tls-webpki-roots")]
let config = config.with_webpki_roots();

let cfg = Arc::new(config.with_no_client_auth());
Expand All @@ -318,20 +309,6 @@ impl<C> ProxyConnector<C> {
})
}

#[allow(missing_docs)]
#[cfg(feature = "openssl-tls")]
pub fn new(connector: C) -> Result<Self, io::Error> {
let builder = OpenSslConnector::builder(SslMethod::tls())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let tls = builder.build();

Ok(ProxyConnector {
proxies: Vec::new(),
connector: connector,
tls: Some(tls),
})
}

/// Create a new unsecured Proxy
pub fn unsecured(connector: C) -> Self {
ProxyConnector {
Expand All @@ -342,7 +319,7 @@ impl<C> ProxyConnector<C> {
}

/// Create a proxy connector and attach a particular proxy
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
#[cfg(feature = "__tls")]
pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
let mut c = ProxyConnector::new(connector)?;
c.proxies.push(proxy);
Expand All @@ -366,23 +343,17 @@ impl<C> ProxyConnector<C> {
}

/// Set or unset tls when tunneling
#[cfg(any(feature = "tls"))]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
pub fn set_tls(&mut self, tls: Option<NativeTlsConnector>) {
self.tls = tls;
}

/// Set or unset tls when tunneling
#[cfg(any(feature = "rustls-base"))]
#[cfg(any(feature = "__rustls"))]
pub fn set_tls(&mut self, tls: Option<TlsConnector>) {
self.tls = tls;
}

/// Set or unset tls when tunneling
#[cfg(any(feature = "openssl-tls"))]
pub fn set_tls(&mut self, tls: Option<OpenSslConnector>) {
self.tls = tls;
}

/// Get the current proxies
pub fn proxies(&self) -> &[Proxy] {
&self.proxies
Expand Down Expand Up @@ -471,7 +442,7 @@ where
let tunnel_stream = mtry!(tunnel.with_stream(proxy_stream).await);

break match tls {
#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
Some(tls) => {
use hyper_util::rt::TokioIo;
let tls = TlsConnector::from(tls);
Expand All @@ -483,7 +454,7 @@ where
Ok(ProxyStream::Secured(TokioIo::new(secure_stream)))
}

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
Some(tls) => {
use hyper_util::rt::TokioIo;
let server_name =
Expand All @@ -497,24 +468,7 @@ where
Ok(ProxyStream::Secured(TokioIo::new(secure_stream)))
}

#[cfg(feature = "openssl-tls")]
Some(tls) => {
use hyper_util::rt::TokioIo;
let config = tls.configure().map_err(io_err)?;
let ssl = config.into_ssl(&host).map_err(io_err)?;

let mut stream =
mtry!(SslStream::new(ssl, TokioIo::new(tunnel_stream)));
mtry!(Pin::new(&mut stream).connect().await.map_err(io_err));

Ok(ProxyStream::Secured(TokioIo::new(stream)))
}

#[cfg(not(any(
feature = "tls",
feature = "rustls-base",
feature = "openssl-tls"
)))]
#[cfg(not(feature = "__tls",))]
Some(_) => panic!("hyper-proxy was not built with TLS support"),

None => Ok(ProxyStream::Regular(tunnel_stream)),
Expand Down
31 changes: 11 additions & 20 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,26 @@ use std::task::{Context, Poll};
use hyper::rt::{Read, ReadBufCursor, Write};
use hyper_util::client::legacy::connect::{Connected, Connection};

#[cfg(any(feature = "rustls-base", feature = "tls", feature = "openssl-tls"))]
#[cfg(feature = "__tls")]
use hyper_util::rt::TokioIo;

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
use tokio_rustls::client::TlsStream as RustlsStream;

#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
use tokio_native_tls::TlsStream as TokioNativeTlsStream;

#[cfg(feature = "openssl-tls")]
use tokio_openssl::SslStream as OpenSslStream;

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
pub type TlsStream<R> = TokioIo<RustlsStream<TokioIo<R>>>;

#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
pub type TlsStream<R> = TokioIo<TokioNativeTlsStream<TokioIo<R>>>;

#[cfg(feature = "openssl-tls")]
pub type TlsStream<R> = TokioIo<OpenSslStream<TokioIo<R>>>;

/// A Proxy Stream wrapper
pub enum ProxyStream<R> {
NoProxy(R),
Regular(R),
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
#[cfg(feature = "__tls")]
Secured(TlsStream<R>),
}

Expand All @@ -39,7 +33,7 @@ macro_rules! match_fn_pinned {
match $self.get_mut() {
ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx, $buf),
ProxyStream::Regular(s) => Pin::new(s).$fn($ctx, $buf),
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
#[cfg(feature = "__tls")]
ProxyStream::Secured(s) => Pin::new(s).$fn($ctx, $buf),
}
};
Expand All @@ -48,7 +42,7 @@ macro_rules! match_fn_pinned {
match $self.get_mut() {
ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx),
ProxyStream::Regular(s) => Pin::new(s).$fn($ctx),
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
#[cfg(feature = "__tls")]
ProxyStream::Secured(s) => Pin::new(s).$fn($ctx),
}
};
Expand Down Expand Up @@ -85,7 +79,7 @@ impl<R: Read + Write + Unpin> Write for ProxyStream<R> {
match self {
ProxyStream::NoProxy(s) => s.is_write_vectored(),
ProxyStream::Regular(s) => s.is_write_vectored(),
#[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
#[cfg(feature = "__tls")]
ProxyStream::Secured(s) => s.is_write_vectored(),
}
}
Expand All @@ -105,7 +99,7 @@ impl<R: Read + Write + Connection + Unpin> Connection for ProxyStream<R> {
ProxyStream::NoProxy(s) => s.connected(),

ProxyStream::Regular(s) => s.connected().proxy(true),
#[cfg(feature = "tls")]
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
ProxyStream::Secured(s) => s
.inner()
.get_ref()
Expand All @@ -115,11 +109,8 @@ impl<R: Read + Write + Connection + Unpin> Connection for ProxyStream<R> {
.connected()
.proxy(true),

#[cfg(feature = "rustls-base")]
#[cfg(feature = "__rustls")]
ProxyStream::Secured(s) => s.inner().get_ref().0.inner().connected().proxy(true),

#[cfg(feature = "openssl-tls")]
ProxyStream::Secured(s) => s.inner().get_ref().inner().connected().proxy(true),
}
}
}

0 comments on commit 5f85b4a

Please sign in to comment.