From baf26e3d9c7100833c63b9bcc19cd1a46d2aa108 Mon Sep 17 00:00:00 2001 From: avinas-kumar Date: Wed, 17 Jul 2024 19:33:10 +0530 Subject: [PATCH] Added support for '.' in Bucket Name (#84) * Added support for '.' in Bucket Name * Added support for '.' in Bucket Name * Modified comments --------- Co-authored-by: Avinash Kumar --- .../com/linkedin/cdi/source/S3SourceV2.java | 39 ++++++++++++++++- .../linkedin/cdi/source/S3SourceV2Test.java | 42 +++++++++++++++++++ docs/parameters/ms.source.s3.parameters.md | 1 + 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/cdi-core/src/main/java/com/linkedin/cdi/source/S3SourceV2.java b/cdi-core/src/main/java/com/linkedin/cdi/source/S3SourceV2.java index 7b3796a..f55d0e8 100644 --- a/cdi-core/src/main/java/com/linkedin/cdi/source/S3SourceV2.java +++ b/cdi-core/src/main/java/com/linkedin/cdi/source/S3SourceV2.java @@ -4,6 +4,7 @@ package com.linkedin.cdi.source; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -14,6 +15,7 @@ import java.net.URL; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.stream.Collectors; import org.apache.avro.Schema; import org.apache.avro.generic.GenericRecord; @@ -33,6 +35,9 @@ public class S3SourceV2 extends MultistageSource { private static final String KEY_CONNECTION_TIMEOUT = "connection_timeout"; private static final HashSet S3_REGIONS_SET = Region.regions().stream().map(region -> region.toString()).collect(Collectors.toCollection(HashSet::new)); + + private static final String KEY_BUCKET_NAME = "bucket_name"; + private S3Keys s3SourceV2Keys = new S3Keys(); public S3Keys getS3SourceV2Keys() { @@ -87,11 +92,11 @@ protected void initialize(State state) { } // separate the endpoint, which should be a URL without bucket name, from the domain name - s3SourceV2Keys.setEndpoint("https://" + getEndpointFromHost(url.getHost())); + s3SourceV2Keys.setEndpoint("https://" + getEndpoint(parameters, url.getHost())); s3SourceV2Keys.setPrefix(url.getPath().substring(1)); // separate the bucket name from URI domain name - s3SourceV2Keys.setBucket(url.getHost().split("\\.")[0]); + s3SourceV2Keys.setBucket(getBucketName(parameters, url.getHost())); s3SourceV2Keys.setFilesPattern(MSTAGE_SOURCE_FILES_PATTERN.get(state)); s3SourceV2Keys.setMaxKeys(MSTAGE_S3_LIST_MAX_KEYS.get(state)); @@ -130,4 +135,34 @@ private String getEndpointFromHost(String host) { segments.remove(0); return Joiner.on(".").join(segments); } + + /** + * + * @param parameters JsonObject containing ms.source.s3.parameters + * @param host hostname with bucket name in the beginning + * @return the bucket name + */ + @VisibleForTesting + protected String getBucketName(JsonObject parameters, String host) { + if (parameters.has(KEY_BUCKET_NAME)) { + return parameters.get(KEY_BUCKET_NAME).getAsString(); + } + return host.split("\\.")[0]; + } + + /** + * + * @param parameters JsonObject containing ms.source.s3.parameters + * @param host hostname with bucket name in the beginning + * @return the endpoint name if bucket name is present in the parameters then removes the bucket name from host and + * calls the getEndpointFromHost method to get the endpoint. + */ + @VisibleForTesting + protected String getEndpoint(JsonObject parameters, String host) { + if (parameters.has(KEY_BUCKET_NAME)) { + String bucketName = parameters.get(KEY_BUCKET_NAME).getAsString().toLowerCase(); + host = host.toLowerCase(Locale.ROOT).replace(bucketName, ""); + } + return getEndpointFromHost(host); + } } diff --git a/cdi-core/src/test/java/com/linkedin/cdi/source/S3SourceV2Test.java b/cdi-core/src/test/java/com/linkedin/cdi/source/S3SourceV2Test.java index 07f72ff..eb2a9f4 100644 --- a/cdi-core/src/test/java/com/linkedin/cdi/source/S3SourceV2Test.java +++ b/cdi-core/src/test/java/com/linkedin/cdi/source/S3SourceV2Test.java @@ -4,6 +4,7 @@ package com.linkedin.cdi.source; +import com.google.gson.Gson; import com.google.gson.JsonObject; import java.io.UnsupportedEncodingException; import org.apache.gobblin.configuration.WorkUnitState; @@ -36,4 +37,45 @@ public void testInitialization() throws UnsupportedEncodingException { Assert.assertEquals(source.getS3SourceV2Keys().getMaxKeys(), new Integer(1000)); Assert.assertEquals(source.getS3SourceV2Keys().getConnectionTimeout(), new Integer(30)); } + + @Test + public void testBucketName() { + S3SourceV2 s3SourceV2 = new S3SourceV2(); + JsonObject parameteres = + new Gson().fromJson("{\"region\" : \"us-east-2\", \"bucket_name\" : \"collect-us-west-2.tealium.com\"}", + JsonObject.class); + String host = "collect-us-west-2.tealium.com.s3.amazonaws.com"; + String bucketName = s3SourceV2.getBucketName(parameteres, host); + Assert.assertEquals(bucketName, "collect-us-west-2.tealium.com"); + } + + @Test + public void testBucketNameWithoutBucketParameter() { + S3SourceV2 s3SourceV2 = new S3SourceV2(); + JsonObject parameteres = new Gson().fromJson("{\"region\" : \"us-east-2\"}", JsonObject.class); + String host = "collect-us-west-2.s3.amazonaws.com"; + String bucketName = s3SourceV2.getBucketName(parameteres, host); + Assert.assertEquals(bucketName, "collect-us-west-2"); + } + + @Test + public void testEndpoint() { + S3SourceV2 s3SourceV2 = new S3SourceV2(); + JsonObject parameteres = + new Gson().fromJson("{\"region\" : \"us-east-2\", \"bucket_name\" : \"colleCt-us-west-2.tealium.com\"}", + JsonObject.class); + String host = "collect-us-west-2.tealium.com.s3.amazonaws.com"; + String endpoint = s3SourceV2.getEndpoint(parameteres, host); + Assert.assertEquals(endpoint, "s3.amazonaws.com"); + } + + @Test + public void testEndpointWithoutPeriod() { + S3SourceV2 s3SourceV2 = new S3SourceV2(); + JsonObject parameters = + new Gson().fromJson("{\"region\" : \"us-east-2\", \"bucket_name\" : \"collect-us-west-2\"}", JsonObject.class); + String host = "collect-us-west-2.s3.amazonaws.com"; + String endpoint = s3SourceV2.getEndpoint(parameters, host); + Assert.assertEquals(endpoint, "s3.amazonaws.com"); + } } diff --git a/docs/parameters/ms.source.s3.parameters.md b/docs/parameters/ms.source.s3.parameters.md index 0ca874b..7bcb33a 100644 --- a/docs/parameters/ms.source.s3.parameters.md +++ b/docs/parameters/ms.source.s3.parameters.md @@ -23,6 +23,7 @@ It can have the following attributes: - **write_timeout_seconds**: integer, write time out in seconds - **connection_timeout_seconds**: Sets the socket to timeout after failing to establish a connection with the server after milliseconds. - **connection_max_idle_millis**: Sets the socket to timeout after timeout milliseconds of inactivity on the socket. +- **bucket_name**: Sets the bucket name, optional if the bucket name doesn't contain any special characters. ### Example