Skip to content

Commit

Permalink
Refactor AWS SigV4 to middleware approach
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia committed Dec 6, 2023
1 parent 7548c9f commit c8bc677
Show file tree
Hide file tree
Showing 25 changed files with 1,172 additions and 814 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
version: 2.8.0
secured: true

- name: Run Tests (${{ matrix.test-args }})
- name: Run Tests
working-directory: client
run: cargo make test ${{ matrix.test-args }}
env:
Expand Down
10 changes: 3 additions & 7 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@ OPENSEARCH_URL = { value = "${OPENSEARCH_PROTOCOL}://localhost:9200", condition
category = "OpenSearch"
description = "Generates SSL certificates used for integration tests"
command = "bash"
args =["./.ci/generate-certs.sh"]
args = ["./.ci/generate-certs.sh"]

[tasks.run-opensearch]
category = "OpenSearch"
private = true
condition = { env_set = [ "STACK_VERSION"], env_false = ["CARGO_MAKE_CI"] }

[tasks.run-opensearch.linux]
command = "./.ci/run-opensearch.sh"

[tasks.run-opensearch.mac]
command = "./.ci/run-opensearch.sh"
command = "bash"
args = ["./.ci/run-opensearch.sh"]

[tasks.run-opensearch.windows]
script_runner = "cmd"
Expand Down
19 changes: 10 additions & 9 deletions opensearch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,38 @@ experimental-apis = ["beta-apis"]
native-tls = ["reqwest/native-tls"]
rustls-tls = ["reqwest/rustls-tls"]

# AWS SigV4 Auth support
aws-auth = ["aws-credential-types", "aws-sigv4", "aws-smithy-runtime-api", "aws-types"]
aws-auth = ["dep:aws-config", "dep:aws-credential-types", "dep:aws-sigv4", "dep:aws-smithy-runtime-api", "dep:aws-types"]

[dependencies]
async-trait = "0.1"
aws-config = { version = "1", optional = true }
aws-credential-types = { version = "1", optional = true }
aws-sigv4 = { version = "1", optional = true }
aws-smithy-runtime-api = { version = "1", features = ["client"], optional = true }
aws-types = { version = "1", optional = true }
base64 = "0.21"
bytes = "1.0"
dyn-clone = "1"
futures-util = "0.3"
lazy_static = "1.4"
percent-encoding = "2.1.0"
reqwest = { version = "0.11", default-features = false, features = ["gzip", "json"] }
url = "2.1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_with = "3"
thiserror = "1"
void = "1.0.2"
aws-credential-types = { version = "1", optional = true }
aws-sigv4 = { version = "1", optional = true }
aws-smithy-runtime-api = { version = "1", optional = true, features = ["client"]}
aws-types = { version = "1", optional = true }

[dev-dependencies]
anyhow = "1.0"
aws-config = "1"
aws-smithy-async = "1"
chrono = { version = "0.4", features = ["serde"] }
clap = "2"
futures = "0.3.1"
futures = "0.3"
http-body-util = "0.1.0"
hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
regex="1.4"
sysinfo = "0.29.0"
test-case = "3"
textwrap = "0.16"
Expand Down
17 changes: 10 additions & 7 deletions opensearch/examples/aws_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,26 @@
* GitHub history for details.
*/

#[tokio::main]
#[cfg(feature = "aws-auth")]
#[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
use std::convert::TryInto;

use aws_config::BehaviorVersion;
use opensearch::{
cat::CatIndicesParts,
http::transport::{SingleNodeConnectionPool, TransportBuilder},
http::{
transport::{SingleNodeConnectionPool, TransportBuilder},
Url,
},
OpenSearch,
};
use url::Url;

let aws_config = aws_config::load_from_env().await;
let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await;

let host = ""; // e.g. https://search-mydomain.us-west-1.es.amazonaws.com
let transport = TransportBuilder::new(SingleNodeConnectionPool::new(Url::parse(host).unwrap()))
.auth(aws_config.try_into()?)
.aws_sigv4(aws_config.try_into()?)
.build()?;
let client = OpenSearch::new(transport);

Expand All @@ -42,6 +45,6 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

#[cfg(not(feature = "aws-auth"))]
pub fn main() {
panic!("Requires the `aws-auth` feature to be enabled")
fn main() {
panic!("This example requires the `aws-auth` feature to be enabled")
}
76 changes: 47 additions & 29 deletions opensearch/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@

//! Authentication components
use crate::{
http::middleware::{ClientInitializer, RequestInitializer},
BoxError,
};
use reqwest::Identity;

/// Credentials for authentication
#[derive(Debug, Clone)]
pub enum Credentials {
Expand All @@ -45,16 +51,6 @@ pub enum Credentials {
Certificate(ClientCertificate),
/// An id and api_key to use for API key authentication
ApiKey(String, String),
/// AWS credentials used for AWS SigV4 request signing.
///
/// # Optional
///
/// This requires the `aws-auth` feature to be enabled.
#[cfg(feature = "aws-auth")]
AwsSigV4(
aws_credential_types::provider::SharedCredentialsProvider,
aws_types::region::Region,
),
}

#[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
Expand Down Expand Up @@ -90,28 +86,50 @@ impl From<ClientCertificate> for Credentials {
}
}

#[cfg(any(feature = "aws-auth"))]
impl std::convert::TryFrom<&aws_types::SdkConfig> for Credentials {
type Error = super::Error;

fn try_from(value: &aws_types::SdkConfig) -> Result<Self, Self::Error> {
let credentials = value
.credentials_provider()
.ok_or_else(|| super::error::lib("SdkConfig does not have a credentials_provider"))?
.clone();
let region = value
.region()
.ok_or_else(|| super::error::lib("SdkConfig does not have a region"))?
.clone();
Ok(Credentials::AwsSigV4(credentials, region))
impl ClientInitializer for Credentials {
fn init(
&self,
client: reqwest::ClientBuilder,
) -> Result<reqwest::ClientBuilder, BoxError<'static>> {
match &self {
#[cfg(feature = "native-tls")]
Credentials::Certificate(ClientCertificate::Pkcs12(b, p)) => {
Ok(client.identity(Identity::from_pkcs12_der(b, p.as_deref().unwrap_or(""))?))
}
#[cfg(feature = "rustls-tls")]
Credentials::Certificate(ClientCertificate::Pem(b)) => {
Ok(client.identity(Identity::from_pem(b)?))
}
_ => Ok(client),
}
}
}

#[cfg(any(feature = "aws-auth"))]
impl std::convert::TryFrom<aws_types::SdkConfig> for Credentials {
type Error = super::Error;
impl RequestInitializer for Credentials {
fn init(
&self,
request: reqwest::RequestBuilder,
) -> Result<reqwest::RequestBuilder, BoxError<'static>> {
Ok(match &self {
Credentials::Basic(u, p) => request.basic_auth(u, Some(p)),
Credentials::Bearer(t) => request.bearer_auth(t),
Credentials::ApiKey(id, key) => {
use base64::{prelude::BASE64_STANDARD, write::EncoderWriter as Base64Encoder};
use reqwest::header::{HeaderValue, AUTHORIZATION};
use std::io::Write;

fn try_from(value: aws_types::SdkConfig) -> Result<Self, Self::Error> {
Credentials::try_from(&value)
let mut header_value = b"ApiKey ".to_vec();
{
let mut encoder = Base64Encoder::new(&mut header_value, &BASE64_STANDARD);
write!(encoder, "{}:", id).unwrap();
write!(encoder, "{}", key).unwrap();
}
request.header(
AUTHORIZATION,
HeaderValue::from_bytes(&header_value).unwrap(),
)
}
_ => request,
})
}
}
18 changes: 18 additions & 0 deletions opensearch/src/aws/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

pub use aws_config;
pub use aws_credential_types;
pub use aws_types;

mod sigv4;

pub use sigv4::*;
Loading

0 comments on commit c8bc677

Please sign in to comment.