Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/multi_tenancy' into mult…
Browse files Browse the repository at this point in the history
…i_tenanct_2
  • Loading branch information
arjunkumargiri committed Jul 10, 2024
2 parents 9159fe9 + a289953 commit 0c5df38
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 64 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
strategy:
matrix:
java: [11, 17, 21]
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true

name: Build and Test MLCommons Plugin on linux
if: github.repository == 'opensearch-project/ml-commons'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;

import java.io.IOException;
import java.util.Map;
Expand All @@ -19,6 +20,8 @@ public class UpdateDataObjectRequest {
private final String index;
private final String id;
private final String tenantId;
private final Long ifSeqNo;
private final Long ifPrimaryTerm;
private final ToXContentObject dataObject;

/**
Expand All @@ -28,12 +31,16 @@ public class UpdateDataObjectRequest {
* @param index the index location to update the object
* @param id the document id
* @param tenantId the tenant id
* @param ifSeqNo the sequence number to match or null if not required
* @param ifPrimaryTerm the primary term to match or null if not required
* @param dataObject the data object
*/
public UpdateDataObjectRequest(String index, String id, String tenantId, ToXContentObject dataObject) {
public UpdateDataObjectRequest(String index, String id, String tenantId, Long ifSeqNo, Long ifPrimaryTerm, ToXContentObject dataObject) {
this.index = index;
this.id = id;
this.tenantId = tenantId;
this.ifSeqNo = ifSeqNo;
this.ifPrimaryTerm = ifPrimaryTerm;
this.dataObject = dataObject;
}

Expand Down Expand Up @@ -61,6 +68,22 @@ public String tenantId() {
return this.tenantId;
}

/**
* Returns the sequence number to match, or null if no match required
* @return the ifSeqNo
*/
public Long ifSeqNo() {
return ifSeqNo;
}

/**
* Returns the primary term to match, or null if no match required
* @return the ifPrimaryTerm
*/
public Long ifPrimaryTerm() {
return ifPrimaryTerm;
}

/**
* Returns the data object
* @return the data object
Expand All @@ -84,6 +107,8 @@ public static class Builder {
private String index = null;
private String id = null;
private String tenantId = null;
private Long ifSeqNo = null;
private Long ifPrimaryTerm = null;
private ToXContentObject dataObject = null;

/**
Expand Down Expand Up @@ -111,7 +136,7 @@ public Builder id(String id) {
return this;
}

/**
/**
* Add a tenant ID to this builder
* @param tenantId the tenant id
* @return the updated builder
Expand All @@ -120,7 +145,35 @@ public Builder tenantId(String tenantId) {
this.tenantId = tenantId;
return this;
}


/**
* Only perform this update request if the document's modification was assigned the given
* sequence number. Must be used in combination with {@link #ifPrimaryTerm(long)}
* <p>
* Sequence number may be represented by a different document versioning key on non-OpenSearch data stores.
*/
public Builder ifSeqNo(long seqNo) {
if (seqNo < 0 && seqNo != UNASSIGNED_SEQ_NO) {
throw new IllegalArgumentException("sequence numbers must be non negative. got [" + seqNo + "].");
}
this.ifSeqNo = seqNo;
return this;
}

/**
* Only performs this update request if the document's last modification was assigned the given
* primary term. Must be used in combination with {@link #ifSeqNo(long)}
* <p>
* Primary term may not be relevant on non-OpenSearch data stores.
*/
public Builder ifPrimaryTerm(long term) {
if (term < 0) {
throw new IllegalArgumentException("primary term must be non negative. got [" + term + "]");
}
this.ifPrimaryTerm = term;
return this;
}

/**
* Add a data object to this builder
* @param dataObject the data object
Expand Down Expand Up @@ -150,7 +203,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
* @return A {@link UpdateDataObjectRequest}
*/
public UpdateDataObjectRequest build() {
return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.dataObject);
if ((ifSeqNo == null) != (ifPrimaryTerm == null)) {
throw new IllegalArgumentException("Either ifSeqNo and ifPrimaryTerm must both be null or both must be non-null.");
}
return new UpdateDataObjectRequest(this.index, this.id, this.tenantId, this.ifSeqNo, this.ifPrimaryTerm, this.dataObject);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.junit.Before;
import org.junit.Test;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.core.rest.RestStatus;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.sdk.UpdateDataObjectRequest.Builder;

import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;

public class UpdateDataObjectRequestTests {

private String testIndex;
private String testId;
private String testTenantId;
private Long testSeqNo;
private Long testPrimaryTerm;
private ToXContentObject testDataObject;
private Map<String, Object> testDataObjectMap;

Expand All @@ -34,6 +39,8 @@ public void setUp() {
testIndex = "test-index";
testId = "test-id";
testTenantId = "test-tenant-id";
testSeqNo = 42L;
testPrimaryTerm = 6L;
testDataObject = mock(ToXContentObject.class);
testDataObjectMap = Map.of("foo", "bar");
}
Expand All @@ -46,6 +53,8 @@ public void testUpdateDataObjectRequest() {
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObject, request.dataObject());
assertNull(request.ifSeqNo());
assertNull(request.ifPrimaryTerm());
}

@Test
Expand All @@ -57,4 +66,26 @@ public void testUpdateDataObjectMapRequest() {
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObjectMap, XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(XContentType.JSON, request.dataObject()), false));
}

