Skip to content

Commit

Permalink
fix for detectors with sigma aggregation rules (opensearch-project#1372)
Browse files Browse the repository at this point in the history
Signed-off-by: Subhobrata Dey <[email protected]>
  • Loading branch information
sbcd90 authored Oct 23, 2024
1 parent b185440 commit 6f543b5
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,22 @@ public void upsertIndexTemplateWithAliasMappings(

upsertComponentTemplateStepListener.whenComplete( acknowledgedResponse -> {

// Find template which matches input index best
// Find template which matches input index best. starts by directly matching with input index and
// if not found matches with current write index.
String templateName =
MetadataIndexTemplateService.findV2Template(
state.metadata(),
normalizeIndexName(indexName),
false
);
if (templateName == null) {
templateName =
MetadataIndexTemplateService.findV2Template(
state.metadata(),
normalizeIndexName(cin),
false
);
}

if (templateName == null) {
// If we find conflicting templates(regardless of priority) and that template was created by us,
Expand Down Expand Up @@ -181,8 +190,8 @@ public void upsertIndexTemplateWithAliasMappings(
template = state.metadata().templatesV2().get(templateName);
if (template.composedOf().contains(componentName) == false) {
List<String> newComposedOf = new ArrayList<>(template.composedOf());
List<String> indexPatterns = List.of(computeIndexPattern(indexName));
;
List<String> indexPatterns = new ArrayList<>(template.indexPatterns());
indexPatterns.add(computeIndexPattern(indexName));
newComposedOf.add(componentName);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,12 +864,16 @@ private IndexMonitorRequest createDocLevelMonitorMatchAllRequest(
}
}
tags.removeIf(Objects::isNull);

// if queryFieldNames is not passed, alerting doc-level monitor fetches entire log doc.
List<String> queryFieldNames = List.of("_id");
DocLevelQuery docLevelQuery = new DocLevelQuery(
monitorName,
monitorName + "doc",
Collections.emptyList(),
actualQuery,
new ArrayList<>(tags)
new ArrayList<>(tags),
queryFieldNames
);
docLevelQueries.add(docLevelQuery);

Expand Down Expand Up @@ -1042,6 +1046,8 @@ public void onResponse(GetIndexMappingsResponse getIndexMappingsResponse) {
boolQueryBuilder.must(timeRangeFilter);
searchSourceBuilder.query(boolQueryBuilder);
}
// query hits are not needed from this query part for aggregations.
searchSourceBuilder.size(0);
List<SearchInput> bucketLevelMonitorInputs = new ArrayList<>();
bucketLevelMonitorInputs.add(new SearchInput(indices, searchSourceBuilder));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1728,7 +1728,8 @@ protected void createComposableIndexTemplate(String templateName, List<String> i
Collectors.joining(",", "\"", "\"")) +
"]," +
(componentTemplateName == null ? ("\"template\": {\"mappings\": {" + mappings + "}},") : "") +
(componentTemplateName != null ? ("\"composed_of\": [\"" + componentTemplateName + "\"],") : "") +
(componentTemplateName != null ? ("\"composed_of\": [\"" + componentTemplateName + "\"],\"template\": {" +
"\"settings\": {\"index\":{\"mapping\":{\"total_fields\":{\"limit\":\"5000\"}},\"number_of_shards\":\"18\",\"number_of_replicas\":\"1\"}}},") : "") +
"\"priority\":" + priority +
"}";
Response response = makeRequest(
Expand Down Expand Up @@ -1809,6 +1810,14 @@ protected void createDatastreamAPI(String datastreamName) throws IOException {
assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode());
}

protected void createIndexAliasApi(String indexAlias, String indexName) throws IOException {
Request request = new Request("POST", "_aliases");
request.setJsonEntity("{\"actions\":[{\"add\":{\"index\":\"" + indexName + "\",\"alias\":\"" + indexAlias + "\", " +
"\"is_write_index\": true}}]}");
Response response = client().performRequest(request);
assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode());
}


protected void deleteDatastreamAPI(String datastreamName) throws IOException {
Request request = new Request("DELETE", "_data_stream/" + datastreamName);
Expand Down Expand Up @@ -1851,6 +1860,34 @@ protected void createSampleDatastream(String datastreamName, String mappings, bo
createDatastreamAPI(datastreamName);
}

protected void createSampleIndexTemplate(String indexPattern, String mappings, boolean useComponentTemplate) throws IOException {
String indexName = indexPattern.substring(0, indexPattern.length() - 1);
String componentTemplateMappings = "\"properties\": {" +
" \"netflow.destination_transport_port\":{ \"type\": \"long\" }," +
" \"netflow.destination_ipv4_address\":{ \"type\": \"ip\" }" +
"}";

if (mappings != null) {
componentTemplateMappings = mappings;
}

if (useComponentTemplate) {
// Setup index_template
createComponentTemplateWithMappings(
"my_ds_component_template-" + indexName,
componentTemplateMappings
);
}
createComposableIndexTemplate(
"my_index_template_ds-" + indexName,
List.of(indexPattern),
useComponentTemplate ? "my_ds_component_template-" + indexName : null,
mappings,
false,
2
);
}

protected void restoreAlertsFindingsIMSettings() throws IOException {
updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "720m");
updateClusterSetting(ALERT_HISTORY_MAX_DOCS.getKey(), "100000");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ public static String randomCloudtrailAggrRuleWithDotFields() {
" - lambda.amazonaws.com\n" +
" api.operation: \n" +
" - Invoke\n" +
" timeframe: 20m\n" +
" timeframe: 1m\n" +
" tags:\n" +
" - attack.privilege_escalation\n" +
" - attack.t1078";
Expand Down Expand Up @@ -2308,7 +2308,7 @@ public static String randomCloudtrailOcsfDoc() {
" },\n" +
" \"status\": \"Success\",\n" +
" \"status_id\": 1,\n" +
" \"time\": 1702952105000,\n" +
" \"time\": " + System.currentTimeMillis() + ",\n" +
" \"type_name\": \"Account Change: Detach Policy\",\n" +
" \"type_uid\": 300108,\n" +
" \"unmapped\": {\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
import org.opensearch.commons.alerting.model.Monitor.MonitorType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -2063,7 +2064,10 @@ public void testCreateDetectorWithCloudtrailAggrRuleWithDotFields() throws IOExc

@SuppressWarnings("unchecked")
public void testCreateDetectorWithCloudtrailAggrRuleWithEcsFields() throws IOException {
String index = createTestIndex("cloudtrail", cloudtrailOcsfMappings());
String index = "cloudtrail";
String indexAlias = "test_alias";

createIndex(index, Settings.EMPTY, cloudtrailOcsfMappings(), "\"" + indexAlias + "\":{\"is_write_index\": true}");

// Execute CreateMappingsAction to add alias mapping for index
Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
Expand Down Expand Up @@ -2153,6 +2157,108 @@ public void testCreateDetectorWithCloudtrailAggrRuleWithEcsFields() throws IOExc
assertEquals(1, getFindingsBody.get("total_findings"));
}

@SuppressWarnings("unchecked")
public void testCreateDetectorWithCloudtrailAggrRuleWithRolloverIndexAliases() throws IOException, InterruptedException {
createSampleIndexTemplate("cloudtrail*", cloudtrailOcsfMappings(), true);
String index = createTestIndex("cloudtrail-000001", "");
createIndexAliasApi("ocsf_ct", "cloudtrail-000001");

// Execute CreateMappingsAction to add alias mapping for index
Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
// both req params and req body are supported
createMappingRequest.setJsonEntity(
"{\n" +
" \"index_name\": \"ocsf_ct\",\n" +
" \"rule_topic\": \"cloudtrail\",\n" +
" \"partial\": true,\n" +
" \"alias_mappings\": {\n" +
" \"properties\": {\n" +
" \"timestamp\": {\n" +
" \"path\": \"time\",\n" +
" \"type\": \"alias\"\n" +
" }\n" +
" }\n" +
" }\n" +
"}"
);

Response createMappingResponse = client().performRequest(createMappingRequest);

assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode());
indexDoc(index, "0", randomCloudtrailOcsfDoc());

String rule = randomCloudtrailAggrRuleWithDotFields();

Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", "cloudtrail"),
new StringEntity(rule), new BasicHeader("Content-Type", "application/json"));
Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
String createdId = responseBody.get("_id").toString();

DetectorInput input = new DetectorInput("cloudtrail detector for security analytics", List.of("ocsf_ct"), List.of(new DetectorRule(createdId)),
List.of());
Detector detector = randomDetectorWithInputsAndTriggers(List.of(input),
List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of(), List.of(), List.of())));

createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector));
Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));

responseBody = asMap(createResponse);

createdId = responseBody.get("_id").toString();
int createdVersion = Integer.parseInt(responseBody.get("_version").toString());
Assert.assertNotEquals("response is missing Id", Detector.NO_ID, createdId);
Assert.assertTrue("incorrect version", createdVersion > 0);
Assert.assertEquals("Incorrect Location header", String.format(Locale.getDefault(), "%s/%s", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, createdId), createResponse.getHeader("Location"));
Assert.assertFalse(((Map<String, Object>) responseBody.get("detector")).containsKey("rule_topic_index"));
Assert.assertFalse(((Map<String, Object>) responseBody.get("detector")).containsKey("findings_index"));
Assert.assertFalse(((Map<String, Object>) responseBody.get("detector")).containsKey("alert_index"));

String detectorTypeInResponse = (String) ((Map<String, Object>)responseBody.get("detector")).get("detector_type");
Assert.assertEquals("Detector type incorrect", randomDetectorType().toLowerCase(Locale.ROOT), detectorTypeInResponse);

String request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + createdId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);

