Skip to content

Commit

Permalink
Add rule validation in AnomalyDetector constructor (#1341)
Browse files Browse the repository at this point in the history
* Add rule validation in AnomalyDetector constructor

This commit introduces rule validation within the AnomalyDetector constructor. Any validation errors are now propagated and displayed on the frontend to ensure immediate feedback.

Testing:
* Verified that validation errors are properly propagated and shown on the frontend.
* Added UTs to cover the new validation logic.

Signed-off-by: Kaituo Li <[email protected]>

* address Amit's comments

Signed-off-by: Kaituo Li <[email protected]>

---------

Signed-off-by: Kaituo Li <[email protected]>
(cherry picked from commit 9cdbcee)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Oct 18, 2024
1 parent d6758c6 commit f9964ac
Show file tree
Hide file tree
Showing 9 changed files with 654 additions and 79 deletions.
3 changes: 0 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,6 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap',
'org.opensearch.timeseries.transport.ResultBulkTransportAction',
'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler',
'org.opensearch.timeseries.transport.handler.ResultIndexingHandler',
Expand Down
26 changes: 1 addition & 25 deletions src/main/java/org/opensearch/ad/ml/ADModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
Expand All @@ -42,7 +41,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.timeseries.AnalysisModelSize;
import org.opensearch.timeseries.MemoryTracker;
import org.opensearch.timeseries.common.exception.ResourceNotFoundException;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
Expand All @@ -52,7 +50,6 @@
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.SingleStreamModelIdMapper;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.settings.TimeSeriesSettings;
import org.opensearch.timeseries.util.DateUtils;
Expand All @@ -69,9 +66,7 @@
* A facade managing ML operations and models.
*/
public class ADModelManager extends
ModelManager<ThresholdedRandomCutForest, AnomalyResult, ThresholdingResult, ADIndex, ADIndexManagement, ADCheckpointDao, ADCheckpointWriteWorker, ADColdStart>
implements
AnalysisModelSize {
ModelManager<ThresholdedRandomCutForest, AnomalyResult, ThresholdingResult, ADIndex, ADIndexManagement, ADCheckpointDao, ADCheckpointWriteWorker, ADColdStart> {
protected static final String ENTITY_SAMPLE = "sp";
protected static final String ENTITY_RCF = "rcf";
protected static final String ENTITY_THRESHOLD = "th";
Expand Down Expand Up @@ -594,25 +589,6 @@ public List<ThresholdingResult> getPreviewResults(Features features, AnomalyDete
}).collect(Collectors.toList());
}

/**
* Get all RCF partition's size corresponding to a detector. Thresholding models' size is a constant since they are small in size (KB).
* @param detectorId detector id
* @return a map of model id to its memory size
*/
@Override
public Map<String, Long> getModelSize(String detectorId) {
Map<String, Long> res = new HashMap<>();
res.putAll(forests.getModelSize(detectorId));
thresholds
.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId))
.forEach(entry -> {
res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes());
});
return res;
}

