Skip to content

Commit

Permalink
Bump RCF Version and Fix Default Rules Bug in AnomalyDetector (#1334)
Browse files Browse the repository at this point in the history
* Updated RCF version to the latest release.
* Fixed a bug in AnomalyDetector where default rules were not applied when the user provided an empty ruleset.

Testing:
* Added unit tests to cover the bug fix

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo authored Oct 11, 2024
1 parent 5704a16 commit 2ab6dc7
Show file tree
Hide file tree
Showing 12 changed files with 437 additions and 26 deletions.
9 changes: 3 additions & 6 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.12.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.1.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.2.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.2.0'

// we inherit jackson-core from opensearch core
implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1"
Expand Down Expand Up @@ -700,9 +700,6 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler',
'org.opensearch.timeseries.transport.SingleStreamResultRequest',
'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap',
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public AnomalyDetector(

this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name();

this.rules = rules == null ? getDefaultRule() : rules;
this.rules = rules == null || rules.isEmpty() ? getDefaultRule() : rules;
}

/*
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/opensearch/timeseries/JobProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void process(Job jobParameter, JobExecutionContext context) {
* @param executionStartTime analysis start time
* @param executionEndTime analysis end time
* @param recorder utility to record job execution result
* @param detector associated detector accessor
* @param config associated config accessor
*/
public void runJob(
Job jobParameter,
Expand All @@ -209,7 +209,7 @@ public void runJob(
Instant executionStartTime,
Instant executionEndTime,
ExecuteResultResponseRecorderType recorder,
Config detector
Config config
) {
String configId = jobParameter.getName();
if (lock == null) {
Expand All @@ -222,7 +222,7 @@ public void runJob(
"Can't run job due to null lock",
false,
recorder,
detector
config
);
return;
}
Expand All @@ -243,7 +243,7 @@ public void runJob(
user,
roles,
recorder,
detector
config
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,18 @@ protected void executeRequest(FeatureRequest coldStartRequest, ActionListener<Vo
);
IntermediateResultType result = modelManager.getResult(currentSample, modelState, modelId, config, taskId);
resultSaver.saveResult(result, config, coldStartRequest, modelId);
}

// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
// only load model to memory for real time analysis that has no task id
if (null == coldStartRequest.getTaskId()) {
boolean hosted = cacheProvider.hostIfPossible(configOptional.get(), modelState);
LOG
.debug(
hosted
? new ParameterizedMessage("Loaded model {}.", modelState.getModelId())
: new ParameterizedMessage("Failed to load model {}.", modelState.getModelId())
);
}
}

} finally {
listener.onResponse(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ public void stopJob(String configId, TransportService transportService, ActionLi
}));
}

