diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java index e63be0405..eb64ccad1 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java @@ -5,6 +5,8 @@ package org.opensearch.securityanalytics.action; import java.io.IOException; +import java.time.Instant; +import java.util.List; import java.util.Locale; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -18,6 +20,9 @@ public class GetFindingsRequest extends ActionRequest { + private List findingIds; + private Instant startTime; + private Instant endTime; private String logType; private String detectorId; private Table table; @@ -34,16 +39,24 @@ public GetFindingsRequest(StreamInput sin) throws IOException { this( sin.readOptionalString(), sin.readOptionalString(), - Table.readFrom(sin), sin.readOptionalString(), sin.readOptionalString() + Table.readFrom(sin), + sin.readOptionalString(), + sin.readOptionalString(), + sin.readOptionalStringList(), + sin.readOptionalInstant(), + sin.readOptionalInstant() ); } - public GetFindingsRequest(String detectorId, String logType, Table table, String severity, String detectionType) { + public GetFindingsRequest(String detectorId, String logType, Table table, String severity, String detectionType, List findingIds, Instant startTime, Instant endTime) { this.detectorId = detectorId; this.logType = logType; this.table = table; this.severity = severity; this.detectionType = detectionType; + this.findingIds = findingIds; + this.startTime = startTime; + this.endTime = endTime; } @Override @@ -53,6 +66,10 @@ public ActionRequestValidationException validate() { validationException = addValidationError(String.format(Locale.getDefault(), "detector_id is missing"), validationException); + } else if(startTime != null && endTime != null && startTime.isAfter(endTime)) { + validationException = addValidationError(String.format(Locale.getDefault(), + "startTime should be less than endTime"), + validationException); } return validationException; } @@ -64,6 +81,9 @@ public void writeTo(StreamOutput out) throws IOException { table.writeTo(out); out.writeOptionalString(severity); out.writeOptionalString(detectionType); + out.writeOptionalStringCollection(findingIds); + out.writeOptionalInstant(startTime); + out.writeOptionalInstant(endTime); } public String getDetectorId() { @@ -85,4 +105,16 @@ public String getLogType() { public Table getTable() { return table; } + + public List getFindingIds() { + return findingIds; + } + + public Instant getStartTime() { + return startTime; + } + + public Instant getEndTime() { + return endTime; + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java index 17ca6aea1..29fbb8a4f 100644 --- a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java +++ b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java @@ -4,6 +4,7 @@ */ package org.opensearch.securityanalytics.findings; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -104,6 +105,9 @@ public void onFailure(Exception e) { table, null, null, + null, + null, + null, getFindingsResponseListener ); } @@ -130,6 +134,9 @@ public void getFindingsByMonitorIds( Table table, String severity, String detectionType, + List findingIds, + Instant startTime, + Instant endTime, ActionListener listener ) { org.opensearch.commons.alerting.action.GetFindingsRequest req = @@ -138,9 +145,8 @@ public void getFindingsByMonitorIds( table, null, findingIndexName, - monitorIds, severity, detectionType + monitorIds, severity, detectionType,findingIds, startTime, endTime ); - AlertingPluginInterface.INSTANCE.getFindings((NodeClient) client, req, new ActionListener<>() { @Override public void onResponse( @@ -176,6 +182,9 @@ public void getFindings( Table table, String severity, String detectionType, + List findingIds, + Instant startTime, + Instant endTime, ActionListener listener ) { if (detectors.size() == 0) { @@ -202,6 +211,9 @@ public void getFindings( table, severity, detectionType, + findingIds, + startTime, + endTime, new ActionListener<>() { @Override public void onResponse(GetFindingsResponse getFindingsResponse) { diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java index d8ca66953..f908caad4 100644 --- a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java @@ -5,6 +5,9 @@ package org.opensearch.securityanalytics.resthandler; import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Locale; import org.opensearch.client.node.NodeClient; @@ -42,6 +45,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli String searchString = request.param("searchString", ""); String severity = request.param("severity", null); String detectionType = request.param("detectionType", null); + List findingIds = null; + if (request.param("findingIds") != null) { + findingIds = Arrays.asList(request.param("findingIds").split(",")); + } + Instant startTime = null; + if (request.param("startTime") != null) { + startTime = Instant.parse(request.param("startTime")); + } + Instant endTime= null; + if (request.param("endTime") != null) { + endTime = Instant.parse(request.param("endTime")); + } Table table = new Table( sortOrder, @@ -57,7 +72,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli detectorType, table, severity, - detectionType + detectionType, + findingIds, + startTime, + endTime ); return channel -> client.execute( diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java index e547a0d2b..96e7207a2 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetFindingsAction.java @@ -142,6 +142,9 @@ public void onResponse(SearchResponse searchResponse) { findingsRequest.getTable(), findingsRequest.getSeverity(), findingsRequest.getDetectionType(), + findingsRequest.getFindingIds(), + findingsRequest.getStartTime(), + findingsRequest.getEndTime(), findingsResponseActionListener ); } catch (IOException e) { diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java index 0c89e6a2b..710089268 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java @@ -6,6 +6,7 @@ package org.opensearch.securityanalytics.findings; import java.io.IOException; +import java.time.Instant; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -701,13 +702,13 @@ public void testGetFindings_bySearchString_success() throws IOException { client().performRequest(new Request("POST", "_refresh")); - // Call GetFindings API for first detector by severity + // Call GetFindings API for first detector by searchString 'high' Map params = new HashMap<>(); params.put("searchString", "high"); Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); Map getFindingsBody = entityAsMap(getFindingsResponse); Assert.assertEquals(1, getFindingsBody.get("total_findings")); - // Call GetFindings API for second detector by severity + // Call GetFindings API for second detector by searchString 'critical' params.clear(); params.put("searchString", "critical"); getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); @@ -715,6 +716,133 @@ public void testGetFindings_bySearchString_success() throws IOException { Assert.assertEquals(1, getFindingsBody.get("total_findings")); } + public void testGetFindings_byStartTimeAndEndTime_success() throws IOException { + String index1 = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index1 + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + + // index 2 + String index2 = createTestIndex("windows1", windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index2 + "\"," + + " \"rule_topic\":\"windows\", " + + " \"partial\":true" + + "}" + ); + + response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + // Detector 1 - WINDOWS + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + emptyList()); + Detector detector1 = randomDetectorWithTriggers( + getPrePackagedRules("windows"), + List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("windows"), List.of(), List.of(), List.of(), List.of(), List.of())), + "windows", + input + ); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector1)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + String createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + String monitorId1 = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + // Detector 2 - CRITICAL Severity Netflow + String randomDocRuleId2 = createRule(randomRuleWithCriticalSeverity()); + List detectorRules2 = List.of(new DetectorRule(randomDocRuleId2)); + DetectorInput inputNetflow = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules2, + emptyList()); + Detector detector2 = randomDetectorWithTriggers( + getPrePackagedRules("windows1"), + List.of(new DetectorTrigger(null, "test-trigger", "0", List.of("windows1"), List.of(), List.of(), List.of(), List.of(), List.of())), + "windows", + inputNetflow + ); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector2)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + responseBody = asMap(createResponse); + logger.info("Created response 2 : {}", responseBody.toString()); + + createdId = responseBody.get("_id").toString(); + + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + String monitorId2 = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + + indexDoc(index1, "1", randomDoc()); + indexDoc(index2, "2", randomDoc()); + Instant startTime1 = Instant.now(); + // execute monitor 1 + Response executeResponse = executeAlertingMonitor(monitorId1, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(1, noOfSigmaRuleMatches); + + Instant startTime2 = Instant.now(); + // execute monitor 2 + executeResponse = executeAlertingMonitor(monitorId2, Collections.emptyMap()); + executeResults = entityAsMap(executeResponse); + noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(1, noOfSigmaRuleMatches); + + client().performRequest(new Request("POST", "_refresh")); + + // Call GetFindings API for first detector by startTime and endTime + Map params = new HashMap<>(); + params.put("startTime", String.valueOf(startTime1)); + Instant endTime1 = Instant.now(); + params.put("endTime", String.valueOf(endTime1)); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + + Map getFindingsBody = entityAsMap(getFindingsResponse); + Assert.assertEquals(2, getFindingsBody.get("total_findings")); + // Call GetFindings API for second detector by startTime and endTime + params.clear(); + params.put("startTime", String.valueOf(startTime2)); + Instant endTime2 = Instant.now(); + params.put("endTime", String.valueOf(endTime2)); + getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + getFindingsBody = entityAsMap(getFindingsResponse); + Assert.assertEquals(1, getFindingsBody.get("total_findings")); + } + public void testGetFindings_rolloverByMaxAge_success() throws IOException, InterruptedException { updateClusterSetting(FINDING_HISTORY_ROLLOVER_PERIOD.getKey(), "1s"); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java index 77cccc0ce..57fa793b5 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java @@ -142,7 +142,7 @@ public void testGetFindings_success() { ActionListener l = invocation.getArgument(4); l.onResponse(getFindingsResponse); return null; - }).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class)); + }).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(), any(ActionListener.class)); // Call getFindingsByDetectorId Table table = new Table( @@ -209,7 +209,7 @@ public void testGetFindings_getFindingsByMonitorIdFailure() { ActionListener l = invocation.getArgument(4); l.onFailure(new IllegalArgumentException("Error getting findings")); return null; - }).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class)); + }).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(), any(ActionListener.class)); // Call getFindingsByDetectorId Table table = new Table(