diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 29a78fe81d..94d3443b59 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -24,6 +24,8 @@ jobs: strategy: matrix: java: [11, 17, 21] + env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true name: Build and Test MLCommons Plugin on linux if: github.repository == 'opensearch-project/ml-commons' diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java index 99cb757bc4..aee473b5f9 100644 --- a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -10,6 +10,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import java.io.IOException; import java.util.Map; @@ -19,6 +20,8 @@ public class UpdateDataObjectRequest { private final String index; private final String id; private final String tenantId; + private final Long ifSeqNo; + private final Long ifPrimaryTerm; private final ToXContentObject dataObject; /** @@ -28,12 +31,16 @@ public class UpdateDataObjectRequest { * @param index the index location to update the object * @param id the document id * @param tenantId the tenant id + * @param ifSeqNo the sequence number to match or null if not required + * @param ifPrimaryTerm the primary term to match or null if not required * @param dataObject the data object */ - public UpdateDataObjectRequest(String index, String id, String tenantId, ToXContentObject dataObject) { + public UpdateDataObjectRequest(String index, String id, String tenantId, Long ifSeqNo, Long ifPrimaryTerm, ToXContentObject dataObject) { this.index = index; this.id = id; this.tenantId = tenantId; + this.ifSeqNo = ifSeqNo; + this.ifPrimaryTerm = ifPrimaryTerm; this.dataObject = dataObject; } @@ -61,6 +68,22 @@ public String tenantId() { return this.tenantId; } + /** + * Returns the sequence number to match, or null if no match required + * @return the ifSeqNo + */ + public Long ifSeqNo() { + return ifSeqNo; + } + + /** + * Returns the primary term to match, or null if no match required + * @return the ifPrimaryTerm + */ + public Long ifPrimaryTerm() { + return ifPrimaryTerm; + } + /** * Returns the data object * @return the data object @@ -84,6 +107,8 @@ public static class Builder { private String index = null; private String id = null; private String tenantId = null; + private Long ifSeqNo = null; + private Long ifPrimaryTerm = null; private ToXContentObject dataObject = null; /** @@ -111,7 +136,7 @@ public Builder id(String id) { return this; } - /** + /** * Add a tenant ID to this builder * @param tenantId the tenant id * @return the updated builder @@ -120,7 +145,35 @@ public Builder tenantId(String tenantId) { this.tenantId = tenantId; return this; } - + + /** + * Only perform this update request if the document's modification was assigned the given + * sequence number. Must be used in combination with {@link #ifPrimaryTerm(long)} + *

+ * Sequence number may be represented by a different document versioning key on non-OpenSearch data stores. + */ + public Builder ifSeqNo(long seqNo) { + if (seqNo < 0 && seqNo != UNASSIGNED_SEQ_NO) { + throw new IllegalArgumentException("sequence numbers must be non negative. got [" + seqNo + "]."); + } + this.ifSeqNo = seqNo; + return this; + } + + /** + * Only performs this update request if the document's last modification was assigned the given + * primary term. Must be used in combination with {@link #ifSeqNo(long)} + *

+ * Primary term may not be relevant on non-OpenSearch data stores. + */ + public Builder ifPrimaryTerm(long term) { + if (term < 0) { + throw new IllegalArgumentException("primary term must be non negative. got [" + term + "]"); + } + this.ifPrimaryTerm = term; + return this; + } + /** * Add a data object to this builder * @param dataObject the data object @@ -150,7 +203,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws * @return A {@link UpdateDataObjectRequest} */ public UpdateDataObjectRequest build() { - return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.dataObject); + if ((ifSeqNo == null) != (ifPrimaryTerm == null)) { + throw new IllegalArgumentException("Either ifSeqNo and ifPrimaryTerm must both be null or both must be non-null."); + } + return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.ifSeqNo, this.ifPrimaryTerm, this.dataObject); } } } diff --git a/common/src/test/java/org/opensearch/sdk/SdkClientUtilsTests.java b/common/src/test/java/org/opensearch/sdk/SdkClientUtilsTests.java index af5e274733..307084b848 100644 --- a/common/src/test/java/org/opensearch/sdk/SdkClientUtilsTests.java +++ b/common/src/test/java/org/opensearch/sdk/SdkClientUtilsTests.java @@ -10,7 +10,6 @@ import org.junit.Before; import org.junit.Test; -import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.rest.RestStatus; diff --git a/common/src/test/java/org/opensearch/sdk/UpdateDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectRequestTests.java index b31ad53691..578f4557f7 100644 --- a/common/src/test/java/org/opensearch/sdk/UpdateDataObjectRequestTests.java +++ b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectRequestTests.java @@ -15,10 +15,13 @@ import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.sdk.UpdateDataObjectRequest.Builder; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; public class UpdateDataObjectRequestTests { @@ -26,6 +29,8 @@ public class UpdateDataObjectRequestTests { private String testIndex; private String testId; private String testTenantId; + private Long testSeqNo; + private Long testPrimaryTerm; private ToXContentObject testDataObject; private Map testDataObjectMap; @@ -34,6 +39,8 @@ public void setUp() { testIndex = "test-index"; testId = "test-id"; testTenantId = "test-tenant-id"; + testSeqNo = 42L; + testPrimaryTerm = 6L; testDataObject = mock(ToXContentObject.class); testDataObjectMap = Map.of("foo", "bar"); } @@ -46,6 +53,8 @@ public void testUpdateDataObjectRequest() { assertEquals(testId, request.id()); assertEquals(testTenantId, request.tenantId()); assertEquals(testDataObject, request.dataObject()); + assertNull(request.ifSeqNo()); + assertNull(request.ifPrimaryTerm()); } @Test @@ -57,4 +66,26 @@ public void testUpdateDataObjectMapRequest() { assertEquals(testTenantId, request.tenantId()); assertEquals(testDataObjectMap, XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(XContentType.JSON, request.dataObject()), false)); } + + @Test + public void testUpdateDataObjectRequestConcurrency() { + UpdateDataObjectRequest request = UpdateDataObjectRequest.builder().index(testIndex).id(testId).tenantId(testTenantId).dataObject(testDataObject) + .ifSeqNo(testSeqNo).ifPrimaryTerm(testPrimaryTerm).build(); + + assertEquals(testIndex, request.index()); + assertEquals(testId, request.id()); + assertEquals(testTenantId, request.tenantId()); + assertEquals(testDataObject, request.dataObject()); + assertEquals(testSeqNo, request.ifSeqNo()); + assertEquals(testPrimaryTerm, request.ifPrimaryTerm()); + + final Builder badSeqNoBuilder = UpdateDataObjectRequest.builder(); + assertThrows(IllegalArgumentException.class, () -> badSeqNoBuilder.ifSeqNo(-99)); + final Builder badPrimaryTermBuilder = UpdateDataObjectRequest.builder(); + assertThrows(IllegalArgumentException.class, () -> badPrimaryTermBuilder.ifPrimaryTerm(-99)); + final Builder onlySeqNoBuilder = UpdateDataObjectRequest.builder().ifSeqNo(testSeqNo); + assertThrows(IllegalArgumentException.class, () -> onlySeqNoBuilder.build()); + final Builder onlyPrimaryTermBuilder = UpdateDataObjectRequest.builder().ifPrimaryTerm(testPrimaryTerm); + assertThrows(IllegalArgumentException.class, () -> onlyPrimaryTermBuilder.build()); + } } diff --git a/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java index f77941f5b3..7762305f5f 100644 --- a/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java +++ b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java @@ -10,9 +10,6 @@ import org.junit.Before; import org.junit.Test; -import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; -import org.opensearch.core.common.Strings; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.XContentParser; import static org.junit.Assert.assertEquals; diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index 22ff3f3254..e8db491cc9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -615,16 +615,6 @@ private UpdateDataObjectRequest createUpdateModelGroupRequest( ) { modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - /* Old code here. TODO investigate if we need to add seqNo and primaryTerm to data object request - UpdateRequest updateModelGroupRequest = new UpdateRequest(); - updateModelGroupRequest - .index(ML_MODEL_GROUP_INDEX) - .id(modelGroupId) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .doc(modelGroupSourceMap); - */ ToXContentObject dataObject = new ToXContentObject() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { @@ -635,7 +625,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder.endObject(); } }; - return UpdateDataObjectRequest.builder().index(ML_MODEL_GROUP_INDEX).id(modelGroupId).dataObject(dataObject).build(); + return UpdateDataObjectRequest + .builder() + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .ifSeqNo(seqNo) + .ifPrimaryTerm(primaryTerm) + .dataObject(dataObject) + .build(); } private Boolean isModelDeployed(MLModelState mlModelState) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 96e756d739..187a0ccbc1 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -385,15 +385,6 @@ public void registerMLRemoteModel( if (getModelGroupResponse.isExists()) { Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); int updatedVersion = incrementLatestVersion(modelGroupSourceMap); - /* TODO UpdateDataObjectRequest needs to track response seqNo + primary term - UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( - modelGroupSourceMap, - modelGroupId, - getModelGroupResponse.getSeqNo(), - getModelGroupResponse.getPrimaryTerm(), - updatedVersion - ); - */ modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest @@ -401,10 +392,8 @@ public void registerMLRemoteModel( .index(ML_MODEL_GROUP_INDEX) .id(modelGroupId) .tenantId(mlRegisterModelInput.getTenantId()) - // TODO need to track these for concurrency - // .setIfSeqNo(seqNo) - // .setIfPrimaryTerm(primaryTerm) - // .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .ifSeqNo(getModelGroupResponse.getSeqNo()) + .ifPrimaryTerm(getModelGroupResponse.getPrimaryTerm()) .dataObject(modelGroupSourceMap) .build(); sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> { diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index 6e42e2cbcb..d4aca6cb6b 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -58,6 +58,7 @@ import software.amazon.awssdk.services.dynamodb.model.AttributeAction; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValueUpdate; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; @@ -74,8 +75,9 @@ public class DDBOpenSearchClient implements SdkClient { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private static final String DEFAULT_TENANT = "DEFAULT_TENANT"; - private static final String HASH_KEY = "tenant_id"; - private static final String RANGE_KEY = "id"; + private static final String HASH_KEY = "_tenant_id"; + private static final String RANGE_KEY = "_id"; + private static final String SEQ_NO_KEY = "_seq_no"; private DynamoDbClient dynamoDbClient; private RemoteClusterIndicesClient remoteClusterIndicesClient; @@ -112,7 +114,11 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe item.put(RANGE_KEY, AttributeValue.builder().s(id).build()); final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build(); + // TODO need to initialize/return SEQ_NO here + // If document doesn't exist, return 0 + // If document exists, overwrite and increment and return SEQ_NO dynamoDbClient.putItem(putItemRequest); + // TODO need to pass seqNo to simulated response String simulatedIndexResponse = simulateOpenSearchResponse(request.index(), id, source, Map.of("result", "created")); return PutDataObjectResponse.builder().id(id).parser(createParser(simulatedIndexResponse)).build(); } catch (IOException e) { @@ -137,6 +143,7 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe .ofEntries( Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) + // TODO need to fetch SEQ_NO_KEY ) ) .build(); @@ -198,16 +205,33 @@ public CompletionStage updateDataObjectAsync(UpdateDat Map updateKey = new HashMap<>(); updateKey.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); updateKey.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); - UpdateItemRequest updateItemRequest = UpdateItemRequest + UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest .builder() .tableName(getTableName(request.index())) .key(updateKey) - .attributeUpdates(updateAttributeValue) - .build(); + .attributeUpdates(updateAttributeValue); + if (request.ifSeqNo() != null) { + // Get current document version and put in attribute map. Ignore primary term on DDB. + int currentSeqNo = jsonNode.has(SEQ_NO_KEY) ? jsonNode.get(SEQ_NO_KEY).asInt() : 0; + updateItemRequestBuilder + .conditionExpression("#seqNo = :currentSeqNo") + .expressionAttributeNames(Map.of("#seqNo", SEQ_NO_KEY)) + .expressionAttributeValues( + Map.of(":currentSeqNo", AttributeValue.builder().n(Integer.toString(currentSeqNo)).build()) + ); + } + UpdateItemRequest updateItemRequest = updateItemRequestBuilder.build(); dynamoDbClient.updateItem(updateItemRequest); - + // TODO need to pass seqNo to simulated response String simulatedUpdateResponse = simulateOpenSearchResponse(request.index(), request.id(), source, Map.of("found", true)); return UpdateDataObjectResponse.builder().id(request.id()).parser(createParser(simulatedUpdateResponse)).build(); + } catch (ConditionalCheckFailedException ccfe) { + log.error("Document version conflict updating {} in {}: {}", request.id(), request.index(), ccfe.getMessage(), ccfe); + // Rethrow + throw new OpenSearchStatusException( + "Document version conflict updating " + request.id() + " in index " + request.index(), + RestStatus.CONFLICT + ); } catch (IOException e) { log.error("Error updating {} in {}: {}", request.id(), request.index(), e.getMessage(), e); // Rethrow unchecked exception on update IOException @@ -239,7 +263,12 @@ public CompletionStage deleteDataObjectAsync(DeleteDat .build(); return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { + // TODO need to return SEQ_NO here + // If document doesn't exist, increment and return highest seq no ever seen, but we would have to track seqNo here + // If document never existed, return -2 (unassigned) for seq no (probably what we have to do here) + // If document exists, increment and return SEQ_NO dynamoDbClient.deleteItem(deleteItemRequest); + // TODO need to pass seqNo to simulated response String simulatedDeleteResponse = simulateOpenSearchResponse( request.index(), request.id(), diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 7bc2bd9468..84ee7e8b47 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -41,6 +41,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -129,17 +130,28 @@ public CompletionStage updateDataObjectAsync(UpdateDat return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { log.info("Updating {} from {}", request.id(), request.index()); - UpdateResponse updateResponse = client - .update( - new UpdateRequest(request.index(), request.id()).doc(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)) - ) - .actionGet(); + UpdateRequest updateRequest = new UpdateRequest(request.index(), request.id()) + .doc(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)); + if (request.ifSeqNo() != null) { + updateRequest.setIfSeqNo(request.ifSeqNo()); + } + if (request.ifPrimaryTerm() != null) { + updateRequest.setIfPrimaryTerm(request.ifPrimaryTerm()); + } + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); if (updateResponse == null) { log.info("Null UpdateResponse"); return UpdateDataObjectResponse.builder().id(request.id()).parser(null).build(); } log.info("Update status for id {}: {}", updateResponse.getId(), updateResponse.getResult()); return UpdateDataObjectResponse.builder().id(updateResponse.getId()).parser(createParser(updateResponse)).build(); + } catch (VersionConflictEngineException vcee) { + log.error("Document version conflict updating {} in {}: {}", request.id(), request.index(), vcee.getMessage(), vcee); + // Rethrow + throw new OpenSearchStatusException( + "Document version conflict updating " + request.id() + " in index " + request.index(), + RestStatus.CONFLICT + ); } catch (IOException e) { // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException( diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index 3473ad20ee..145341f69a 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -25,6 +25,7 @@ import org.opensearch.client.json.JsonpMapper; import org.opensearch.client.json.JsonpSerializable; import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch._types.OpenSearchException; import org.opensearch.client.opensearch.core.DeleteRequest; import org.opensearch.client.opensearch.core.DeleteResponse; import org.opensearch.client.opensearch.core.GetRequest; @@ -34,6 +35,7 @@ import org.opensearch.client.opensearch.core.SearchRequest; import org.opensearch.client.opensearch.core.SearchResponse; import org.opensearch.client.opensearch.core.UpdateRequest; +import org.opensearch.client.opensearch.core.UpdateRequest.Builder; import org.opensearch.client.opensearch.core.UpdateResponse; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; @@ -130,15 +132,30 @@ public CompletionStage updateDataObjectAsync(UpdateDat Map docMap = JsonXContent.jsonXContent .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, builder.toString()) .map(); - UpdateRequest, ?> updateRequest = new UpdateRequest.Builder, Map>() - .index(request.index()) - .id(request.id()) - .doc(docMap) - .build(); + Builder, Map> updateRequestBuilder = + new UpdateRequest.Builder, Map>() + .index(request.index()) + .id(request.id()) + .doc(docMap); + if (request.ifSeqNo() != null) { + updateRequestBuilder.ifSeqNo(request.ifSeqNo()); + } + if (request.ifPrimaryTerm() != null) { + updateRequestBuilder.ifPrimaryTerm(request.ifPrimaryTerm()); + } + UpdateRequest, ?> updateRequest = updateRequestBuilder.build(); log.info("Updating {} in {}", request.id(), request.index()); UpdateResponse> updateResponse = openSearchClient.update(updateRequest, MAP_DOCTYPE); log.info("Update status for id {}: {}", updateResponse.id(), updateResponse.result()); return UpdateDataObjectResponse.builder().id(updateResponse.id()).parser(createParser(updateResponse)).build(); + } catch (OpenSearchException ose) { + String errorType = ose.status() == RestStatus.CONFLICT.getStatus() ? "Document Version Conflict" : "Failed"; + log.error("{} updating {} in {}: {}", errorType, request.id(), request.index(), ose.getMessage(), ose); + // Rethrow + throw new OpenSearchStatusException( + errorType + " updating " + request.id() + " in index " + request.index(), + RestStatus.fromCode(ose.status()) + ); } catch (IOException e) { log.error("Error updating {} in {}: {}", request.id(), request.index(), e.getMessage(), e); // Rethrow unchecked exception on update IOException diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index b7fe5841a4..5a393dbc14 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -9,6 +9,7 @@ package org.opensearch.ml.sdkclient; +import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; @@ -30,6 +31,7 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; @@ -41,6 +43,7 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; @@ -65,6 +68,7 @@ import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; @@ -76,6 +80,9 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase { + private static final String HASH_KEY = "_tenant_id"; + private static final String RANGE_KEY = "_id"; + private static final String TEST_ID = "123"; private static final String TENANT_ID = "TEST_TENANT_ID"; private static final String TEST_INDEX = "test_index"; @@ -146,8 +153,8 @@ public void testPutDataObject_HappyCase() throws IOException { PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, putItemRequest.tableName()); - Assert.assertEquals(TEST_ID, putItemRequest.item().get("id").s()); - Assert.assertEquals(TENANT_ID, putItemRequest.item().get("tenant_id").s()); + Assert.assertEquals(TEST_ID, putItemRequest.item().get(RANGE_KEY).s()); + Assert.assertEquals(TENANT_ID, putItemRequest.item().get(HASH_KEY).s()); Assert.assertEquals("foo", putItemRequest.item().get("data").s()); } @@ -192,7 +199,7 @@ public void testPutDataObject_NullTenantId_SetsDefaultTenantId() throws IOExcept Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); - Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get("tenant_id").s()); + Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get(HASH_KEY).s()); } @Test @@ -206,7 +213,7 @@ public void testPutDataObject_NullId_SetsDefaultTenantId() throws IOException { Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); - Assert.assertNotNull(putItemRequest.item().get("id").s()); + Assert.assertNotNull(putItemRequest.item().get(RANGE_KEY).s()); Assert.assertNotNull(response.id()); } @@ -237,8 +244,8 @@ public void testGetDataObject_HappyCase() throws IOException { Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, getItemRequest.tableName()); - Assert.assertEquals(TENANT_ID, getItemRequest.key().get("tenant_id").s()); - Assert.assertEquals(TEST_ID, getItemRequest.key().get("id").s()); + Assert.assertEquals(TENANT_ID, getItemRequest.key().get(HASH_KEY).s()); + Assert.assertEquals(TEST_ID, getItemRequest.key().get(RANGE_KEY).s()); Assert.assertEquals(TEST_ID, response.id()); Assert.assertEquals("foo", response.source().get("data")); XContentParser parser = response.parser(); @@ -319,7 +326,7 @@ public void testGetDataObject_UseDefaultTenantIdIfNull() throws IOException { sdkClient.getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); - Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get("tenant_id").s()); + Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get(HASH_KEY).s()); } @Test @@ -343,8 +350,8 @@ public void testDeleteDataObject_HappyCase() throws IOException { .join(); DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, deleteItemRequest.tableName()); - Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get("tenant_id").s()); - Assert.assertEquals(TEST_ID, deleteItemRequest.key().get("id").s()); + Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get(HASH_KEY).s()); + Assert.assertEquals(TEST_ID, deleteItemRequest.key().get(RANGE_KEY).s()); Assert.assertEquals(TEST_ID, deleteResponse.id()); DeleteResponse deleteActionResponse = DeleteResponse.fromXContent(deleteResponse.parser()); @@ -361,7 +368,7 @@ public void testDeleteDataObject_NullTenantId_UsesDefaultTenantId() { Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())).thenReturn(DeleteItemResponse.builder().build()); sdkClient.deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); - Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get("tenant_id").s()); + Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get(HASH_KEY).s()); } @Test @@ -382,8 +389,8 @@ public void updateDataObjectAsync_HappyCase() { UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); assertEquals(TEST_ID, updateRequest.id()); assertEquals(TEST_INDEX, updateItemRequest.tableName()); - assertEquals(TEST_ID, updateItemRequest.key().get("id").s()); - assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); + assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s()); + assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); assertEquals("foo", updateItemRequest.attributeUpdates().get("data").value().s()); } @@ -406,10 +413,9 @@ public void updateDataObjectAsync_HappyCaseWithMap() { UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); assertEquals(TEST_ID, updateRequest.id()); assertEquals(TEST_INDEX, updateItemRequest.tableName()); - assertEquals(TEST_ID, updateItemRequest.key().get("id").s()); - assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); + assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s()); + assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); assertEquals("bar", updateItemRequest.attributeUpdates().get("foo").value().s()); - } @Test @@ -424,7 +430,30 @@ public void updateDataObjectAsync_NullTenantId_UsesDefaultTenantId() { Mockito.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenReturn(UpdateItemResponse.builder().build()); sdkClient.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); - assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); + assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); + } + + public void testUpdateDataObject_VersionCheck() throws IOException { + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .ifSeqNo(5) + .ifPrimaryTerm(2) + .build(); + + ConditionalCheckFailedException conflictException = ConditionalCheckFailedException.builder().build(); + when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenThrow(conflictException); + + CompletableFuture future = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + Throwable cause = ce.getCause(); + assertEquals(OpenSearchStatusException.class, cause.getClass()); + assertEquals(RestStatus.CONFLICT, ((OpenSearchStatusException) cause).status()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index 9d28e61650..9175cb5462 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -62,6 +62,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.get.GetResult; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; @@ -428,6 +429,34 @@ public void testUpdateDataObject_Exception() throws IOException { assertEquals("test", cause.getMessage()); } + public void testUpdateDataObject_VersionCheck() throws IOException { + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .ifSeqNo(5) + .ifPrimaryTerm(2) + .build(); + + ArgumentCaptor updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + VersionConflictEngineException conflictException = new VersionConflictEngineException( + new ShardId(TEST_INDEX, "_na_", 0), + TEST_ID, + "test" + ); + when(mockedClient.update(updateRequestCaptor.capture())).thenThrow(conflictException); + + CompletableFuture future = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + Throwable cause = ce.getCause(); + assertEquals(OpenSearchStatusException.class, cause.getClass()); + assertEquals(RestStatus.CONFLICT, ((OpenSearchStatusException) cause).status()); + } + public void testDeleteDataObject() throws IOException { DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().index(TEST_INDEX).id(TEST_ID).build(); diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index 7dbe7062d6..2557dfd16f 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -32,6 +32,9 @@ import org.opensearch.action.DocWriteResponse; import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch._types.ErrorCause; +import org.opensearch.client.opensearch._types.ErrorResponse; +import org.opensearch.client.opensearch._types.OpenSearchException; import org.opensearch.client.opensearch._types.Result; import org.opensearch.client.opensearch._types.ShardStatistics; import org.opensearch.client.opensearch.core.DeleteRequest; @@ -54,6 +57,7 @@ import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; @@ -403,6 +407,35 @@ public void testtUpdateDataObject_Exception() throws IOException { assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } + public void testUpdateDataObject_VersionCheck() throws IOException { + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .ifSeqNo(5) + .ifPrimaryTerm(2) + .build(); + + ArgumentCaptor> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + OpenSearchException conflictException = new OpenSearchException( + new ErrorResponse.Builder() + .status(RestStatus.CONFLICT.getStatus()) + .error(new ErrorCause.Builder().type("test").reason("test").build()) + .build() + ); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenThrow(conflictException); + + CompletableFuture future = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + Throwable cause = ce.getCause(); + assertEquals(OpenSearchStatusException.class, cause.getClass()); + assertEquals(RestStatus.CONFLICT, ((OpenSearchStatusException) cause).status()); + } + public void testDeleteDataObject() throws IOException { DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().index(TEST_INDEX).id(TEST_ID).build();