private ActionListener<StopConfigResponse> stopConfigListener(
public ActionListener<StopConfigResponse> stopConfigListener(
String configId,
TransportService transportService,
ActionListener<JobResponse> listener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void bulk(String resultIndexOrAlias, List<ResultType> results, String con
} catch (Exception e) {
String error = "Failed to bulk index result";
LOG.error(error, e);
listener.onFailure(new TimeSeriesException(error, e));
listener.onFailure(new TimeSeriesException(configId, error, e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected String genDetector(String datasetName, int intervalMinutes, int trainT
if (relative) {
thresholdType1 = "actual_over_expected_ratio";
thresholdType2 = "expected_over_actual_ratio";
value = 0.3;
value = 0.2;
} else {
thresholdType1 = "actual_over_expected_margin";
thresholdType2 = "expected_over_actual_margin";
Expand Down
63 changes: 63 additions & 0 deletions src/test/java/org/opensearch/ad/ml/ADColdStartTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ad.ml;

import java.io.IOException;
import java.util.ArrayList;

import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.timeseries.TestHelpers;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;

public class ADColdStartTests extends OpenSearchTestCase {
private int baseDimensions = 1;
private int shingleSize = 8;
private int dimensions;

@Override
public void setUp() throws Exception {
super.setUp();
dimensions = baseDimensions * shingleSize;
}

/**
* Test if no explicit rule is provided, we apply 20% rule.
* @throws IOException when failing to constructor detector
*/
public void testEmptyRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(new ArrayList<>()).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}

public void testNullRule() throws IOException {
AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder.newInstance(1).setRules(null).build();
ThresholdedRandomCutForest.Builder builder = new ThresholdedRandomCutForest.Builder<>()
.dimensions(dimensions)
.shingleSize(shingleSize);
ADColdStart.applyRule(builder, detector);

ThresholdedRandomCutForest forest = builder.build();
double[] ignore = forest.getPredictorCorrector().getIgnoreNearExpected();

// Specify a small delta for floating-point comparison
double delta = 1e-6;

assertArrayEquals("The double arrays are not equal", new double[] { 0, 0, 0.2, 0.2 }, ignore, delta);
}
}
29 changes: 29 additions & 0 deletions src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.apache.lucene.search.TotalHits;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.Version;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
Expand Down Expand Up @@ -104,6 +105,7 @@
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -136,6 +138,7 @@
import org.opensearch.timeseries.transport.JobResponse;
import org.opensearch.timeseries.transport.StatsNodeResponse;
import org.opensearch.timeseries.transport.StatsNodesResponse;
import org.opensearch.timeseries.transport.StopConfigResponse;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.timeseries.util.DiscoveryNodeFilterer;
import org.opensearch.transport.TransportResponseHandler;
Expand Down Expand Up @@ -1544,4 +1547,30 @@ public void testDeleteTaskDocs() {
verify(adTaskCacheManager, times(1)).addDeletedTask(anyString());
verify(function, times(1)).execute();
}

public void testStopConfigListener_onResponse_failure() {
// Arrange
String configId = randomAlphaOfLength(5);
TransportService transportService = mock(TransportService.class);
@SuppressWarnings("unchecked")
ActionListener<JobResponse> listener = mock(ActionListener.class);

// Act
ActionListener<StopConfigResponse> stopConfigListener = indexAnomalyDetectorJobActionHandler
.stopConfigListener(configId, transportService, listener);
StopConfigResponse stopConfigResponse = mock(StopConfigResponse.class);
when(stopConfigResponse.success()).thenReturn(false);

stopConfigListener.onResponse(stopConfigResponse);

// Assert
ArgumentCaptor<OpenSearchStatusException> exceptionCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);

verify(adTaskManager, times(1))
.stopLatestRealtimeTask(eq(configId), eq(TaskState.FAILED), exceptionCaptor.capture(), eq(transportService), eq(listener));

OpenSearchStatusException capturedException = exceptionCaptor.getValue();
assertEquals("Failed to delete model", capturedException.getMessage());
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, capturedException.status());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.time.Clock;
import java.util.Optional;

import org.opensearch.ResourceAlreadyExistsException;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.bulk.BulkAction;
Expand All @@ -43,11 +44,13 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler;
import org.opensearch.timeseries.util.ClientUtil;
import org.opensearch.timeseries.util.IndexUtils;
Expand Down Expand Up @@ -232,4 +235,127 @@ private AnomalyResult wrongAnomalyResult() {
null
);
}

public void testResponseIsAcknowledgedTrue() throws InterruptedException {
String testIndex = "testIndex";

// Set up mocks for doesIndexExist and doesAliasExist
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false);

// Mock initCustomResultIndexDirectly to simulate index creation and call the listener
doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
// Simulate immediate onResponse call
listener.onResponse(new CreateIndexResponse(true, true, testIndex));
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);

// Call bulk method
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onResponse is called
verify(client, times(1)).prepareBulk();
}

public void testResponseIsAcknowledgedFalse() {
String testIndex = "testIndex";
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false);

doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
listener.onResponse(new CreateIndexResponse(false, false, testIndex));
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("Creating custom result index with mappings call not acknowledged", exceptionCaptor.getValue().getMessage());
}

public void testResourceAlreadyExistsException() {
String testIndex = "testIndex";
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false, true);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false, false);

doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
listener.onFailure(new ResourceAlreadyExistsException("index already exists"));
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(1);
listener.onResponse(true);
return null;
}).when(anomalyDetectionIndices).validateResultIndexMapping(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onResponse is called
verify(client, times(1)).prepareBulk();
}

public void testOtherException() {
String testIndex = "testIndex";
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenReturn(false);
when(anomalyDetectionIndices.doesAliasExist(testIndex)).thenReturn(false);

Exception testException = new OpenSearchRejectedExecutionException("Test exception");

doAnswer(invocation -> {
ActionListener<CreateIndexResponse> listener = invocation.getArgument(1);
listener.onFailure(testException);
return null;
}).when(anomalyDetectionIndices).initCustomResultIndexDirectly(eq(testIndex), any());

AnomalyResult result = mock(AnomalyResult.class);
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals(testException, exceptionCaptor.getValue());
}

public void testTimeSeriesExceptionCaughtInBulk() {
String testIndex = "testIndex";
TimeSeriesException testException = new TimeSeriesException("Test TimeSeriesException");

// Mock doesIndexExist to throw TimeSeriesException
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenThrow(testException);

AnomalyResult result = mock(AnomalyResult.class);

// Call bulk method
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onFailure is called with the TimeSeriesException
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals(testException, exceptionCaptor.getValue());
}

public void testExceptionCaughtInBulk() {
String testIndex = "testIndex";
NullPointerException testException = new NullPointerException("Test NullPointerException");

// Mock doesIndexExist to throw NullPointerException
when(anomalyDetectionIndices.doesIndexExist(testIndex)).thenThrow(testException);

AnomalyResult result = mock(AnomalyResult.class);

// Call bulk method
bulkIndexHandler.bulk(testIndex, ImmutableList.of(result), configId, listener);

// Verify that listener.onFailure is called with a TimeSeriesException wrapping the original exception
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
Exception capturedException = exceptionCaptor.getValue();
assertTrue(capturedException instanceof TimeSeriesException);
assertEquals("Failed to bulk index result", capturedException.getMessage());
assertEquals(testException, capturedException.getCause());
}
}
Loading

0 comments on commit 2ab6dc7

Please sign in to comment.