diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index 30f6352a4..f10527612 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -61,7 +61,6 @@ import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.mapper.IndexTemplateManager; import org.opensearch.securityanalytics.mapper.MapperService; -import org.opensearch.securityanalytics.model.CorrelationAlert; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.ThreatIntelFeedData; import org.opensearch.securityanalytics.resthandler.*; @@ -167,13 +166,13 @@ public Collection createComponents(Client client, TIFJobParameterService tifJobParameterService = new TIFJobParameterService(client, clusterService); TIFJobUpdateService tifJobUpdateService = new TIFJobUpdateService(clusterService, tifJobParameterService, threatIntelFeedDataService, builtInTIFMetadataLoader); TIFLockService threatIntelLockService = new TIFLockService(clusterService, client); - + CorrelationAlertService correlationAlertService = new CorrelationAlertService(client, xContentRegistry); TIFJobRunner.getJobRunnerInstance().initialize(clusterService, tifJobUpdateService, tifJobParameterService, threatIntelLockService, threadPool, detectorThreatIntelService); return List.of( detectorIndices, correlationIndices, correlationRuleIndices, ruleTopicIndices, customLogTypeIndices, ruleIndices, mapperService, indexTemplateManager, builtinLogTypeLoader, builtInTIFMetadataLoader, threatIntelFeedDataService, detectorThreatIntelService, - tifJobUpdateService, tifJobParameterService, threatIntelLockService, new CorrelationAlertService(client, clusterService, xContentRegistry)); + tifJobUpdateService, tifJobParameterService, threatIntelLockService, correlationAlertService); } @Override @@ -241,7 +240,6 @@ public ScheduledJobParser getJobParser() { public List getNamedXContent() { return List.of( Detector.XCONTENT_REGISTRY, - CorrelationAlert.XCONTENT_REGISTRY, DetectorInput.XCONTENT_REGISTRY, Rule.XCONTENT_REGISTRY, CustomLogType.XCONTENT_REGISTRY, diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 84447e7b5..82354a419 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -72,18 +72,24 @@ public class JoinEngine { private final LogTypeService logTypeService; + private final CorrelationAlertService correlationAlertService; + + private volatile TimeValue indexTimeout; + private static final Logger log = LogManager.getLogger(JoinEngine.class); public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry, - long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, - LogTypeService logTypeService, boolean enableAutoCorrelations) { + long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, + LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService) { this.client = client; this.request = request; this.xContentRegistry = xContentRegistry; this.corrTimeWindow = corrTimeWindow; + this.indexTimeout = indexTimeout; this.correlateFindingAction = correlateFindingAction; this.logTypeService = logTypeService; this.enableAutoCorrelations = enableAutoCorrelations; + this.correlationAlertService = correlationAlertService; } public void onSearchDetectorResponse(Detector detector, Finding finding) { @@ -544,12 +550,11 @@ private void getCorrelatedFindings(String detectorType, Map ++idx; } - CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(); - correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId()); - log.info("Source correlated findings: {}", request.getFinding().getId()); - log.info("Get correlated findings: {}", correlatedFindings); - log.info("Source correlated findings: {}", request.getFinding().getId()); - log.info("Index correlated findings: {}", idx); + if (!correlatedFindings.isEmpty()) { + CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService); + correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout); + correlationRuleScheduler.shutdown(); + } for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { if (correlatedFindings.containsKey(autoCorrelation.getKey())) { diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java index 1294e646c..062bae02c 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java @@ -6,111 +6,132 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.commons.alerting.model.Table; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.FieldSortBuilder; -import org.opensearch.search.sort.SortBuilders; -import org.opensearch.search.sort.SortOrder; -import org.opensearch.securityanalytics.model.CorrelationAlert; +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.securityanalytics.util.CorrelationIndices; import java.io.IOException; +import java.time.Instant; +import java.util.List; import java.util.ArrayList; import java.util.Collections; -import java.util.List; -import java.util.Objects; public class CorrelationAlertService { - public static final String CORRELATION_ALERT_INDEX = ".opensearch-sap-correlations-alerts"; private static final Logger log = LogManager.getLogger(CorrelationAlertService.class); - private final Client client; - private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private final Client client; - public CorrelationAlertService(Client client, ClusterService clusterService, NamedXContentRegistry xContentRegistry) { + public CorrelationAlertService(Client client, NamedXContentRegistry xContentRegistry) { this.client = client; - this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; } - public void getCorrelationAlerts(ActionListener listener,Table table, - String severityLevel, - String alertState) { - try { - if (false == correlationAlertsIndexExists()) { - listener.onResponse(new CorrelationAlertsList(Collections.emptyList(), 0)); - } else { - FieldSortBuilder sortBuilder = SortBuilders - .fieldSort(table.getSortString()) - .order(SortOrder.fromString(table.getSortOrder())); - if (null != table.getMissing() && false == table.getMissing().isEmpty()) { - sortBuilder.missing(table.getMissing()); - } - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); - if (false == Objects.equals(severityLevel, "ALL")) { - queryBuilder.filter(QueryBuilders.termQuery("severity", severityLevel)); - } - if (false == Objects.equals(alertState, "ALL")) { - queryBuilder.filter(QueryBuilders.termQuery("state", alertState)); - } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .version(true) - .seqNoAndPrimaryTerm(true) - .query(queryBuilder) - .sort(sortBuilder) - .size(table.getSize()) - .from(table.getStartIndex()); - - SearchRequest searchRequest = new SearchRequest(CORRELATION_ALERT_INDEX).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap( searchResponse -> { - if (0 == searchResponse.getHits().getHits().length) { + /** + * Searches for active Alerts in the correlation alerts index within a specified time range. + * + * @param ruleId The correlation rule ID to filter the alerts + * @param currentTime The current time of the search range + * @return The search response containing active alerts + */ + public void getActiveAlerts(String ruleId, long currentTime, ActionListener listener) { + Instant currentTimeDate = Instant.ofEpochMilli(currentTime); + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("correlation_rule_id", ruleId)) + .must(QueryBuilders.rangeQuery("start_time").lte(currentTimeDate)) + .must(QueryBuilders.rangeQuery("end_time").gte(currentTimeDate)) + .must(QueryBuilders.termQuery("state", "ACTIVE")); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .seqNoAndPrimaryTerm(true) + .version(true) + .size(10000) // set the size to 10,000 + .query(queryBuilder); + + SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getTotalHits().equals(0)) { listener.onResponse(new CorrelationAlertsList(Collections.emptyList(), 0)); } else { - listener.onResponse( new CorrelationAlertsList( + listener.onResponse(new CorrelationAlertsList( parseCorrelationAlerts(searchResponse), searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ? - (int) searchResponse.getHits().getTotalHits().value : 0) + (int) searchResponse.getHits().getTotalHits().value : 0) ); } }, - e -> { - log.error("Search request to fetch correlation alerts failed", e); - listener.onFailure(e); - } - )); - } - } catch (Exception e) { - log.error("Unexpected error when fetch correlation alerts", e); - listener.onFailure(e); - } + e -> { + log.error("Search request to fetch correlation alerts failed", e); + listener.onFailure(e); + } + )); } - public boolean correlationAlertsIndexExists() { - ClusterState clusterState = clusterService.state(); - return clusterState.getRoutingTable().hasIndex(CORRELATION_ALERT_INDEX); + + public void indexCorrelationAlert(CorrelationAlert correlationAlert, TimeValue indexTimeout, ActionListener listener) { + // Convert CorrelationAlert to a map + try { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("correlated_finding_ids", correlationAlert.getCorrelatedFindingIds()); + builder.field("correlation_rule_id", correlationAlert.getCorrelationRuleId()); + builder.field("correlation_rule_name", correlationAlert.getCorrelationRuleName()); + builder.field("id", correlationAlert.getId()); + builder.field("user", correlationAlert.getUser()); // Convert User object to map + builder.field("schema_version", correlationAlert.getSchemaVersion()); + builder.field("severity", correlationAlert.getSeverity()); + builder.field("state", correlationAlert.getState()); + builder.field("trigger_name", correlationAlert.getTriggerName()); + builder.field("version", correlationAlert.getVersion()); + builder.field("start_time", correlationAlert.getStartTime()); + builder.field("end_time", correlationAlert.getEndTime()); + builder.field("action_execution_results", correlationAlert.getActionExecutionResults()); + builder.field("error_message", correlationAlert.getErrorMessage()); + builder.field("acknowledged_time", correlationAlert.getAcknowledgedTime()); + builder.endObject(); + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .id(correlationAlert.getId()) + .source(builder) + .timeout(indexTimeout); + + client.index(indexRequest, listener); + } catch (IOException ex) { + log.error("Exception while adding alerts in .opensearch-sap-correlation-alerts index", ex); + } } public List parseCorrelationAlerts(final SearchResponse response) throws IOException { List alerts = new ArrayList<>(); for (SearchHit hit : response.getHits()) { - XContentParser xcp = XContentType.JSON.xContent().createParser(xContentRegistry, - LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString()); - CorrelationAlert correlationAlert = CorrelationAlert.docParse(xcp, hit.getId(), hit.getVersion()); + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString() + ); + + CorrelationAlert correlationAlert = CorrelationAlert.parse(xcp, hit.getId(), hit.getVersion()); alerts.add(correlationAlert); } return alerts; } + // Helper method to convert User object to map } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java index 88cdd2cc3..a6cdda9a6 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java @@ -4,7 +4,7 @@ */ package org.opensearch.securityanalytics.correlation.alert; -import org.opensearch.securityanalytics.model.CorrelationAlert; +import org.opensearch.commons.alerting.model.CorrelationAlert; import java.util.List; diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java index de4b6d166..b3690afb6 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java @@ -2,22 +2,38 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.core.action.ActionListener; import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; import java.time.Instant; -import java.util.*; -import java.util.concurrent.TimeUnit; +import java.util.UUID; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; public class CorrelationRuleScheduler { - private static final Logger log = LogManager.getLogger(CorrelationRuleScheduler.class); + private final Logger log = LogManager.getLogger(CorrelationRuleScheduler.class); + private final Client client; + private final CorrelationAlertService correlationAlertService; + private final ExecutorService executorService; - public void schedule(List correlationRules, Map> correlatedFindings, String sourceFinding) { - // Create a map of correlation rule to list of finding IDs - Map> correlationRuleToFindingIds = new HashMap<>(); + public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService) { + this.client = client; + this.correlationAlertService = correlationAlertService; + this.executorService = Executors.newCachedThreadPool(); + } + + public void schedule(List correlationRules, Map> correlatedFindings, String sourceFinding, TimeValue indexTimeout) { for (CorrelationRule rule : correlationRules) { CorrelationRuleTrigger trigger = rule.getCorrelationTrigger(); if (trigger != null) { @@ -28,60 +44,121 @@ public void schedule(List correlationRules, Map findingIds) { - Timer timer = new Timer(); + + public void shutdown() { + executorService.shutdown(); + } + + private void scheduleRule(CorrelationRule correlationRule, List findingIds, TimeValue indexTimeout) { long startTime = Instant.now().toEpochMilli(); - long endTime = startTime + TimeUnit.MINUTES.toMillis(correlationRule.getCorrTimeWindow()); // Assuming time window is based on ruleId -// timer.schedule(new RuleTask(this.correlationAlertService, this.notificationService, correlationRule, findingIds, startTime, endTime), 0, 60000); // Check every minute + long endTime = startTime + correlationRule.getCorrTimeWindow(); + executorService.submit(new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, indexTimeout)); } - static class RuleTask extends TimerTask { - private final CorrelationAlertService alertService; - private final NotificationService notificationService; + private class RuleTask implements Runnable { private final CorrelationRule correlationRule; private final long startTime; private final long endTime; private final List correlatedFindingIds; + private final CorrelationAlertService correlationAlertService; + private final TimeValue indexTimeout; - - public RuleTask(CorrelationAlertService alertService, NotificationService notificationService, CorrelationRule correlationRule, List correlatedFindingIds, long startTime, long endTime) { - this.alertService = alertService; - this.notificationService = notificationService; + public RuleTask(CorrelationRule correlationRule, List correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, TimeValue indexTimeout) { + this.correlationRule = correlationRule; + this.correlatedFindingIds = correlatedFindingIds; this.startTime = startTime; this.endTime = endTime; - this.correlatedFindingIds = correlatedFindingIds; - this.correlationRule = correlationRule; + this.correlationAlertService = correlationAlertService; + this.indexTimeout = indexTimeout; } @Override public void run() { long currentTime = Instant.now().toEpochMilli(); -// if (currentTime >= startTime && currentTime <= endTime) { // Within time window -// try { -// List activeAlertIds = alertService.getActiveAlertsList(correlationRule.getId(), startTime, endTime); -// if (activeAlertIds.isEmpty()) { -// Map correlationAlert = Map.of( -// "start_time", startTime, -// "end_time", endTime, -// "correlation_rule_id", correlationRule.getId(), -// "severity", correlationRule.getCorrelationTrigger().getSeverity() -// // add more fields; -// ); -// alertService.indexAlert(correlationAlert); -// //notificationService.sendNotification(alert); -// } else { -// alertService.updateActiveAlerts(activeAlertIds); -// } -// } catch (IOException e) { -// throw new RuntimeException(e); -// } -// } + if (currentTime >= startTime && currentTime <= endTime) { + try { + correlationAlertService.getActiveAlerts(correlationRule.getId(), currentTime, new ActionListener<>() { + @Override + public void onResponse(CorrelationAlertsList correlationAlertsList) { + if (correlationAlertsList.getTotalAlerts() == 0) { + addCorrelationAlertIntoIndex(); + } else { + for (CorrelationAlert correlationAlert: correlationAlertsList.getCorrelationAlertList()) { + updateCorrelationAlert(correlationAlert); + } + } + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to search active correlation alert", e); + } + }); + } catch (Exception e) { + log.error("Failed to fetch active alerts in the time window", e); + } + } + } + + private void addCorrelationAlertIntoIndex() { + CorrelationAlert correlationAlert = new CorrelationAlert( + correlatedFindingIds, + correlationRule.getId(), + correlationRule.getName(), + UUID.randomUUID().toString(), + 1L, + 1, + null, + correlationRule.getCorrelationTrigger().getName(), + Alert.State.ACTIVE, + Instant.ofEpochMilli(startTime), + Instant.ofEpochMilli(endTime), + null, + null, + correlationRule.getCorrelationTrigger().getSeverity(), + new ArrayList<>() + ); + insertCorrelationAlert(correlationAlert); + } + + private void updateCorrelationAlert(CorrelationAlert correlationAlert) { + CorrelationAlert newCorrelationAlert = new CorrelationAlert( + correlatedFindingIds, + correlationAlert.getCorrelationRuleId(), + correlationAlert.getCorrelationRuleName(), + correlationAlert.getId(), + 1L, + 1, + correlationAlert.getUser(), + correlationRule.getCorrelationTrigger().getName(), + Alert.State.ACTIVE, + Instant.ofEpochMilli(startTime), + Instant.ofEpochMilli(endTime), + null, + null, + correlationRule.getCorrelationTrigger().getSeverity(), + new ArrayList<>() + ); + insertCorrelationAlert(newCorrelationAlert); + } + + private void insertCorrelationAlert(CorrelationAlert correlationAlert) { + correlationAlertService.indexCorrelationAlert(correlationAlert, indexTimeout, new ActionListener<>() { + @Override + public void onResponse(IndexResponse indexResponse) { + log.info("Successfully updated the index .opensearch-sap-correlation-alerts: {}", indexResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to index correlation alert", e); + } + }); } } } + diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java new file mode 100644 index 000000000..4ffde3e6c --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java @@ -0,0 +1,36 @@ +package org.opensearch.securityanalytics.correlation.alert.notifications; + +import org.opensearch.securityanalytics.model.CorrelationRule; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public abstract class CorrelationAlertContext { + + private final CorrelationRule correlationRule; + private final List correlatedFindingIds; + protected CorrelationAlertContext(CorrelationRule correlationRule, List correlatedFindingIds) { + this.correlationRule = correlationRule; + this.correlatedFindingIds = correlatedFindingIds; + } + + /** + * Mustache templates need special permissions to reflectively introspect field names. To avoid doing this we + * translate the context to a Map of Strings to primitive types, which can be accessed without reflection. + */ + public Map asTemplateArg() { + Map templateArg = new HashMap<>(); + templateArg.put("correlationRule", correlationRule); + templateArg.put("correlatedFindingIds", correlatedFindingIds); + return templateArg; + } + + public CorrelationRule getCorrelationRule() { + return correlationRule; + } + + public List getCorrelatedFindingIds() { + return correlatedFindingIds; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java index e19d39941..483ecd522 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java @@ -1,4 +1,88 @@ package org.opensearch.securityanalytics.correlation.alert.notifications; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.notifications.NotificationsPluginInterface; +import org.opensearch.commons.notifications.action.*; +import org.opensearch.commons.notifications.model.ChannelMessage; +import org.opensearch.commons.notifications.model.EventSource; +import org.opensearch.commons.notifications.model.SeverityType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +import org.opensearch.script.Script; +import org.opensearch.script.ScriptService; +import org.opensearch.script.TemplateScript; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + public class NotificationService { + + private static final Logger logger = LogManager.getLogger(NotificationService.class); + + private static ScriptService scriptService; + /** + * Extension function for publishing a notification to a channel in the Notification plugin. + */ + public static void sendNotification(NodeClient client, String configId, String severity, List channelIds) throws IOException { + ChannelMessage message = generateMessage(configId); + NotificationsPluginInterface.INSTANCE.sendNotification(client, new EventSource(configId, configId, SeverityType.CRITICAL, channelIds), message, channelIds, new ActionListener() { + @Override + public void onResponse(SendNotificationResponse sendNotificationResponse) { + if(sendNotificationResponse.getStatus() == RestStatus.OK) { + logger.info("Successfully sent a notification, Notification Event: " + sendNotificationResponse.getNotificationEvent()); + } + else { + logger.error("Successfully sent a notification, Notification Event: " + sendNotificationResponse.getNotificationEvent()); + } + + } + @Override + public void onFailure(Exception e) { + logger.error("Failed while sending a notification: " + e.toString()); + new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e); + } + }); + } + public static String compileTemplate(Script template, CorrelationAlertContext ctx) { + TemplateScript.Factory factory = scriptService.compile(template, TemplateScript.CONTEXT); + Map params = new HashMap<>(template.getParams()); + params.put("ctx", ctx.asTemplateArg()); + TemplateScript templateScript = factory.newInstance(params); + return templateScript.execute(); + } + + public static ChannelMessage generateMessage(String configId) { + return new ChannelMessage( + getMessageTextDescription(configId), + getMessageHtmlDescription(configId), + null + ); + } + + public static EventSource generateEventSource(String configId, String severity, List tags) { + return new EventSource( + getMessageTitle(configId), + configId, + SeverityType.INFO, + tags + ); + } + + private static String getMessageTitle(String configId) { + return "Test Message Title-" + configId; // TODO: change as per spec + } + + private static String getMessageTextDescription(String configId) { + return "Test message content body for config id " + configId; // TODO: change as per spec + } + + private static String getMessageHtmlDescription(String configId) { + return "
Test Message

Test Message for config id " + configId + "

"; // TODO: change as per spec + } + } diff --git a/src/main/java/org/opensearch/securityanalytics/model/CorrelationAlert.java b/src/main/java/org/opensearch/securityanalytics/model/CorrelationAlert.java deleted file mode 100644 index 893729cea..000000000 --- a/src/main/java/org/opensearch/securityanalytics/model/CorrelationAlert.java +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.securityanalytics.model; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.commons.alerting.model.ActionExecutionResult; -import org.opensearch.commons.alerting.model.Alert; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.core.xcontent.XContentParserUtils; - -import java.io.IOException; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -/** - * Model for docs store in .opensearch-sap-correlation-alerts index. - * Correlation alerts are created when a detector finding triggers correlation - */ -public class CorrelationAlert implements Writeable, ToXContentObject { - - private static final Logger log = LogManager.getLogger(CorrelationAlert.class); - private static final String ID_FIELD = "id"; - private static final String TRIGGER_TIME_FIELD = "trigger_time"; - private static final String ACKNOWLEDGED_TIME_FIELD = "acknowledged_time"; - private static final String ACTION_EXECUTION_RESULTS_FIELD = "action_execution_results"; - private static final String VERSION_FIELD = "version"; - private static final String SCHEMA_VERSION_FIELD = "schema_version"; - private static final String TRIGGER_NAME_FIELD = "trigger_name"; - private static final String ERROR_MESSAGE_FIELD = "error_message"; - private static final String CORRELATED_FINDING_IDS_FIELD = "correlated_finding_ids"; - private static final String CORRELATED_RULE_NAME_FIELD = "correlated_rule_name"; - private static final String CORRELATION_RULE_ID_FIELD = "correlation_rule_id"; - private static final String USER_FIELD = "user"; - private static final String SEVERITY_FIELD = "severity"; - private static final String STATE_FIELD = "state"; - public static final String NO_ID = ""; - public static final Long NO_VERSION = 1L; - public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - CorrelationAlert.class, - new ParseField(ID_FIELD), - xcp -> parse(xcp, null, null) - ); - - private String id; - private Instant triggerTime; - private final Instant acknowledgedTime; - private final List actionExecutionResults; - private Long version; - private final Long schemaVersion; - private final String triggerName; - private final String errorMessage; - private final List correlatedFindingIds; - private final String correlationRuleId; - private final String correlationRuleName; - private final User user; - private final String severity; - private final Alert.State state; - - - public CorrelationAlert(String triggerId, Instant acknowledgedTime, Instant triggerTime, - String correlationRuleId, - List actionExecutionResults, Long version, Long schemaVersion, - String triggerName, String errorMessage, - List correlatedFindingIds, String correlationRuleName, User user, - String severity, Alert.State state) { - this.id = triggerId; - this.acknowledgedTime = acknowledgedTime; - this.actionExecutionResults = actionExecutionResults; - this.version = version; - this.schemaVersion = schemaVersion; - this.triggerName = triggerName; - this.triggerTime = triggerTime; - this.errorMessage = errorMessage; - this.correlatedFindingIds = correlatedFindingIds; - this.correlationRuleName = correlationRuleName; - this.correlationRuleId = correlationRuleId; - this.user = user; - this.severity = severity; - this.state = state; - } - - public CorrelationAlert(StreamInput sin) throws IOException { - this( - sin.readString(), - sin.readOptionalInstant(), - sin.readInstant(), - sin.readString(), - sin.readList(ActionExecutionResult::new), - sin.readLong(), - sin.readLong(), - sin.readString(), - sin.readOptionalString(), - sin.readStringList(), - sin.readString(), - sin.readBoolean() ? new User(sin) : null, - sin.readString(), - sin.readEnum(Alert.State.class) - ); - } - - public static CorrelationAlert docParse(XContentParser xcp, String id, Long version) throws IOException { - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, xcp.nextToken(), xcp); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp); - CorrelationAlert correlationAlert = xcp.namedObject(CorrelationAlert.class, xcp.currentName(), null); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, xcp.nextToken(), xcp); - - correlationAlert.setId(id); - correlationAlert.setVersion(version); - return correlationAlert; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(id); - out.writeOptionalInstant(acknowledgedTime); - out.writeInstant(triggerTime); - out.writeString(correlationRuleId); - out.writeCollection(actionExecutionResults); - out.writeLong(version); - out.writeLong(schemaVersion); - out.writeString(triggerName); - out.writeOptionalString(errorMessage); - out.writeStringCollection(correlatedFindingIds); - out.writeString(correlationRuleName); - out.writeBoolean(user != null); - if (user != null) { - user.writeTo(out); - } - out.writeString(severity); - out.writeEnum(state); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return createXContentBuilder(builder, params, true); - } - - public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException { - return createXContentBuilder(builder, params, false); - } - - private XContentBuilder createXContentBuilder(XContentBuilder builder, Params params, boolean secure) throws IOException { - builder.startObject() - .field(ID_FIELD, id) - .field(ACKNOWLEDGED_TIME_FIELD, acknowledgedTime) - .field(TRIGGER_TIME_FIELD, triggerTime) - .field(CORRELATED_FINDING_IDS_FIELD, correlatedFindingIds) - .field(ACTION_EXECUTION_RESULTS_FIELD, actionExecutionResults) - .field(VERSION_FIELD, version) - .field(SCHEMA_VERSION_FIELD, schemaVersion) - .field(TRIGGER_NAME_FIELD, triggerName) - .field(ERROR_MESSAGE_FIELD, errorMessage) - .field(CORRELATION_RULE_ID_FIELD, correlationRuleId) - .field(CORRELATED_RULE_NAME_FIELD, correlationRuleName); - if (!secure) { - if (user == null) { - builder.nullField(USER_FIELD); - } else { - builder.field(USER_FIELD, user); - } - } - builder.field(SEVERITY_FIELD, severity); - builder.field(STATE_FIELD, state); - return builder; - } - - public static CorrelationAlert parse(XContentParser xcp, String id, Long version) throws IOException { - if (id == null) { - id = NO_ID; - } - if (version == null) { - version = NO_VERSION; - } - Instant acknowledgedTime = null; - Instant triggerTime = null; - Instant endTime = null; - List actionExecutionResults = new ArrayList<>(); - Long schemaVersion = NO_VERSION; - String triggerName = ""; - String triggerId = ""; - String errorMessage = ""; - List correlatedFindingIds = new ArrayList<>(); - String correlationRuleName = ""; - String correlationRuleId = ""; - User user = null; - String severity = ""; - Alert.State state = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); - while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = xcp.currentName(); - xcp.nextToken(); - - switch (fieldName) { - case ID_FIELD: - id = xcp.text(); - break; - case TRIGGER_TIME_FIELD: - if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { - triggerTime = null; - } else if (xcp.currentToken().isValue()) { - triggerTime = Instant.ofEpochMilli(xcp.longValue()); - } else { - XContentParserUtils.throwUnknownToken(xcp.currentToken(), xcp.getTokenLocation()); - triggerTime = null; - } - break; - case ACKNOWLEDGED_TIME_FIELD: - if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { - acknowledgedTime = null; - } else if (xcp.currentToken().isValue()) { - acknowledgedTime = Instant.ofEpochMilli(xcp.longValue()); - } else { - XContentParserUtils.throwUnknownToken(xcp.currentToken(), xcp.getTokenLocation()); - acknowledgedTime = null; - } - break; - case ACTION_EXECUTION_RESULTS_FIELD: - ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); - while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { - actionExecutionResults.add(ActionExecutionResult.parse(xcp)); - } - break; - case VERSION_FIELD: - version = xcp.longValue(); - break; - case SCHEMA_VERSION_FIELD: - schemaVersion = xcp.longValue(); - break; - case TRIGGER_NAME_FIELD: - triggerName = xcp.text(); - break; - case ERROR_MESSAGE_FIELD: - errorMessage = xcp.text(); - break; - case CORRELATED_FINDING_IDS_FIELD: - ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); - while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { - String correlatedFindingId = xcp.text(); - correlatedFindingIds.add(correlatedFindingId); - } - break; - case CORRELATED_RULE_NAME_FIELD: - correlationRuleName = xcp.text(); - break; - case CORRELATION_RULE_ID_FIELD: - correlationRuleId = xcp.text(); - break; - case USER_FIELD: - if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { - user = null; - } else { - user = User.parse(xcp); - } - break; - case SEVERITY_FIELD: - severity = xcp.text(); - break; - case STATE_FIELD: - state = Alert.State.valueOf(xcp.text()); - break; - } - } - return new CorrelationAlert( - id, - acknowledgedTime, - triggerTime, - correlationRuleId, - actionExecutionResults, - version, - schemaVersion, - triggerName, - errorMessage, - correlatedFindingIds, - correlationRuleName, - user, - severity, - state - ); - } - - public String getId() { - return id; - } - - public Instant getTriggerTime() { - return triggerTime; - } - - public Instant getAcknowledgedTime() { - return acknowledgedTime; - } - - public String getCorrelationRuleId() { - return correlationRuleId; - } - - public String getCorrelationRuleName() { - return correlationRuleName; - } - - public List getActionExecutionResults() { - return actionExecutionResults; - } - - public Long getVersion() { - return version; - } - - public Long getSchemaVersion() { - return schemaVersion; - } - - public String getTriggerName() { - return triggerName; - } - - public String getErrorMessage() { - return errorMessage; - } - - public List getCorrelatedFindingIds() { - return correlatedFindingIds; - } - - public User getUser() { - return user; - } - - public String getSeverity() { - return severity; - } - - public Alert.State getState() { - return state; - } - - private void setVersion(Long version) { - this.version = version; - } - - private void setId(String id) { - this.id = id; - } -} diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 130206e54..3a8f96f97 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -49,6 +49,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.correlation.JoinEngine; import org.opensearch.securityanalytics.correlation.VectorEmbeddingsEngine; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; @@ -99,6 +100,8 @@ public class TransportCorrelateFindingAction extends HandledTransportAction { if (createIndexResponse.isAcknowledged()) { - IndexUtils.correlationMetadataIndexUpdated(); + IndexUtils.correlationAlertIndexUpdated(); } else { correlateFindingAction.onFailures(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); } @@ -212,7 +216,7 @@ public class AsyncCorrelateFindingAction { this.response =new AtomicReference<>(); - this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this, logTypeService, enableAutoCorrelation); + this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService); this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this); } diff --git a/src/main/resources/mappings/correlation_alert_mapping.json b/src/main/resources/mappings/correlation_alert_mapping.json index 2a7acabda..585a036c6 100644 --- a/src/main/resources/mappings/correlation_alert_mapping.json +++ b/src/main/resources/mappings/correlation_alert_mapping.json @@ -20,9 +20,6 @@ } } }, - "trigger_time": { - "type": "date" - }, "error_message": { "type": "text" },