Skip to content

Commit

Permalink
changes to add start_time and end_time filters to GetAlertsRequest (#…
Browse files Browse the repository at this point in the history
…1039) (#1074)

(cherry picked from commit 20905ce)


ignore flaky tests

Signed-off-by: Subhobrata Dey <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sbcd90 and github-actions[bot] authored Jun 12, 2024
1 parent 0ddc669 commit 41bd0c0
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.securityanalytics.action;

import java.io.IOException;
import java.time.Instant;
import java.util.Locale;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
Expand All @@ -24,29 +25,39 @@ public class GetAlertsRequest extends ActionRequest {
private String severityLevel;
private String alertState;

private Instant startTime;

private Instant endTime;

public static final String DETECTOR_ID = "detector_id";

public GetAlertsRequest(
String detectorId,
String logType,
Table table,
String severityLevel,
String alertState
String alertState,
Instant startTime,
Instant endTime
) {
super();
this.detectorId = detectorId;
this.logType = logType;
this.table = table;
this.severityLevel = severityLevel;
this.alertState = alertState;
this.startTime = startTime;
this.endTime = endTime;
}
public GetAlertsRequest(StreamInput sin) throws IOException {
this(
sin.readOptionalString(),
sin.readOptionalString(),
Table.readFrom(sin),
sin.readString(),
sin.readString()
sin.readString(),
sin.readOptionalInstant(),
sin.readOptionalInstant()
);
}

Expand All @@ -68,6 +79,8 @@ public void writeTo(StreamOutput out) throws IOException {
table.writeTo(out);
out.writeString(severityLevel);
out.writeString(alertState);
out.writeOptionalInstant(startTime);
out.writeOptionalInstant(endTime);
}

public String getDetectorId() {
Expand All @@ -89,4 +102,12 @@ public String getAlertState() {
public String getLogType() {
return logType;
}

public Instant getStartTime() {
return startTime;
}

public Instant getEndTime() {
return endTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import org.opensearch.commons.alerting.model.Alert;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.securityanalytics.action.AckAlertsResponse;
import org.opensearch.securityanalytics.action.AlertDto;
import org.opensearch.securityanalytics.action.GetAlertsResponse;
Expand All @@ -29,6 +32,7 @@
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.util.SecurityAnalyticsException;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
Expand Down Expand Up @@ -66,6 +70,8 @@ public void getAlertsByDetectorId(
Table table,
String severityLevel,
String alertState,
Instant startTime,
Instant endTime,
ActionListener<GetAlertsResponse> listener
) {
this.client.execute(GetDetectorAction.INSTANCE, new GetDetectorRequest(detectorId, -3L), new ActionListener<>() {
Expand All @@ -88,6 +94,8 @@ public void onResponse(GetDetectorResponse getDetectorResponse) {
table,
severityLevel,
alertState,
startTime,
endTime,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
Expand Down Expand Up @@ -129,9 +137,11 @@ public void getAlertsByMonitorIds(
Table table,
String severityLevel,
String alertState,
Instant startTime,
Instant endTime,
ActionListener<GetAlertsResponse> listener
) {

BoolQueryBuilder boolQueryBuilder = getBoolQueryBuilder(startTime, endTime);
org.opensearch.commons.alerting.action.GetAlertsRequest req =
new org.opensearch.commons.alerting.action.GetAlertsRequest(
table,
Expand All @@ -142,7 +152,7 @@ public void getAlertsByMonitorIds(
monitorIds,
null,
null,
null
boolQueryBuilder
);

AlertingPluginInterface.INSTANCE.getAlerts((NodeClient) client, req, new ActionListener<>() {
Expand Down Expand Up @@ -179,6 +189,8 @@ public void getAlerts(
Table table,
String severityLevel,
String alertState,
Instant startTime,
Instant endTime,
ActionListener<GetAlertsResponse> listener
) {
if (detectors.size() == 0) {
Expand All @@ -205,6 +217,8 @@ public void getAlerts(
table,
severityLevel,
alertState,
startTime,
endTime,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
Expand Down Expand Up @@ -247,7 +261,10 @@ private AlertDto mapAlertToAlertDto(Alert alert, String detectorId) {
public void getAlerts(List<String> alertIds,
Detector detector,
Table table,
Instant startTime,
Instant endTime,
ActionListener<org.opensearch.commons.alerting.action.GetAlertsResponse> actionListener) {
BoolQueryBuilder boolQueryBuilder = getBoolQueryBuilder(startTime, endTime);
GetAlertsRequest request = new GetAlertsRequest(
table,
"ALL",
Expand All @@ -257,7 +274,7 @@ public void getAlerts(List<String> alertIds,
null,
null,
alertIds,
null);
boolQueryBuilder);
AlertingPluginInterface.INSTANCE.getAlerts(
(NodeClient) client,
request, actionListener);
Expand Down Expand Up @@ -307,4 +324,17 @@ public void onFailure(Exception e) {
}

}

private static BoolQueryBuilder getBoolQueryBuilder(Instant startTime, Instant endTime) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
if (startTime != null && endTime != null) {
long startTimeMillis = startTime.toEpochMilli();
long endTimeMillis = endTime.toEpochMilli();
QueryBuilder timeRangeQuery = QueryBuilders.rangeQuery("start_time")
.from(startTimeMillis) // Greater than or equal to start time
.to(endTimeMillis); // Less than or equal to end time
boolQueryBuilder.filter(timeRangeQuery);
}
return boolQueryBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package org.opensearch.securityanalytics.resthandler;

import java.io.IOException;
import java.time.DateTimeException;
import java.time.Instant;
import java.util.List;
import java.util.Locale;
import org.opensearch.client.node.NodeClient;
Expand Down Expand Up @@ -45,6 +47,26 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
int startIndex = request.paramAsInt("startIndex", 0);
String searchString = request.param("searchString", "");

Instant startTime = null;
String startTimeParam = request.param("startTime");
if (startTimeParam != null && !startTimeParam.isEmpty()) {
try {
startTime = Instant.ofEpochMilli(Long.parseLong(startTimeParam));
} catch (NumberFormatException | NullPointerException | DateTimeException e) {
startTime = Instant.now();
}
}

Instant endTime = null;
String endTimeParam = request.param("endTime");
if (endTimeParam != null && !endTimeParam.isEmpty()) {
try {
endTime = Instant.ofEpochMilli(Long.parseLong(endTimeParam));
} catch (NumberFormatException | NullPointerException | DateTimeException e) {
endTime = Instant.now();
}
}

Table table = new Table(
sortOrder,
sortString,
Expand All @@ -59,7 +81,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
detectorType,
table,
severityLevel,
alertState
alertState,
startTime,
endTime
);

return channel -> client.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ public void onResponse(GetDetectorResponse getDetectorResponse) {
request.getAlertIds(),
getDetectorResponse.getDetector(),
new Table("asc", "id", null, 10000, 0, null),
null,
null,
getAlertsResponseStepListener
);
getAlertsResponseStepListener.whenComplete(getAlertsResponse -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ protected void doExecute(Task task, GetAlertsRequest request, ActionListener<Get
request.getTable(),
request.getSeverityLevel(),
request.getAlertState(),
request.getStartTime(),
request.getEndTime(),
actionListener
);
} else {
Expand Down Expand Up @@ -135,6 +137,8 @@ public void onResponse(SearchResponse searchResponse) {
request.getTable(),
request.getSeverityLevel(),
request.getAlertState(),
request.getStartTime(),
request.getEndTime(),
actionListener
);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ public void testGetAlerts_success() {
);

doAnswer(invocation -> {
ActionListener l = invocation.getArgument(6);
ActionListener l = invocation.getArgument(8);
l.onResponse(getAlertsResponse);
return null;
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class));
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(ActionListener.class));

// Call getFindingsByDetectorId
Table table = new Table(
Expand All @@ -205,7 +205,8 @@ public void testGetAlerts_success() {
0,
null
);
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() {
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
assertEquals(2, (int)getAlertsResponse.getTotalAlerts());
Expand Down Expand Up @@ -258,10 +259,10 @@ public void testGetFindings_getFindingsByMonitorIdFailures() {
}).when(client).execute(eq(GetDetectorAction.INSTANCE), any(GetDetectorRequest.class), any(ActionListener.class));

doAnswer(invocation -> {
ActionListener l = invocation.getArgument(6);
ActionListener l = invocation.getArgument(8);
l.onFailure(new IllegalArgumentException("Error getting findings"));
return null;
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class));
}).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(), any(), any(ActionListener.class));

// Call getFindingsByDetectorId
Table table = new Table(
Expand All @@ -272,7 +273,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() {
0,
null
);
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() {
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
fail("this test should've failed");
Expand Down Expand Up @@ -307,7 +309,8 @@ public void testGetFindings_getDetectorFailure() {
0,
null
);
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() {
alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), null, null,
new ActionListener<>() {
@Override
public void onResponse(GetAlertsResponse getAlertsResponse) {
fail("this test should've failed");
Expand Down
Loading

0 comments on commit 41bd0c0

Please sign in to comment.