diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index 4b4005984d..7ccf527deb 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -56,9 +56,7 @@ import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; -import software.amazon.awssdk.services.dynamodb.model.AttributeAction; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -import software.amazon.awssdk.services.dynamodb.model.AttributeValueUpdate; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; @@ -128,7 +126,10 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe item.put(SOURCE, AttributeValue.builder().m(sourceMap).build()); item.put(SEQ_NO_KEY, AttributeValue.builder().n(sequenceNumber.toString()).build()); Builder builder = PutItemRequest.builder().tableName(tableName).item(item); - if (!request.overwriteIfExists() && getItemResponse != null && getItemResponse.item() != null) { + if (!request.overwriteIfExists() + && getItemResponse != null + && getItemResponse.item() != null + && !getItemResponse.item().isEmpty()) { throw new OpenSearchStatusException("Existing data object for ID: " + request.id(), RestStatus.CONFLICT); } final PutItemRequest putItemRequest = builder.build(); @@ -216,37 +217,25 @@ public CompletionStage updateDataObjectAsync(UpdateDat Map updateItem = JsonTransformer.convertJsonObjectToDDBAttributeMap(jsonNode); updateItem.remove(HASH_KEY); updateItem.remove(RANGE_KEY); - Map updateAttributeValue = new HashMap<>(); - updateAttributeValue - .put( - SOURCE, - AttributeValueUpdate - .builder() - .action(AttributeAction.PUT) - .value(AttributeValue.builder().m(updateItem).build()) - .build() - ); Map updateKey = new HashMap<>(); updateKey.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); updateKey.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); - UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest - .builder() - .tableName(request.index()) - .key(updateKey) - .attributeUpdates(updateAttributeValue); - updateItemRequestBuilder - .updateExpression("SET #seqNo = #seqNo + :incr") - .expressionAttributeNames(Map.of("#seqNo", SEQ_NO_KEY)) - .expressionAttributeValues(Map.of(":incr", AttributeValue.builder().n("1").build())); + UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest.builder().tableName(request.index()).key(updateKey); + Map expressionAttributeNames = new HashMap<>(); + expressionAttributeNames.put("#seqNo", SEQ_NO_KEY); + expressionAttributeNames.put("#source", SOURCE); + Map expressionAttributeValues = new HashMap<>(); + expressionAttributeValues.put(":incr", AttributeValue.builder().n("1").build()); + expressionAttributeValues.put(":source", AttributeValue.builder().m(updateItem).build()); + updateItemRequestBuilder.updateExpression("SET #seqNo = #seqNo + :incr, #source = :source "); if (request.ifSeqNo() != null) { // Get current document version and put in attribute map. Ignore primary term on DDB. - updateItemRequestBuilder - .conditionExpression("#seqNo = :currentSeqNo") - .expressionAttributeNames(Map.of("#seqNo", SEQ_NO_KEY)) - .expressionAttributeValues( - Map.of(":currentSeqNo", AttributeValue.builder().n(Long.toString(request.ifSeqNo())).build()) - ); + updateItemRequestBuilder.conditionExpression("#seqNo = :currentSeqNo"); + expressionAttributeValues.put(":currentSeqNo", AttributeValue.builder().n(Long.toString(request.ifSeqNo())).build()); } + updateItemRequestBuilder + .expressionAttributeNames(expressionAttributeNames) + .expressionAttributeValues(expressionAttributeValues); UpdateItemRequest updateItemRequest = updateItemRequestBuilder.build(); UpdateItemResponse updateItemResponse = dynamoDbClient.updateItem(updateItemRequest); Long sequenceNumber = null; @@ -330,7 +319,7 @@ public CompletionStage deleteDataObjectAsync(DeleteDat */ @Override public CompletionStage searchDataObjectAsync(SearchDataObjectRequest request, Executor executor) { - List indices = Arrays.stream(request.indices()).collect(Collectors.toList()); + List indices = Arrays.stream(request.indices()).map(this::getIndexName).collect(Collectors.toList()); SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest( indices.toArray(new String[0]), @@ -340,6 +329,11 @@ public CompletionStage searchDataObjectAsync(SearchDat return this.remoteClusterIndicesClient.searchDataObjectAsync(searchDataObjectRequest, executor); } + private String getIndexName(String index) { + // System index is not supported in remote index. Replacing '.' from index name. + return index.replaceAll("\\.", ""); + } + private XContentParser createParser(String json) throws IOException { return jsonXContent.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, json); } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index d1234de60f..6c283862e5 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -457,6 +457,8 @@ public void updateDataObjectAsync_HappyCaseWithMap() throws Exception { .index(TEST_INDEX) .tenantId(TENANT_ID) .dataObject(Map.of("foo", "bar")) + .ifSeqNo(10) + .ifPrimaryTerm(10) .build(); Mockito .when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())) @@ -470,7 +472,13 @@ public void updateDataObjectAsync_HappyCaseWithMap() throws Exception { assertEquals(TEST_INDEX, updateItemRequest.tableName()); assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s()); assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); - assertEquals("bar", updateItemRequest.attributeUpdates().get("_source").value().m().get("foo").s()); + assertTrue(updateItemRequest.expressionAttributeNames().containsKey("#seqNo")); + assertTrue(updateItemRequest.expressionAttributeNames().containsKey("#source")); + assertTrue(updateItemRequest.expressionAttributeValues().containsKey(":incr")); + assertTrue(updateItemRequest.expressionAttributeValues().containsKey(":source")); + assertEquals("bar", updateItemRequest.expressionAttributeValues().get(":source").m().get("foo").s()); + assertTrue(updateItemRequest.expressionAttributeValues().containsKey(":currentSeqNo")); + assertNotNull(updateItemRequest.conditionExpression()); UpdateResponse response = UpdateResponse.fromXContent(updateResponse.parser()); Assert.assertEquals(5, response.getSeqNo()); } @@ -549,7 +557,7 @@ public void searchDataObjectAsync_SystemIndex() { assertEquals(searchDataObjectResponse, searchResponse); Mockito.verify(remoteClusterIndicesClient).searchDataObjectAsync(searchDataObjectRequestArgumentCaptor.capture(), Mockito.any()); - Assert.assertEquals(".test_index", searchDataObjectRequestArgumentCaptor.getValue().indices()[0]); + Assert.assertEquals("test_index", searchDataObjectRequestArgumentCaptor.getValue().indices()[0]); } private Map getComplexDataSource() {