Skip to content

Commit

Permalink
adding addiional params findingIds, startTime and endTime as findings…
Browse files Browse the repository at this point in the history
… API enhancement

Signed-off-by: Riya Saxena <[email protected]>
  • Loading branch information
riysaxen-amzn committed Feb 29, 2024
1 parent 0eb7dce commit 6b2a15c
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package org.opensearch.securityanalytics.action;

import java.io.IOException;
import java.time.Instant;
import java.util.List;
import java.util.Locale;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
Expand All @@ -18,6 +20,9 @@

public class GetFindingsRequest extends ActionRequest {

private List<String> findingIds;
private Instant startTime;
private Instant endTime;
private String logType;
private String detectorId;
private Table table;
Expand All @@ -34,16 +39,24 @@ public GetFindingsRequest(StreamInput sin) throws IOException {
this(
sin.readOptionalString(),
sin.readOptionalString(),
Table.readFrom(sin), sin.readOptionalString(), sin.readOptionalString()
Table.readFrom(sin),
sin.readOptionalString(),
sin.readOptionalString(),
sin.readOptionalStringList(),
sin.readOptionalInstant(),
sin.readOptionalInstant()
);
}

public GetFindingsRequest(String detectorId, String logType, Table table, String severity, String detectionType) {
public GetFindingsRequest(String detectorId, String logType, Table table, String severity, String detectionType, List<String> findingIds, Instant startTime, Instant endTime) {
this.detectorId = detectorId;
this.logType = logType;
this.table = table;
this.severity = severity;
this.detectionType = detectionType;
this.findingIds = findingIds;
this.startTime = startTime;
this.endTime = endTime;
}

@Override
Expand All @@ -53,6 +66,10 @@ public ActionRequestValidationException validate() {
validationException = addValidationError(String.format(Locale.getDefault(),
"detector_id is missing"),
validationException);
} else if(startTime != null && endTime != null && startTime.isAfter(endTime)) {
validationException = addValidationError(String.format(Locale.getDefault(),
"startTime should be less than endTime"),
validationException);
}
return validationException;
}
Expand All @@ -64,6 +81,9 @@ public void writeTo(StreamOutput out) throws IOException {
table.writeTo(out);
out.writeOptionalString(severity);
out.writeOptionalString(detectionType);
out.writeOptionalStringCollection(findingIds);
out.writeOptionalInstant(startTime);
out.writeOptionalInstant(endTime);
}

public String getDetectorId() {
Expand All @@ -85,4 +105,16 @@ public String getLogType() {
public Table getTable() {
return table;
}

public List<String> getFindingIds() {
return findingIds;
}

public Instant getStartTime() {
return startTime;
}

public Instant getEndTime() {
return endTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.securityanalytics.findings;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -104,6 +105,9 @@ public void onFailure(Exception e) {
table,
null,
null,
null,
null,
null,
getFindingsResponseListener
);
}
Expand All @@ -130,6 +134,9 @@ public void getFindingsByMonitorIds(
Table table,
String severity,
String detectionType,
List<String> findingIds,
Instant startTime,
Instant endTime,
ActionListener<GetFindingsResponse> listener
) {
org.opensearch.commons.alerting.action.GetFindingsRequest req =
Expand All @@ -138,9 +145,8 @@ public void getFindingsByMonitorIds(
table,
null,
findingIndexName,
monitorIds, severity, detectionType
monitorIds, severity, detectionType,findingIds, startTime, endTime
);

AlertingPluginInterface.INSTANCE.getFindings((NodeClient) client, req, new ActionListener<>() {
@Override
public void onResponse(
Expand Down Expand Up @@ -176,6 +182,9 @@ public void getFindings(
Table table,
String severity,
String detectionType,
List<String> findingIds,
Instant startTime,
Instant endTime,
ActionListener<GetFindingsResponse> listener
) {
if (detectors.size() == 0) {
Expand All @@ -202,6 +211,9 @@ public void getFindings(
table,
severity,
detectionType,
findingIds,
startTime,
endTime,
new ActionListener<>() {
@Override
public void onResponse(GetFindingsResponse getFindingsResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
package org.opensearch.securityanalytics.resthandler;

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import org.opensearch.client.node.NodeClient;
Expand Down Expand Up @@ -42,6 +45,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
String searchString = request.param("searchString", "");
String severity = request.param("severity", null);
String detectionType = request.param("detectionType", null);
List<String> findingIds = null;
if (request.param("findingIds") != null) {
findingIds = Arrays.asList(request.param("findingIds").split(","));
}
Instant startTime = null;
if (request.param("startTime") != null) {
startTime = Instant.parse(request.param("startTime"));
}
Instant endTime= null;
if (request.param("endTime") != null) {
endTime = Instant.parse(request.param("endTime"));
}

Table table = new Table(
sortOrder,
Expand All @@ -57,7 +72,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
detectorType,
table,
severity,
detectionType
detectionType,
findingIds,
startTime,
endTime
);

return channel -> client.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ public void onResponse(SearchResponse searchResponse) {
findingsRequest.getTable(),
findingsRequest.getSeverity(),
findingsRequest.getDetectionType(),
findingsRequest.getFindingIds(),
findingsRequest.getStartTime(),
findingsRequest.getEndTime(),
findingsResponseActionListener
);
} catch (IOException e) {
Expand Down
132 changes: 130 additions & 2 deletions src/test/java/org/opensearch/securityanalytics/findings/FindingIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.securityanalytics.findings;

import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -701,20 +702,147 @@ public void testGetFindings_bySearchString_success() throws IOException {

client().performRequest(new Request("POST", "_refresh"));

// Call GetFindings API for first detector by severity
// Call GetFindings API for first detector by searchString 'high'
Map<String, String> params = new HashMap<>();
params.put("searchString", "high");
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);
Assert.assertEquals(1, getFindingsBody.get("total_findings"));
// Call GetFindings API for second detector by severity
// Call GetFindings API for second detector by searchString 'critical'
params.clear();
params.put("searchString", "critical");
getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
getFindingsBody = entityAsMap(getFindingsResponse);
Assert.assertEquals(1, getFindingsBody.get("total_findings"));
}

public void testGetFindings_byStartTimeAndEndTime_success() throws IOException {
String index1 = createTestIndex(randomIndex(), windowsIndexMapping());

// Execute CreateMappingsAction to add alias mapping for index
Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
// both req params and req body are supported
createMappingRequest.setJsonEntity(
"{ \"index_name\":\"" + index1 + "\"," +
" \"rule_topic\":\"" + randomDetectorType() + "\", " +
" \"partial\":true" +
"}"
);

Response response = client().performRequest(createMappingRequest);
assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode());

// index 2
String index2 = createTestIndex("windows1", windowsIndexMapping());

// Execute CreateMappingsAction to add alias mapping for index
createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
// both req params and req body are supported
createMappingRequest.setJsonEntity(
"{ \"index_name\":\"" + index2 + "\"," +
" \"rule_topic\":\"windows\", " +
" \"partial\":true" +
"}"
);

response = client().performRequest(createMappingRequest);
assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode());
// Detector 1 - WINDOWS
String randomDocRuleId = createRule(randomRule());
List<DetectorRule> detectorRules = List.of(new DetectorRule(randomDocRuleId));
DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules,
emptyList());
Detector detector1 = randomDetectorWithTriggers(
getPrePackagedRules("windows"),
List.of(new DetectorTrigger(null, "test-trigger", "1", List.of("windows"), List.of(), List.of(), List.of(), List.of(), List.of())),
"windows",
input
);

Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector1));
Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));

Map<String, Object> responseBody = asMap(createResponse);
String createdId = responseBody.get("_id").toString();

String request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + createdId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);
String monitorId1 = ((List<String>) ((Map<String, Object>) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0);
// Detector 2 - CRITICAL Severity Netflow
String randomDocRuleId2 = createRule(randomRuleWithCriticalSeverity());
List<DetectorRule> detectorRules2 = List.of(new DetectorRule(randomDocRuleId2));
DetectorInput inputNetflow = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules2,
emptyList());
Detector detector2 = randomDetectorWithTriggers(
getPrePackagedRules("windows1"),
List.of(new DetectorTrigger(null, "test-trigger", "0", List.of("windows1"), List.of(), List.of(), List.of(), List.of(), List.of())),
"windows",
inputNetflow
);

createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector2));
Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));

