From 29ba163a13221c5c5fc896c8ca8211db75e9ba04 Mon Sep 17 00:00:00 2001 From: Riya Saxena Date: Mon, 3 Jun 2024 09:43:53 -0700 Subject: [PATCH] alerts in correlations notification service added Signed-off-by: Riya Saxena --- .../SecurityAnalyticsPlugin.java | 5 +- .../correlation/JoinEngine.java | 7 +- .../alert/CorrelationAlertService.java | 6 +- .../alert/CorrelationRuleScheduler.java | 42 +++++++-- .../CorrelationAlertContext.java | 23 +++-- .../notifications/NotificationService.java | 94 ++++++++++++------- .../model/CorrelationRuleTrigger.java | 44 ++++----- .../TransportCorrelateFindingAction.java | 8 +- 8 files changed, 146 insertions(+), 83 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index f10527612..f18f75639 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -17,6 +17,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionResponse; import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; @@ -54,6 +55,7 @@ import org.opensearch.securityanalytics.action.*; import org.opensearch.securityanalytics.correlation.index.codec.CorrelationCodecService; import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.correlation.index.mapper.CorrelationVectorFieldMapper; import org.opensearch.securityanalytics.correlation.index.query.CorrelationQueryBuilder; import org.opensearch.securityanalytics.indexmanagment.DetectorIndexManagementService; @@ -167,12 +169,13 @@ public Collection createComponents(Client client, TIFJobUpdateService tifJobUpdateService = new TIFJobUpdateService(clusterService, tifJobParameterService, threatIntelFeedDataService, builtInTIFMetadataLoader); TIFLockService threatIntelLockService = new TIFLockService(clusterService, client); CorrelationAlertService correlationAlertService = new CorrelationAlertService(client, xContentRegistry); + NotificationService notificationServiceService = new NotificationService((NodeClient)client, scriptService); TIFJobRunner.getJobRunnerInstance().initialize(clusterService, tifJobUpdateService, tifJobParameterService, threatIntelLockService, threadPool, detectorThreatIntelService); return List.of( detectorIndices, correlationIndices, correlationRuleIndices, ruleTopicIndices, customLogTypeIndices, ruleIndices, mapperService, indexTemplateManager, builtinLogTypeLoader, builtInTIFMetadataLoader, threatIntelFeedDataService, detectorThreatIntelService, - tifJobUpdateService, tifJobParameterService, threatIntelLockService, correlationAlertService); + tifJobUpdateService, tifJobParameterService, threatIntelLockService, correlationAlertService, notificationServiceService); } @Override diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 82354a419..20a6de813 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -74,13 +74,15 @@ public class JoinEngine { private final CorrelationAlertService correlationAlertService; + private final NotificationService notificationService; + private volatile TimeValue indexTimeout; private static final Logger log = LogManager.getLogger(JoinEngine.class); public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry, long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, - LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService) { + LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService) { this.client = client; this.request = request; this.xContentRegistry = xContentRegistry; @@ -90,6 +92,7 @@ public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRe this.logTypeService = logTypeService; this.enableAutoCorrelations = enableAutoCorrelations; this.correlationAlertService = correlationAlertService; + this.notificationService = notificationService; } public void onSearchDetectorResponse(Detector detector, Finding finding) { @@ -551,7 +554,7 @@ private void getCorrelatedFindings(String detectorType, Map } if (!correlatedFindings.isEmpty()) { - CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService); + CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService, notificationService); correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout); correlationRuleScheduler.shutdown(); } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java index 062bae02c..4bc67b72c 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java @@ -25,7 +25,6 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.commons.alerting.model.CorrelationAlert; import org.opensearch.securityanalytics.util.CorrelationIndices; - import java.io.IOException; import java.time.Instant; import java.util.List; @@ -73,7 +72,7 @@ public void getActiveAlerts(String ruleId, long currentTime, ActionListener parseCorrelationAlerts(final SearchResponse respon LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() ); - + xcp.nextToken(); CorrelationAlert correlationAlert = CorrelationAlert.parse(xcp, hit.getId(), hit.getVersion()); alerts.add(correlationAlert); } return alerts; } - // Helper method to convert User object to map } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java index b3690afb6..bc0ca1809 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java @@ -11,7 +11,12 @@ import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; - +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; +import org.opensearch.securityanalytics.correlation.alert.notifications.CorrelationAlertContext; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.model.action.Action; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; import java.time.Instant; import java.util.UUID; import java.util.List; @@ -19,17 +24,22 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import org.opensearch.script.ScriptService; public class CorrelationRuleScheduler { private final Logger log = LogManager.getLogger(CorrelationRuleScheduler.class); private final Client client; private final CorrelationAlertService correlationAlertService; + private final NotificationService notificationService; private final ExecutorService executorService; + private static ScriptService scriptService; - public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService) { + public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService, NotificationService notificationService) { this.client = client; + this.scriptService = scriptService; this.correlationAlertService = correlationAlertService; + this.notificationService = notificationService; this.executorService = Executors.newCachedThreadPool(); } @@ -44,7 +54,7 @@ public void schedule(List correlationRules, Map findingIds, TimeValue indexTimeout) { + private void scheduleRule(CorrelationRule correlationRule, List findingIds, TimeValue indexTimeout, String sourceFindingId) { long startTime = Instant.now().toEpochMilli(); long endTime = startTime + correlationRule.getCorrTimeWindow(); - executorService.submit(new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, indexTimeout)); + RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId); + executorService.submit(ruleTask); } private class RuleTask implements Runnable { @@ -65,15 +76,19 @@ private class RuleTask implements Runnable { private final long endTime; private final List correlatedFindingIds; private final CorrelationAlertService correlationAlertService; + private final NotificationService notificationService; private final TimeValue indexTimeout; + private final String sourceFindingId; - public RuleTask(CorrelationRule correlationRule, List correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, TimeValue indexTimeout) { + public RuleTask(CorrelationRule correlationRule, List correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId) { this.correlationRule = correlationRule; this.correlatedFindingIds = correlatedFindingIds; this.startTime = startTime; this.endTime = endTime; this.correlationAlertService = correlationAlertService; + this.notificationService = notificationService; this.indexTimeout = indexTimeout; + this.sourceFindingId = sourceFindingId; } @Override @@ -86,6 +101,19 @@ public void run() { public void onResponse(CorrelationAlertsList correlationAlertsList) { if (correlationAlertsList.getTotalAlerts() == 0) { addCorrelationAlertIntoIndex(); + List actions = correlationRule.getCorrelationTrigger().getActions(); + for (Action action : actions) { + CorrelationAlertContext ctx = new CorrelationAlertContext(correlatedFindingIds, correlationRule.getName(), correlationRule.getCorrTimeWindow(), sourceFindingId); + String transfomedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate()); + String transformedMessage = notificationService.compileTemplate(ctx, action.getMessageTemplate()); + try { + notificationService.sendNotification(action.getDestinationId(), correlationRule.getCorrelationTrigger().getSeverity(), transfomedSubject, transformedMessage); + } catch (Exception e) { + log.error("Failed while sending a notification: " + e.toString()); + new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e); + } + + } } else { for (CorrelationAlert correlationAlert: correlationAlertsList.getCorrelationAlertList()) { updateCorrelationAlert(correlationAlert); @@ -96,10 +124,12 @@ public void onResponse(CorrelationAlertsList correlationAlertsList) { @Override public void onFailure(Exception e) { log.error("Failed to search active correlation alert", e); + new SecurityAnalyticsException("Failed to search active correlation alert", RestStatus.INTERNAL_SERVER_ERROR, e); } }); } catch (Exception e) { log.error("Failed to fetch active alerts in the time window", e); + new SecurityAnalyticsException("Failed to get active alerts in the correlationRuletimewindow", RestStatus.INTERNAL_SERVER_ERROR, e); } } } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java index 4ffde3e6c..90b8ded25 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java @@ -6,13 +6,17 @@ import java.util.List; import java.util.Map; -public abstract class CorrelationAlertContext { +public class CorrelationAlertContext { - private final CorrelationRule correlationRule; private final List correlatedFindingIds; - protected CorrelationAlertContext(CorrelationRule correlationRule, List correlatedFindingIds) { - this.correlationRule = correlationRule; + private final String sourceFinding; + private final String correlationRuleName; + private final long timeWindow; + public CorrelationAlertContext(List correlatedFindingIds, String correlationRuleName, long timeWindow, String sourceFinding) { this.correlatedFindingIds = correlatedFindingIds; + this.correlationRuleName = correlationRuleName; + this.timeWindow = timeWindow; + this.sourceFinding = sourceFinding; } /** @@ -21,16 +25,11 @@ protected CorrelationAlertContext(CorrelationRule correlationRule, List */ public Map asTemplateArg() { Map templateArg = new HashMap<>(); - templateArg.put("correlationRule", correlationRule); templateArg.put("correlatedFindingIds", correlatedFindingIds); + templateArg.put("sourceFinding", sourceFinding); + templateArg.put("correlationRuleName", correlationRuleName); + templateArg.put("timeWindow", timeWindow); return templateArg; } - public CorrelationRule getCorrelationRule() { - return correlationRule; - } - - public List getCorrelatedFindingIds() { - return correlatedFindingIds; - } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java index 483ecd522..5782bdf5d 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java @@ -8,29 +8,45 @@ import org.opensearch.commons.notifications.model.ChannelMessage; import org.opensearch.commons.notifications.model.EventSource; import org.opensearch.commons.notifications.model.SeverityType; +import org.opensearch.commons.notifications.model.NotificationConfigInfo; +import org.opensearch.commons.notifications.action.GetNotificationConfigRequest; +import org.opensearch.commons.notifications.action.GetNotificationConfigResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; - -import org.opensearch.script.Script; import org.opensearch.script.ScriptService; -import org.opensearch.script.TemplateScript; import java.io.IOException; import java.util.HashMap; import java.util.List; +import java.util.ArrayList; import java.util.Map; +import java.util.Set; +import java.util.HashSet; +import java.util.Collections; +import org.opensearch.script.Script; +import org.opensearch.script.TemplateScript; +import org.opensearch.commons.notifications.model.SeverityType; public class NotificationService { private static final Logger logger = LogManager.getLogger(NotificationService.class); private static ScriptService scriptService; + private final NodeClient client; + + public NotificationService(NodeClient client, ScriptService scriptService) { + this.client = client; + this.scriptService = scriptService; + } /** * Extension function for publishing a notification to a channel in the Notification plugin. */ - public static void sendNotification(NodeClient client, String configId, String severity, List channelIds) throws IOException { - ChannelMessage message = generateMessage(configId); - NotificationsPluginInterface.INSTANCE.sendNotification(client, new EventSource(configId, configId, SeverityType.CRITICAL, channelIds), message, channelIds, new ActionListener() { + public void sendNotification(String configId, String severity, String subject, String notificationMessageText) throws IOException { + ChannelMessage message = generateMessage(notificationMessageText); + List channelIds = new ArrayList<>(); + channelIds.add(configId); + SeverityType severityType = SeverityType.Companion.fromTagOrDefault(severity); + NotificationsPluginInterface.INSTANCE.sendNotification(client, new EventSource(subject, configId, severityType, Collections.emptyList()), message, channelIds, new ActionListener() { @Override public void onResponse(SendNotificationResponse sendNotificationResponse) { if(sendNotificationResponse.getStatus() == RestStatus.OK) { @@ -44,45 +60,53 @@ public void onResponse(SendNotificationResponse sendNotificationResponse) { @Override public void onFailure(Exception e) { logger.error("Failed while sending a notification: " + e.toString()); - new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e); } }); } - public static String compileTemplate(Script template, CorrelationAlertContext ctx) { - TemplateScript.Factory factory = scriptService.compile(template, TemplateScript.CONTEXT); - Map params = new HashMap<>(template.getParams()); - params.put("ctx", ctx.asTemplateArg()); - TemplateScript templateScript = factory.newInstance(params); - return templateScript.execute(); - } - public static ChannelMessage generateMessage(String configId) { - return new ChannelMessage( - getMessageTextDescription(configId), - getMessageHtmlDescription(configId), - null - ); - } + /** + * Gets a NotificationConfigInfo object by ID if it exists. + */ + public GetNotificationConfigResponse getNotificationConfigInfo(String id) { - public static EventSource generateEventSource(String configId, String severity, List tags) { - return new EventSource( - getMessageTitle(configId), - configId, - SeverityType.INFO, - tags - ); - } + Set idSet = new HashSet(); + idSet.add(id); + GetNotificationConfigRequest getNotificationConfigRequest = new GetNotificationConfigRequest(idSet, 0, 10, null, null, new HashMap<>()); + GetNotificationConfigResponse configResp = null; + NotificationsPluginInterface.INSTANCE.getNotificationConfig(client, getNotificationConfigRequest, new ActionListener() { + @Override + public void onResponse(GetNotificationConfigResponse getNotificationConfigResponse) { + if (getNotificationConfigResponse.getStatus() == RestStatus.OK) { + getNotificationConfigResponse = configResp; + } else { + logger.error("Successfully sent a notification, Notification Event: " + getNotificationConfigResponse); + } + } - private static String getMessageTitle(String configId) { - return "Test Message Title-" + configId; // TODO: change as per spec + @Override + public void onFailure(Exception e) { + logger.error("Notification config [" + id + "] was not found"); + new SecurityAnalyticsException("Failed to fetch notification config", RestStatus.INTERNAL_SERVER_ERROR, e); + } + }); + logger.info("Notification config response is: {} ", configResp); + return configResp; } - private static String getMessageTextDescription(String configId) { - return "Test message content body for config id " + configId; // TODO: change as per spec + public static ChannelMessage generateMessage(String message) { + return new ChannelMessage( + message, + null, + null + ); } - private static String getMessageHtmlDescription(String configId) { - return "
Test Message

