Skip to content

Commit

Permalink
Reuse hyper client for with-tokio feature
Browse files Browse the repository at this point in the history
  • Loading branch information
durch committed Oct 16, 2023
1 parent 43b7414 commit 7d7c57b
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 52 deletions.
3 changes: 1 addition & 2 deletions s3/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rust-s3"
version = "0.34.0-rc1"
version = "0.34.0-rc2"
authors = ["Drazen Urch"]
description = "Rust library for working with AWS S3 and compatible object storage APIs"
repository = "https://github.com/durch/rust-s3"
Expand Down Expand Up @@ -62,7 +62,6 @@ hyper = { version = "0.14", default-features = false, features = [
"stream",
], optional = true }
hyper-tls = { version = "0.5.0", default-features = false, optional = true }
hyper-native-tls = { version = "0.3.0", default-features = false, optional = true }
log = "0.4"
maybe-async = { version = "0.2" }
md5 = "0.7"
Expand Down
61 changes: 49 additions & 12 deletions s3/src/bucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use crate::bucket_ops::{BucketConfiguration, CreateBucketResponse};
use crate::command::{Command, Multipart};
use crate::creds::Credentials;
use crate::region::Region;
#[cfg(feature = "with-tokio")]
use crate::request::tokio_backend::client;
use crate::request::ResponseData;
#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
use crate::request::ResponseDataStream;
Expand Down Expand Up @@ -94,6 +96,8 @@ pub struct Bucket {
pub request_timeout: Option<Duration>,
path_style: bool,
listobjects_v2: bool,
#[cfg(feature = "with-tokio")]
http_client: Arc<hyper::Client<hyper_tls::HttpsConnector<hyper::client::HttpConnector>>>,
}

impl Bucket {
Expand All @@ -104,6 +108,13 @@ impl Bucket {
.map_err(|_| S3Error::WLCredentials)?
.refresh()?)
}

#[cfg(feature = "with-tokio")]
pub fn http_client(
&self,
) -> Arc<hyper::Client<hyper_tls::HttpsConnector<hyper::client::HttpConnector>>> {
Arc::clone(&self.http_client)
}
}

fn validate_expiry(expiry_secs: u32) -> Result<(), S3Error> {
Expand Down Expand Up @@ -451,6 +462,7 @@ impl Bucket {
let request = RequestImpl::new(&bucket, "", command)?;
let response_data = request.response_data(false).await?;
let response_text = response_data.to_string()?;

Ok(CreateBucketResponse {
bucket,
response_text,
Expand Down Expand Up @@ -520,6 +532,8 @@ impl Bucket {
request_timeout: DEFAULT_REQUEST_TIMEOUT,
path_style: false,
listobjects_v2: true,
#[cfg(feature = "with-tokio")]
http_client: Arc::new(client(DEFAULT_REQUEST_TIMEOUT)?),
})
}

Expand All @@ -544,6 +558,8 @@ impl Bucket {
request_timeout: DEFAULT_REQUEST_TIMEOUT,
path_style: false,
listobjects_v2: true,
#[cfg(feature = "with-tokio")]
http_client: Arc::new(client(DEFAULT_REQUEST_TIMEOUT)?),
})
}

Expand All @@ -557,11 +573,13 @@ impl Bucket {
request_timeout: self.request_timeout,
path_style: true,
listobjects_v2: self.listobjects_v2,
#[cfg(feature = "with-tokio")]
http_client: self.http_client.clone(),
}
}

pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Bucket {
Bucket {
pub fn with_extra_headers(&self, extra_headers: HeaderMap) -> Result<Bucket, S3Error> {
Ok(Bucket {
name: self.name.clone(),
region: self.region.clone(),
credentials: self.credentials.clone(),
Expand All @@ -570,11 +588,16 @@ impl Bucket {
request_timeout: self.request_timeout,
path_style: self.path_style,
listobjects_v2: self.listobjects_v2,
}
#[cfg(feature = "with-tokio")]
http_client: self.http_client.clone(),
})
}

pub fn with_extra_query(&self, extra_query: HashMap<String, String>) -> Bucket {
Bucket {
pub fn with_extra_query(
&self,
extra_query: HashMap<String, String>,
) -> Result<Bucket, S3Error> {
Ok(Bucket {
name: self.name.clone(),
region: self.region.clone(),
credentials: self.credentials.clone(),
Expand All @@ -583,11 +606,13 @@ impl Bucket {
request_timeout: self.request_timeout,
path_style: self.path_style,
listobjects_v2: self.listobjects_v2,
}
#[cfg(feature = "with-tokio")]
http_client: self.http_client.clone(),
})
}

pub fn with_request_timeout(&self, request_timeout: Duration) -> Bucket {
Bucket {
pub fn with_request_timeout(&self, request_timeout: Duration) -> Result<Bucket, S3Error> {
Ok(Bucket {
name: self.name.clone(),
region: self.region.clone(),
credentials: self.credentials.clone(),
Expand All @@ -596,7 +621,9 @@ impl Bucket {
request_timeout: Some(request_timeout),
path_style: self.path_style,
listobjects_v2: self.listobjects_v2,
}
#[cfg(feature = "with-tokio")]
http_client: Arc::new(client(Some(request_timeout))?),
})
}

pub fn with_listobjects_v1(&self) -> Bucket {
Expand All @@ -609,6 +636,8 @@ impl Bucket {
request_timeout: self.request_timeout,
path_style: self.path_style,
listobjects_v2: false,
#[cfg(feature = "with-tokio")]
http_client: self.http_client.clone(),
}
}

Expand Down Expand Up @@ -2371,7 +2400,14 @@ mod test {
}

fn test_minio_credentials() -> Credentials {
Credentials::new(Some("test"), Some("test1234"), None, None, None).unwrap()
Credentials::new(
Some(&env::var("MINIO_ACCESS_KEY_ID").unwrap()),
Some(&env::var("MINIO_SECRET_ACCESS_KEY").unwrap()),
None,
None,
None,
)
.unwrap()
}

fn test_digital_ocean_credentials() -> Credentials {
Expand Down Expand Up @@ -2432,7 +2468,7 @@ mod test {
Bucket::new(
"rust-s3",
Region::Custom {
region: "eu-central-1".to_owned(),
region: "us-east-1".to_owned(),
endpoint: "http://localhost:9000".to_owned(),
},
test_minio_credentials(),
Expand Down Expand Up @@ -3143,7 +3179,8 @@ mod test {
test_aws_credentials(),
)
.unwrap()
.with_request_timeout(Duration::from_secs(10));
.with_request_timeout(Duration::from_secs(10))
.unwrap();

assert_eq!(bucket.request_timeout(), Some(Duration::from_secs(10)));
}
Expand Down
19 changes: 11 additions & 8 deletions s3/src/bucket_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl fmt::Display for BucketAcl {

#[derive(Clone, Debug)]
pub struct BucketConfiguration {
acl: CannedBucketAcl,
acl: Option<CannedBucketAcl>,
object_lock_enabled: bool,
grant_full_control: Option<Vec<BucketAcl>>,
grant_read: Option<Vec<BucketAcl>>,
Expand All @@ -75,7 +75,7 @@ fn acl_list(acl: &[BucketAcl]) -> String {
impl BucketConfiguration {
#[allow(clippy::too_many_arguments)]
pub fn new(
acl: CannedBucketAcl,
acl: Option<CannedBucketAcl>,
object_lock_enabled: bool,
grant_full_control: Option<Vec<BucketAcl>>,
grant_read: Option<Vec<BucketAcl>>,
Expand All @@ -98,7 +98,7 @@ impl BucketConfiguration {

pub fn public() -> Self {
BucketConfiguration {
acl: CannedBucketAcl::PublicReadWrite,
acl: None,
object_lock_enabled: false,
grant_full_control: None,
grant_read: None,
Expand All @@ -111,7 +111,7 @@ impl BucketConfiguration {

pub fn private() -> Self {
BucketConfiguration {
acl: CannedBucketAcl::Private,
acl: Some(CannedBucketAcl::Private),
object_lock_enabled: false,
grant_full_control: None,
grant_read: None,
Expand Down Expand Up @@ -145,10 +145,13 @@ impl BucketConfiguration {
}

pub fn add_headers(&self, headers: &mut HeaderMap) -> Result<(), S3Error> {
headers.insert(
HeaderName::from_static("x-amz-acl"),
self.acl.to_string().parse()?,
);
if let Some(ref acl) = self.acl {
headers.insert(
HeaderName::from_static("x-amz-acl"),
acl.to_string().parse()?,
);
}

if self.object_lock_enabled {
headers.insert(
HeaderName::from_static("x-amz-bucket-object-lock-enabled"),
Expand Down
67 changes: 37 additions & 30 deletions s3/src/request/tokio_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,42 @@ use crate::utils::now_utc;

use tokio_stream::StreamExt;

pub fn client(
request_timeout: Option<std::time::Duration>,
) -> Result<Client<HttpsConnector<HttpConnector>>, S3Error> {
#[cfg(any(feature = "use-tokio-native-tls", feature = "tokio-rustls-tls"))]
let mut tls_connector_builder = native_tls::TlsConnector::builder();

#[cfg(not(any(feature = "use-tokio-native-tls", feature = "tokio-rustls-tls")))]
let tls_connector_builder = native_tls::TlsConnector::builder();

if cfg!(feature = "no-verify-ssl") {
cfg_if::cfg_if! {
if #[cfg(feature = "use-tokio-native-tls")]
{
tls_connector_builder.danger_accept_invalid_hostnames(true);
}

}

cfg_if::cfg_if! {
if #[cfg(any(feature = "use-tokio-native-tls", feature = "tokio-rustls-tls"))]
{
tls_connector_builder.danger_accept_invalid_certs(true);
}

}
}
let tls_connector = tokio_native_tls::TlsConnector::from(tls_connector_builder.build()?);

let mut http_connector = HttpConnector::new();
http_connector.set_connect_timeout(request_timeout);
http_connector.enforce_http(false);
let https_connector = HttpsConnector::from((http_connector, tls_connector));

Ok(Client::builder().build::<_, hyper::Body>(https_connector))
}

// Temporary structure for making a request
pub struct HyperRequest<'a> {
pub bucket: &'a Bucket,
Expand All @@ -40,36 +76,7 @@ impl<'a> Request for HyperRequest<'a> {
Err(e) => return Err(e),
};

#[cfg(any(feature = "use-tokio-native-tls", feature = "tokio-rustls-tls"))]
let mut tls_connector_builder = native_tls::TlsConnector::builder();

#[cfg(not(any(feature = "use-tokio-native-tls", feature = "tokio-rustls-tls")))]
let tls_connector_builder = native_tls::TlsConnector::builder();

if cfg!(feature = "no-verify-ssl") {
cfg_if::cfg_if! {
if #[cfg(feature = "use-tokio-native-tls")]
{
tls_connector_builder.danger_accept_invalid_hostnames(true);
}

}

cfg_if::cfg_if! {
if #[cfg(any(feature = "use-tokio-native-tls", feature = "tokio-rustls-tls"))]
{
tls_connector_builder.danger_accept_invalid_certs(true);
}

}
}
let tls_connector = tokio_native_tls::TlsConnector::from(tls_connector_builder.build()?);

let mut http_connector = HttpConnector::new();
http_connector.set_connect_timeout(self.bucket.request_timeout);
let https_connector = HttpsConnector::from((http_connector, tls_connector));

let client = Client::builder().build::<_, hyper::Body>(https_connector);
let client = self.bucket.http_client();

let method = match self.command.http_verb() {
HttpMethod::Delete => http::Method::DELETE,
Expand Down

0 comments on commit 7d7c57b

Please sign in to comment.