responseBody = asMap(createResponse);
logger.info("Created response 2 : {}", responseBody.toString());

createdId = responseBody.get("_id").toString();

request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + createdId + "\"\n" +
" }\n" +
" }\n" +
"}";
hits = executeSearch(Detector.DETECTORS_INDEX, request);
hit = hits.get(0);
String monitorId2 = ((List<String>) ((Map<String, Object>) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0);

indexDoc(index1, "1", randomDoc());
indexDoc(index2, "2", randomDoc());
Instant startTime1 = Instant.now();
// execute monitor 1
Response executeResponse = executeAlertingMonitor(monitorId1, Collections.emptyMap());
Map<String, Object> executeResults = entityAsMap(executeResponse);
int noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
Assert.assertEquals(1, noOfSigmaRuleMatches);

Instant startTime2 = Instant.now();
// execute monitor 2
executeResponse = executeAlertingMonitor(monitorId2, Collections.emptyMap());
executeResults = entityAsMap(executeResponse);
noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
Assert.assertEquals(1, noOfSigmaRuleMatches);

client().performRequest(new Request("POST", "_refresh"));

// Call GetFindings API for first detector by startTime and endTime
Map<String, String> params = new HashMap<>();
params.put("startTime", String.valueOf(startTime1));
Instant endTime1 = Instant.now();
params.put("endTime", String.valueOf(endTime1));
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);

Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);
Assert.assertEquals(2, getFindingsBody.get("total_findings"));
// Call GetFindings API for second detector by startTime and endTime
params.clear();
params.put("startTime", String.valueOf(startTime2));
Instant endTime2 = Instant.now();
params.put("endTime", String.valueOf(endTime2));
getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
getFindingsBody = entityAsMap(getFindingsResponse);
Assert.assertEquals(1, getFindingsBody.get("total_findings"));
}

public void testGetFindings_rolloverByMaxAge_success() throws IOException, InterruptedException {

updateClusterSetting(FINDING_HISTORY_ROLLOVER_PERIOD.getKey(), "1s");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public void testGetFindings_success() {
ActionListener l = invocation.getArgument(4);
l.onResponse(getFindingsResponse);
return null;
}).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class));
}).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(), any(ActionListener.class));

// Call getFindingsByDetectorId
Table table = new Table(
Expand Down Expand Up @@ -209,7 +209,7 @@ public void testGetFindings_getFindingsByMonitorIdFailure() {
ActionListener l = invocation.getArgument(4);
l.onFailure(new IllegalArgumentException("Error getting findings"));
return null;
}).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class));
}).when(findingsService).getFindingsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(), any(ActionListener.class));

// Call getFindingsByDetectorId
Table table = new Table(
Expand Down

0 comments on commit 6b2a15c

Please sign in to comment.