@Test
public void testUpdateDataObjectRequestConcurrency() {
UpdateDataObjectRequest request = UpdateDataObjectRequest.builder().index(testIndex).id(testId).tenantId(testTenantId).dataObject(testDataObject)
.ifSeqNo(testSeqNo).ifPrimaryTerm(testPrimaryTerm).build();

assertEquals(testIndex, request.index());
assertEquals(testId, request.id());
assertEquals(testTenantId, request.tenantId());
assertEquals(testDataObject, request.dataObject());
assertEquals(testSeqNo, request.ifSeqNo());
assertEquals(testPrimaryTerm, request.ifPrimaryTerm());

final Builder badSeqNoBuilder = UpdateDataObjectRequest.builder();
assertThrows(IllegalArgumentException.class, () -> badSeqNoBuilder.ifSeqNo(-99));
final Builder badPrimaryTermBuilder = UpdateDataObjectRequest.builder();
assertThrows(IllegalArgumentException.class, () -> badPrimaryTermBuilder.ifPrimaryTerm(-99));
final Builder onlySeqNoBuilder = UpdateDataObjectRequest.builder().ifSeqNo(testSeqNo);
assertThrows(IllegalArgumentException.class, () -> onlySeqNoBuilder.build());
final Builder onlyPrimaryTermBuilder = UpdateDataObjectRequest.builder().ifPrimaryTerm(testPrimaryTerm);
assertThrows(IllegalArgumentException.class, () -> onlyPrimaryTermBuilder.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.XContentParser;

import static org.junit.Assert.assertEquals;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,16 +615,6 @@ private UpdateDataObjectRequest createUpdateModelGroupRequest(
) {
modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion);
modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
/* Old code here. TODO investigate if we need to add seqNo and primaryTerm to data object request
UpdateRequest updateModelGroupRequest = new UpdateRequest();
updateModelGroupRequest
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.setIfSeqNo(seqNo)
.setIfPrimaryTerm(primaryTerm)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.doc(modelGroupSourceMap);
*/
ToXContentObject dataObject = new ToXContentObject() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Expand All @@ -635,7 +625,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder.endObject();
}
};
return UpdateDataObjectRequest.builder().index(ML_MODEL_GROUP_INDEX).id(modelGroupId).dataObject(dataObject).build();
return UpdateDataObjectRequest
.builder()
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.ifSeqNo(seqNo)
.ifPrimaryTerm(primaryTerm)
.dataObject(dataObject)
.build();
}