String workflowId = ((List<String>) ((Map<String, Object>) hit.getSourceAsMap().get("detector")).get("workflow_ids")).get(0);

indexDoc("ocsf_ct", "1", randomCloudtrailOcsfDoc());
indexDoc("ocsf_ct", "2", randomCloudtrailOcsfDoc());
executeAlertingWorkflow(workflowId, Collections.emptyMap());

Map<String, String> params = new HashMap<>();
params.put("detector_id", createdId);
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);

// Assert findings
assertNotNull(getFindingsBody);
assertEquals(1, getFindingsBody.get("total_findings"));

doRollover("ocsf_ct");
Thread.sleep(90000);

indexDoc("ocsf_ct", "4", randomCloudtrailOcsfDoc());
indexDoc("ocsf_ct", "5", randomCloudtrailOcsfDoc());
executeAlertingWorkflow(workflowId, Collections.emptyMap());

params = new HashMap<>();
params.put("detector_id", createdId);
getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
getFindingsBody = entityAsMap(getFindingsResponse);

// Assert findings
assertNotNull(getFindingsBody);
assertEquals(2, getFindingsBody.get("total_findings"));
}

private static void assertRuleMonitorFinding(Map<String, Object> executeResults, String ruleId, int expectedDocCount, List<String> expectedTriggerResult) {
List<Map<String, Object>> buckets = ((List<Map<String, Object>>) (((Map<String, Object>) ((Map<String, Object>) ((Map<String, Object>) ((List<Object>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get("result_agg")).get("buckets")));
Integer docCount = buckets.stream().mapToInt(it -> (Integer) it.get("doc_count")).sum();
Expand Down

0 comments on commit 6f543b5

Please sign in to comment.