Test Message for config id " + configId + "

"; // TODO: change as per spec + public static String compileTemplate(CorrelationAlertContext ctx, Script template) { + TemplateScript.Factory factory = scriptService.compile(template, TemplateScript.CONTEXT); + Map params = new HashMap<>(template.getParams()); + params.put("ctx", ctx.asTemplateArg()); + TemplateScript templateScript = factory.newInstance(params); + return templateScript.execute(); } } diff --git a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java index b1a5b77b4..3426c7eb1 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java +++ b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java @@ -166,26 +166,28 @@ public String getSeverity() { } public List getActions() { - List transformedActions = new ArrayList<>(); - - if (actions != null) { - for (Action action : actions) { - String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode() : ""; - subjectTemplate = subjectTemplate.replace("{{ctx.detector", "{{ctx.monitor"); - - action.getMessageTemplate(); - String messageTemplate = action.getMessageTemplate().getIdOrCode(); - messageTemplate = messageTemplate.replace("{{ctx.detector", "{{ctx.monitor"); - - Action transformedAction = new Action(action.getName(), action.getDestinationId(), - new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, subjectTemplate, Collections.emptyMap()), - new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, messageTemplate, Collections.emptyMap()), - action.getThrottleEnabled(), action.getThrottle(), - action.getId(), action.getActionExecutionPolicy()); - - transformedActions.add(transformedAction); - } - } - return transformedActions; +// List transformedActions = new ArrayList<>(); +// +// if (actions != null) { +// for (Action action : actions) { +// String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode() : ""; +// CorrelationContext ctx = CorrelationContext(rule, sourceFindingId); +// no +// +// action.getMessageTemplate(); +// String messageTemplate = action.getMessageTemplate().getIdOrCode(); +// messageTemplate = messageTemplate.replace("{{ctx.detector", "{{ctx.monitor"); +// +// Action transformedAction = new Action(action.getName(), action.getDestinationId(), +// new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, subjectTemplate, Collections.emptyMap()), +// new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, messageTemplate, Collections.emptyMap()), +// action.getThrottleEnabled(), action.getThrottle(), +// action.getId(), action.getActionExecutionPolicy()); +// +// transformedActions.add(transformedAction); +// } +// } + return actions; } + } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 3a8f96f97..26c6b0e5b 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -50,6 +50,7 @@ import org.opensearch.securityanalytics.correlation.JoinEngine; import org.opensearch.securityanalytics.correlation.VectorEmbeddingsEngine; import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; @@ -102,6 +103,8 @@ public class TransportCorrelateFindingAction extends HandledTransportAction(); - this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService); + this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService); this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this); }