private Boolean isModelDeployed(MLModelState mlModelState) {
Expand Down
15 changes: 2 additions & 13 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -385,26 +385,15 @@ public void registerMLRemoteModel(
if (getModelGroupResponse.isExists()) {
Map<String, Object> modelGroupSourceMap = getModelGroupResponse.getSourceAsMap();
int updatedVersion = incrementLatestVersion(modelGroupSourceMap);
/* TODO UpdateDataObjectRequest needs to track response seqNo + primary term
UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest(
modelGroupSourceMap,
modelGroupId,
getModelGroupResponse.getSeqNo(),
getModelGroupResponse.getPrimaryTerm(),
updatedVersion
);
*/
modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion);
modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest
.builder()
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.tenantId(mlRegisterModelInput.getTenantId())
// TODO need to track these for concurrency
// .setIfSeqNo(seqNo)
// .setIfPrimaryTerm(primaryTerm)
// .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.ifSeqNo(getModelGroupResponse.getSeqNo())
.ifPrimaryTerm(getModelGroupResponse.getPrimaryTerm())
.dataObject(modelGroupSourceMap)
.build();
sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import software.amazon.awssdk.services.dynamodb.model.AttributeAction;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.AttributeValueUpdate;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
Expand All @@ -74,8 +75,9 @@ public class DDBOpenSearchClient implements SdkClient {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final String DEFAULT_TENANT = "DEFAULT_TENANT";

private static final String HASH_KEY = "tenant_id";
private static final String RANGE_KEY = "id";
private static final String HASH_KEY = "_tenant_id";
private static final String RANGE_KEY = "_id";
private static final String SEQ_NO_KEY = "_seq_no";

private DynamoDbClient dynamoDbClient;
private RemoteClusterIndicesClient remoteClusterIndicesClient;
Expand Down Expand Up @@ -112,7 +114,11 @@ public CompletionStage<PutDataObjectResponse> putDataObjectAsync(PutDataObjectRe
item.put(RANGE_KEY, AttributeValue.builder().s(id).build());
final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build();

// TODO need to initialize/return SEQ_NO here
// If document doesn't exist, return 0
// If document exists, overwrite and increment and return SEQ_NO
dynamoDbClient.putItem(putItemRequest);
// TODO need to pass seqNo to simulated response
String simulatedIndexResponse = simulateOpenSearchResponse(request.index(), id, source, Map.of("result", "created"));
return PutDataObjectResponse.builder().id(id).parser(createParser(simulatedIndexResponse)).build();
} catch (IOException e) {
Expand All @@ -137,6 +143,7 @@ public CompletionStage<GetDataObjectResponse> getDataObjectAsync(GetDataObjectRe
.ofEntries(
Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()),
Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build())
// TODO need to fetch SEQ_NO_KEY
)
)
.build();
Expand Down Expand Up @@ -198,16 +205,33 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(UpdateDat
Map<String, AttributeValue> updateKey = new HashMap<>();
updateKey.put(HASH_KEY, AttributeValue.builder().s(tenantId).build());
updateKey.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build());
UpdateItemRequest updateItemRequest = UpdateItemRequest
UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest
.builder()
.tableName(getTableName(request.index()))
.key(updateKey)
.attributeUpdates(updateAttributeValue)
.build();
.attributeUpdates(updateAttributeValue);
if (request.ifSeqNo() != null) {
// Get current document version and put in attribute map. Ignore primary term on DDB.
int currentSeqNo = jsonNode.has(SEQ_NO_KEY) ? jsonNode.get(SEQ_NO_KEY).asInt() : 0;
updateItemRequestBuilder
.conditionExpression("#seqNo = :currentSeqNo")
.expressionAttributeNames(Map.of("#seqNo", SEQ_NO_KEY))
.expressionAttributeValues(
Map.of(":currentSeqNo", AttributeValue.builder().n(Integer.toString(currentSeqNo)).build())
);
}
UpdateItemRequest updateItemRequest = updateItemRequestBuilder.build();
dynamoDbClient.updateItem(updateItemRequest);

// TODO need to pass seqNo to simulated response
String simulatedUpdateResponse = simulateOpenSearchResponse(request.index(), request.id(), source, Map.of("found", true));
return UpdateDataObjectResponse.builder().id(request.id()).parser(createParser(simulatedUpdateResponse)).build();
} catch (ConditionalCheckFailedException ccfe) {
log.error("Document version conflict updating {} in {}: {}", request.id(), request.index(), ccfe.getMessage(), ccfe);
// Rethrow
throw new OpenSearchStatusException(
"Document version conflict updating " + request.id() + " in index " + request.index(),
RestStatus.CONFLICT
);
} catch (IOException e) {
log.error("Error updating {} in {}: {}", request.id(), request.index(), e.getMessage(), e);
// Rethrow unchecked exception on update IOException
Expand Down Expand Up @@ -239,7 +263,12 @@ public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(DeleteDat
.build();
return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction<DeleteDataObjectResponse>) () -> {
try {
// TODO need to return SEQ_NO here
// If document doesn't exist, increment and return highest seq no ever seen, but we would have to track seqNo here
// If document never existed, return -2 (unassigned) for seq no (probably what we have to do here)
// If document exists, increment and return SEQ_NO
dynamoDbClient.deleteItem(deleteItemRequest);
// TODO need to pass seqNo to simulated response
String simulatedDeleteResponse = simulateOpenSearchResponse(
request.index(),
request.id(),
Expand Down
Loading

0 comments on commit 0c5df38

Please sign in to comment.