Skip to content
This repository has been archived by the owner on Feb 15, 2022. It is now read-only.

Add support of STS assume role. #736

Merged
merged 3 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions data-prepper-plugins/elasticsearch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ Default is null.

- `aws_region`: A String represents the region of Amazon Elasticsearch Service domain, e.g. us-west-2. Only applies to Amazon Elasticsearch Service. Defaults to `us-east-1`.

- `aws_sts_role`: A IAM role which the sink plugin will assume to sign request to Amazon Elasticsearch. If not provided the plugin will use the default credentials.

- `insecure`: A boolean flag to turn off SSL certificate verification. If set to true, CA certificate verification will be turned off and insecure HTTP requests will be sent. Default to `false`.

- `username`(optional): A String of username used in the [internal users](https://opendistro.github.io/for-elasticsearch-docs/docs/security/access-control/users-roles) of ODFE cluster. Default is null.
Expand Down
3 changes: 3 additions & 0 deletions data-prepper-plugins/elasticsearch/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ dependencies {
implementation 'software.amazon.awssdk:auth:2.16.95'
implementation 'software.amazon.awssdk:http-client-spi:2.16.95'
implementation 'software.amazon.awssdk:sdk-core:2.16.95'
implementation 'software.amazon.awssdk:aws-core:2.16.95'
implementation 'software.amazon.awssdk:regions:2.16.95'
implementation 'software.amazon.awssdk:utils:2.16.95'
implementation 'software.amazon.awssdk:sts:2.16.95'
implementation 'software.amazon.awssdk:url-connection-client:2.16.95'
implementation "io.micrometer:micrometer-core:1.7.1"
testImplementation("junit:junit:4.13.2") {
exclude group:'org.hamcrest' // workaround for jarHell
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;

import javax.net.ssl.SSLContext;
import java.io.InputStream;
Expand All @@ -31,6 +34,7 @@
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.util.List;
import java.util.UUID;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
Expand All @@ -50,6 +54,7 @@ public class ConnectionConfiguration {
public static final String INSECURE = "insecure";
public static final String AWS_SIGV4 = "aws_sigv4";
public static final String AWS_REGION = "aws_region";
public static final String AWS_STS_ROLE = "aws_sts_role";

private final List<String> hosts;
private final String username;
Expand All @@ -60,6 +65,8 @@ public class ConnectionConfiguration {
private final boolean insecure;
private final boolean awsSigv4;
private final String awsRegion;
private final String awsStsRole;
private final String pipelineName;

public List<String> getHosts() {
return hosts;
Expand All @@ -81,6 +88,14 @@ public String getAwsRegion() {
return awsRegion;
}

public String getAwsStsRole() {
return awsStsRole;
}

public Path getCertPath() {
return certPath;
}

public Integer getSocketTimeout() {
return socketTimeout;
}
Expand All @@ -99,13 +114,16 @@ private ConnectionConfiguration(final Builder builder) {
this.insecure = builder.insecure;
this.awsSigv4 = builder.awsSigv4;
this.awsRegion = builder.awsRegion;
this.awsStsRole = builder.awsStsRole;
this.pipelineName = builder.pipelineName;
}

public static ConnectionConfiguration readConnectionConfiguration(final PluginSetting pluginSetting){
@SuppressWarnings("unchecked")
final List<String> hosts = (List<String>) pluginSetting.getAttributeFromSettings(HOSTS);
ConnectionConfiguration.Builder builder = new ConnectionConfiguration.Builder(hosts);
final String username = (String) pluginSetting.getAttributeFromSettings(USERNAME);
builder.withPipelineName(pluginSetting.getPipelineName());
if (username != null) {
builder = builder.withUsername(username);
}
Expand All @@ -123,7 +141,10 @@ public static ConnectionConfiguration readConnectionConfiguration(final PluginSe
}

builder.withAwsSigv4(pluginSetting.getBooleanOrDefault(AWS_SIGV4, false));
builder.withAwsRegion(pluginSetting.getStringOrDefault(AWS_REGION, DEFAULT_AWS_REGION));
if (builder.awsSigv4) {
builder.withAwsRegion(pluginSetting.getStringOrDefault(AWS_REGION, DEFAULT_AWS_REGION));
builder.withAWSStsRole(pluginSetting.getStringOrDefault(AWS_STS_ROLE, null));
}

final String certPath = pluginSetting.getStringOrDefault(CERT_PATH, null);
final boolean insecure = pluginSetting.getBooleanOrDefault(INSECURE, false);
Expand All @@ -136,8 +157,8 @@ public static ConnectionConfiguration readConnectionConfiguration(final PluginSe
return builder.build();
}

public Path getCertPath() {
return certPath;
public String getPipelineName() {
return pipelineName;
}

public RestHighLevelClient createClient() {
Expand Down Expand Up @@ -175,7 +196,19 @@ private void attachSigV4(final RestClientBuilder restClientBuilder) {
//if not follow regular credentials process
LOG.info("{} is set, will sign requests using AWSRequestSigningApacheInterceptor", AWS_SIGV4);
final Aws4Signer aws4Signer = Aws4Signer.create();
final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
AwsCredentialsProvider credentialsProvider;
if (awsStsRole != null && !awsStsRole.isEmpty()) {
credentialsProvider = StsAssumeRoleCredentialsProvider.builder()
.stsClient(StsClient.create())
.refreshRequest(AssumeRoleRequest.builder()
.roleSessionName(pipelineName + " Elasticsearch-Sink " + UUID.randomUUID()
.toString())
.roleArn(awsStsRole)
.build())
.build();
} else {
credentialsProvider = DefaultCredentialsProvider.create();
}
final HttpRequestInterceptor httpRequestInterceptor = new AwsRequestSigningApacheInterceptor(SERVICE_NAME, aws4Signer,
credentialsProvider, awsRegion);
restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> {
Expand Down Expand Up @@ -248,6 +281,8 @@ public static class Builder {
private boolean insecure;
private boolean awsSigv4;
private String awsRegion;
private String awsStsRole;
private String pipelineName;


public Builder(final List<String> hosts) {
Expand Down Expand Up @@ -302,6 +337,16 @@ public Builder withAwsRegion(final String awsRegion) {
return this;
}

public Builder withAWSStsRole(final String awsStsRole) {
this.awsStsRole = awsStsRole;
return this;
}

public Builder withPipelineName(final String pipelineName) {
this.pipelineName = pipelineName;
return this;
}

public ConnectionConfiguration build() {
return new ConnectionConfiguration(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ public class ConnectionConfigurationTests {
private final List<String> TEST_HOSTS = Collections.singletonList("http://localhost:9200");
private final String TEST_USERNAME = "admin";
private final String TEST_PASSWORD = "admin";
private final String TEST_PIPELINE_NAME = "Test-Pipeline";
private final Integer TEST_CONNECT_TIMEOUT = 5;
private final Integer TEST_SOCKET_TIMEOUT = 10;
private final String TEST_CERT_PATH = Objects.requireNonNull(getClass().getClassLoader().getResource("test-ca.pem")).getFile();

@Test
public void testReadConnectionConfigurationDefault() {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, null, null, null, null, false, null, null, false);
TEST_HOSTS, null, null, null, null, false, null, null, null, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
assertEquals(TEST_HOSTS, connectionConfiguration.getHosts());
Expand All @@ -38,12 +39,13 @@ public void testReadConnectionConfigurationDefault() {
assertNull(connectionConfiguration.getCertPath());
assertNull(connectionConfiguration.getConnectTimeout());
assertNull(connectionConfiguration.getSocketTimeout());
assertEquals(TEST_PIPELINE_NAME, connectionConfiguration.getPipelineName());
}

@Test
public void testCreateClientDefault() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, null, null, null, null, false, null, null, false);
TEST_HOSTS, null, null, null, null, false, null, null, null, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
final RestHighLevelClient client = connectionConfiguration.createClient();
Expand All @@ -54,7 +56,7 @@ public void testCreateClientDefault() throws IOException {
@Test
public void testReadConnectionConfigurationNoCert() {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, null, false);
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, null, null, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
assertEquals(TEST_HOSTS, connectionConfiguration.getHosts());
Expand All @@ -63,12 +65,13 @@ public void testReadConnectionConfigurationNoCert() {
assertEquals(TEST_CONNECT_TIMEOUT, connectionConfiguration.getConnectTimeout());
assertEquals(TEST_SOCKET_TIMEOUT, connectionConfiguration.getSocketTimeout());
assertFalse(connectionConfiguration.isAwsSigv4());
assertEquals(TEST_PIPELINE_NAME, connectionConfiguration.getPipelineName());
}

@Test
public void testCreateClientNoCert() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, null, false);
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, null, null, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
final RestHighLevelClient client = connectionConfiguration.createClient();
Expand All @@ -79,7 +82,7 @@ public void testCreateClientNoCert() throws IOException {
@Test
public void testCreateClientInsecure() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, null, true);
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, null, null, true);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
final RestHighLevelClient client = connectionConfiguration.createClient();
Expand All @@ -90,7 +93,7 @@ public void testCreateClientInsecure() throws IOException {
@Test
public void testCreateClientWithCertPath() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, TEST_CERT_PATH, false);
TEST_HOSTS, TEST_USERNAME, TEST_PASSWORD, TEST_CONNECT_TIMEOUT, TEST_SOCKET_TIMEOUT, false, null, null, TEST_CERT_PATH, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
final RestHighLevelClient client = connectionConfiguration.createClient();
Expand All @@ -101,7 +104,7 @@ public void testCreateClientWithCertPath() throws IOException {
@Test
public void testCreateClientWithAWSSigV4AndRegion() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, null, null, null, null, true, "us-west-2", null, false);
TEST_HOSTS, null, null, null, null, true, "us-west-2", null, null, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
assertEquals("us-west-2", connectionConfiguration.getAwsRegion());
Expand All @@ -111,37 +114,52 @@ public void testCreateClientWithAWSSigV4AndRegion() throws IOException {
@Test
public void testCreateClientWithAWSSigV4DefaultRegion() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, null, null, null, null, true, null, null, false);
TEST_HOSTS, null, null, null, null, true, null, null, null, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
assertEquals("us-east-1", connectionConfiguration.getAwsRegion());
assertTrue(connectionConfiguration.isAwsSigv4());;
assertTrue(connectionConfiguration.isAwsSigv4());
assertEquals(TEST_PIPELINE_NAME, connectionConfiguration.getPipelineName());
}

@Test
public void testCreateClientWithAWSSigV4AndInsecure() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, null, null, null, null, true, null, null, true);
TEST_HOSTS, null, null, null, null, true, null, null, null, true);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
assertEquals("us-east-1", connectionConfiguration.getAwsRegion());
assertTrue(connectionConfiguration.isAwsSigv4());
assertEquals(TEST_PIPELINE_NAME, connectionConfiguration.getPipelineName());
}

@Test
public void testCreateClientWithAWSSigV4AndCertPath() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, null, null, null, null, true, null, TEST_CERT_PATH, false);
TEST_HOSTS, null, null, null, null, true, null, null, TEST_CERT_PATH, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
assertEquals("us-east-1", connectionConfiguration.getAwsRegion());
assertTrue(connectionConfiguration.isAwsSigv4());;
assertTrue(connectionConfiguration.isAwsSigv4());
assertEquals(TEST_PIPELINE_NAME, connectionConfiguration.getPipelineName());
}

@Test
public void testCreateClientWithAWSSigV4AndSTSRole() throws IOException {
final PluginSetting pluginSetting = generatePluginSetting(
TEST_HOSTS, null, null, null, null, true, null, "some-iam-role", TEST_CERT_PATH, false);
final ConnectionConfiguration connectionConfiguration =
ConnectionConfiguration.readConnectionConfiguration(pluginSetting);
assertEquals("us-east-1", connectionConfiguration.getAwsRegion());
assertTrue(connectionConfiguration.isAwsSigv4());
assertEquals("some-iam-role", connectionConfiguration.getAwsStsRole());
assertEquals(TEST_PIPELINE_NAME, connectionConfiguration.getPipelineName());
}

private PluginSetting generatePluginSetting(
final List<String> hosts, final String username, final String password,
final Integer connectTimeout, final Integer socketTimeout, final boolean awsSigv4, final String awsRegion,
final String certPath, final boolean insecure) {
final String awsStsRole, final String certPath, final boolean insecure) {
final Map<String, Object> metadata = new HashMap<>();
metadata.put("hosts", hosts);
metadata.put("username", username);
Expand All @@ -152,9 +170,11 @@ private PluginSetting generatePluginSetting(
if (awsRegion != null) {
metadata.put("aws_region", awsRegion);
}
metadata.put("aws_sts_role", awsStsRole);
metadata.put("cert", certPath);
metadata.put("insecure", insecure);

return new PluginSetting("elasticsearch", metadata);
final PluginSetting pluginSetting = new PluginSetting("elasticsearch", metadata);
pluginSetting.setPipelineName(TEST_PIPELINE_NAME);
return pluginSetting;
}
}