Skip to content

Commit

Permalink
Adding ElasticsearchInternalServiceModelValidator to stop model deplo…
Browse files Browse the repository at this point in the history
…yment on failed validation
  • Loading branch information
dan-rubinstein committed Nov 12, 2024
1 parent f80f054 commit 24619f9
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ void chunkedInfer(
/**
* Stop the model deployment.
* The default action does nothing except acknowledge the request (true).
* @param unparsedModel The unparsed model configuration
* @param model The model configuration
* @param listener The listener
*/
default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
default void stop(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ private void doExecuteForked(

var service = serviceRegistry.getService(unparsedModel.service());
if (service.isPresent()) {
service.get().stop(unparsedModel, listener);
var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
service.get().stop(model, listener);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
Expand Down Expand Up @@ -119,9 +118,7 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
}

@Override
public void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {

var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
public void stop(Model model, ActionListener<Boolean> listener) {
if (model instanceof ElasticsearchInternalModel esModel) {

var serviceSettings = esModel.getServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M

@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
ModelValidatorBuilder.buildModelValidator(model.getTaskType(), true).validate(this, model, listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.validation;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;

public class ElasticsearchInternalServiceModelValidator implements ModelValidator {

ModelValidator modelValidator;

public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) {
this.modelValidator = modelValidator;
}

@Override
public void validate(InferenceService service, Model model, ActionListener<Model> listener) {
modelValidator.validate(service, model, listener.delegateResponse((l, exception) -> {
// TODO: Cleanup the below code
service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception)));
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
import org.elasticsearch.inference.TaskType;

public class ModelValidatorBuilder {
public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) {
var modelValidator = buildModelValidator(taskType);
if (isElasticsearchInternalService) {
return new ElasticsearchInternalServiceModelValidator(modelValidator);
} else {
return modelValidator;
}
}

public static ModelValidator buildModelValidator(TaskType taskType) {
if (taskType == null) {
throw new IllegalArgumentException("Task type can't be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
);

var request = (InferModelAction.Request) invocationOnMock.getArguments()[1];
assertThat(request.getId(), is("custom-model"));
assertThat(request.getId(), is(randomInferenceEntityId));
return Void.TYPE;
}).when(client).execute(eq(InferModelAction.INSTANCE), any(), any());
when(client.threadPool()).thenReturn(threadPool);
Expand Down

0 comments on commit 24619f9

Please sign in to comment.