/**
* Get a RCF model's total updates.
* @param modelId the RCF model's id
Expand Down
121 changes: 121 additions & 0 deletions src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -109,6 +110,7 @@ public Integer getShingleSize(Integer customShingleSize) {
@Deprecated
public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range";
public static final String RULES_FIELD = "rules";
private static final String SUPPRESSION_RULE_ISSUE_PREFIX = "Suppression Rule Error: ";

protected String detectorType;

Expand Down Expand Up @@ -229,6 +231,8 @@ public AnomalyDetector(
issueType = ValidationIssueType.CATEGORY;
}

validateRules(features, rules);

checkAndThrowValidationErrors(ValidationAspect.DETECTOR);

this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name();
Expand Down Expand Up @@ -720,4 +724,121 @@ private static Boolean onlyParseBooleanValue(XContentParser parser) throws IOExc
}
return null;
}

/**
* Validates each condition in the list of rules against the list of features.
* Checks that:
* - The feature name exists in the list of features.
* - The related feature is enabled.
* - The value is not NaN and is positive.
*
* @param features The list of available features. Must not be null.
* @param rules The list of rules containing conditions to validate. Can be null.
*/
private void validateRules(List<Feature> features, List<Rule> rules) {
// Null check for rules
if (rules == null || rules.isEmpty()) {
return; // No suppression rules to validate; consider as valid
}

// Null check for features
if (features == null || features.isEmpty()) {
// Cannot proceed with validation if features are null but rules are not null
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "Features are not defined while suppression rules are provided.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Create a map of feature names to their enabled status for quick lookup
Map<String, Boolean> featureEnabledMap = new HashMap<>();
for (Feature feature : features) {
if (feature != null && feature.getName() != null) {
featureEnabledMap.put(feature.getName(), feature.getEnabled());
}
}

// Iterate over each rule
for (Rule rule : rules) {
if (rule == null || rule.getConditions() == null) {
// Invalid rule or conditions list is null
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A suppression rule or its conditions are not properly defined.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Iterate over each condition in the rule
for (Condition condition : rule.getConditions()) {
if (condition == null) {
// Invalid condition
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition within a suppression rule is not properly defined.";
this.issueType = ValidationIssueType.RULE;
return;
}

String featureName = condition.getFeatureName();

// Check if the feature name is null
if (featureName == null) {
// Feature name is required
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX + "A condition is missing the feature name.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the feature exists
if (!featureEnabledMap.containsKey(featureName)) {
// Feature does not exist
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "Feature \""
+ featureName
+ "\" specified in a suppression rule does not exist.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the feature is enabled
if (!featureEnabledMap.get(featureName)) {
// Feature is not enabled
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "Feature \""
+ featureName
+ "\" specified in a suppression rule is not enabled.";
this.issueType = ValidationIssueType.RULE;
return;
}

// other threshold types may not have value operand
ThresholdType thresholdType = condition.getThresholdType();
if (thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_MARGIN
|| thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_MARGIN
|| thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_RATIO
|| thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_RATIO) {
// Check if the value is not NaN
double value = condition.getValue();
if (Double.isNaN(value)) {
// Value is NaN
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "The threshold value for feature \""
+ featureName
+ "\" is not a valid number.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the value is positive
if (value <= 0) {
// Value is not positive
this.errorMessage = SUPPRESSION_RULE_ISSUE_PREFIX
+ "The threshold value for feature \""
+ featureName
+ "\" must be a positive number.";
this.issueType = ValidationIssueType.RULE;
return;
}
}
}
}

// All checks passed
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

package org.opensearch.timeseries.ml;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import org.opensearch.timeseries.MemoryTracker;
Expand Down Expand Up @@ -55,48 +52,4 @@ public ModelState<RCFModelType> put(String key, ModelState<RCFModelType> value)
}
return previousAssociatedState;
}

/**
* Gets all of a config's model sizes hosted on a node
*
* @param configId config Id
* @return a map of model id to its memory size
*/
public Map<String, Long> getModelSize(String configId) {
Map<String, Long> res = new HashMap<>();
super.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId))
.forEach(entry -> {
Optional<RCFModelType> modelOptional = entry.getValue().getModel();
if (modelOptional.isPresent()) {
res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get()));
}
});
return res;
}

/**
* Checks if a model exists for the given config.
* @param configId Config Id
* @return `true` if the model exists, `false` otherwise.
*/
public boolean doesModelExist(String configId) {
return super.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId))
.anyMatch(n -> true);
}

public boolean hostIfPossible(String modelId, ModelState<RCFModelType> toUpdate) {
return Optional
.ofNullable(toUpdate)
.filter(state -> state.getModel().isPresent())
.filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get()))
.map(state -> {
super.put(modelId, toUpdate);
return true;
})
.orElse(false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public enum ValidationIssueType implements Name {
SUBAGGREGATION(SearchTopForecastResultRequest.SUBAGGREGATIONS_FIELD),
RECENCY_EMPHASIS(Config.RECENCY_EMPHASIS_FIELD),
DESCRIPTION(Config.DESCRIPTION_FIELD),
HISTORY(Config.HISTORY_INTERVAL_FIELD);
HISTORY(Config.HISTORY_INTERVAL_FIELD),
RULE(AnomalyDetector.RULES_FIELD);

private String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ public class SuggestConfigParamRequest extends ActionRequest {
public SuggestConfigParamRequest(StreamInput in) throws IOException {
super(in);
context = in.readEnum(AnalysisType.class);
if (context.isAD()) {
if (getContext().isAD()) {
config = new AnomalyDetector(in);
} else if (context.isForecast()) {
} else if (getContext().isForecast()) {
config = new Forecaster(in);
} else {
throw new UnsupportedOperationException("This method is not supported");
Expand All @@ -55,7 +55,7 @@ public SuggestConfigParamRequest(AnalysisType context, Config config, String par
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeEnum(context);
out.writeEnum(getContext());
config.writeTo(out);
out.writeString(param);
out.writeTimeValue(requestTimeout);
Expand All @@ -77,4 +77,8 @@ public String getParam() {
public TimeValue getRequestTimeout() {
return requestTimeout;
}

public AnalysisType getContext() {
return context;
}
}
Loading

0 comments on commit f9964ac

Please sign in to comment.