Skip to content

Commit

Permalink
alerts in correlations notification service added
Browse files Browse the repository at this point in the history
Signed-off-by: Riya Saxena <[email protected]>
  • Loading branch information
riysaxen-amzn committed Jun 3, 2024
1 parent 8603d5a commit 29ba163
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
Expand Down Expand Up @@ -54,6 +55,7 @@
import org.opensearch.securityanalytics.action.*;
import org.opensearch.securityanalytics.correlation.index.codec.CorrelationCodecService;
import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService;
import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService;
import org.opensearch.securityanalytics.correlation.index.mapper.CorrelationVectorFieldMapper;
import org.opensearch.securityanalytics.correlation.index.query.CorrelationQueryBuilder;
import org.opensearch.securityanalytics.indexmanagment.DetectorIndexManagementService;
Expand Down Expand Up @@ -167,12 +169,13 @@ public Collection<Object> createComponents(Client client,
TIFJobUpdateService tifJobUpdateService = new TIFJobUpdateService(clusterService, tifJobParameterService, threatIntelFeedDataService, builtInTIFMetadataLoader);
TIFLockService threatIntelLockService = new TIFLockService(clusterService, client);
CorrelationAlertService correlationAlertService = new CorrelationAlertService(client, xContentRegistry);
NotificationService notificationServiceService = new NotificationService((NodeClient)client, scriptService);
TIFJobRunner.getJobRunnerInstance().initialize(clusterService, tifJobUpdateService, tifJobParameterService, threatIntelLockService, threadPool, detectorThreatIntelService);

return List.of(
detectorIndices, correlationIndices, correlationRuleIndices, ruleTopicIndices, customLogTypeIndices, ruleIndices,
mapperService, indexTemplateManager, builtinLogTypeLoader, builtInTIFMetadataLoader, threatIntelFeedDataService, detectorThreatIntelService,
tifJobUpdateService, tifJobParameterService, threatIntelLockService, correlationAlertService);
tifJobUpdateService, tifJobParameterService, threatIntelLockService, correlationAlertService, notificationServiceService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ public class JoinEngine {

private final CorrelationAlertService correlationAlertService;

private final NotificationService notificationService;

private volatile TimeValue indexTimeout;

private static final Logger log = LogManager.getLogger(JoinEngine.class);

public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry,
long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction,
LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService) {
LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService) {
this.client = client;
this.request = request;
this.xContentRegistry = xContentRegistry;
Expand All @@ -90,6 +92,7 @@ public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRe
this.logTypeService = logTypeService;
this.enableAutoCorrelations = enableAutoCorrelations;
this.correlationAlertService = correlationAlertService;
this.notificationService = notificationService;
}

public void onSearchDetectorResponse(Detector detector, Finding finding) {
Expand Down Expand Up @@ -551,7 +554,7 @@ private void getCorrelatedFindings(String detectorType, Map<String, List<String>
}

if (!correlatedFindings.isEmpty()) {
CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService);
CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService, notificationService);
correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout);
correlationRuleScheduler.shutdown();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.commons.alerting.model.CorrelationAlert;
import org.opensearch.securityanalytics.util.CorrelationIndices;

import java.io.IOException;
import java.time.Instant;
import java.util.List;
Expand Down Expand Up @@ -73,7 +72,7 @@ public void getActiveAlerts(String ruleId, long currentTime, ActionListener<Corr
listener.onResponse(new CorrelationAlertsList(Collections.emptyList(), 0));
} else {
listener.onResponse(new CorrelationAlertsList(
parseCorrelationAlerts(searchResponse),
Collections.emptyList(),
searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ?
(int) searchResponse.getHits().getTotalHits().value : 0)
);
Expand Down Expand Up @@ -125,13 +124,12 @@ public List<CorrelationAlert> parseCorrelationAlerts(final SearchResponse respon
LoggingDeprecationHandler.INSTANCE,
hit.getSourceAsString()
);

xcp.nextToken();
CorrelationAlert correlationAlert = CorrelationAlert.parse(xcp, hit.getId(), hit.getVersion());
alerts.add(correlationAlert);
}
return alerts;
}
// Helper method to convert User object to map
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,35 @@
import org.opensearch.securityanalytics.model.CorrelationQuery;
import org.opensearch.securityanalytics.model.CorrelationRule;
import org.opensearch.securityanalytics.model.CorrelationRuleTrigger;

import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService;
import org.opensearch.securityanalytics.correlation.alert.notifications.CorrelationAlertContext;
import org.opensearch.client.node.NodeClient;
import org.opensearch.commons.alerting.model.action.Action;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.securityanalytics.util.SecurityAnalyticsException;
import java.time.Instant;
import java.util.UUID;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.opensearch.script.ScriptService;

public class CorrelationRuleScheduler {

private final Logger log = LogManager.getLogger(CorrelationRuleScheduler.class);
private final Client client;
private final CorrelationAlertService correlationAlertService;
private final NotificationService notificationService;
private final ExecutorService executorService;
private static ScriptService scriptService;

public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService) {
public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService, NotificationService notificationService) {
this.client = client;
this.scriptService = scriptService;
this.correlationAlertService = correlationAlertService;
this.notificationService = notificationService;
this.executorService = Executors.newCachedThreadPool();
}

Expand All @@ -44,7 +54,7 @@ public void schedule(List<CorrelationRule> correlationRules, Map<String, List<St
findingIds.addAll(categoryFindingIds);
}
}
scheduleRule(rule, findingIds, indexTimeout);
scheduleRule(rule, findingIds, indexTimeout, sourceFinding);
}
}
}
Expand All @@ -53,10 +63,11 @@ public void shutdown() {
executorService.shutdown();
}

private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout) {
private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout, String sourceFindingId) {
long startTime = Instant.now().toEpochMilli();
long endTime = startTime + correlationRule.getCorrTimeWindow();
executorService.submit(new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, indexTimeout));
RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId);
executorService.submit(ruleTask);
}

private class RuleTask implements Runnable {
Expand All @@ -65,15 +76,19 @@ private class RuleTask implements Runnable {
private final long endTime;
private final List<String> correlatedFindingIds;
private final CorrelationAlertService correlationAlertService;
private final NotificationService notificationService;
private final TimeValue indexTimeout;
private final String sourceFindingId;

public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, TimeValue indexTimeout) {
public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId) {
this.correlationRule = correlationRule;
this.correlatedFindingIds = correlatedFindingIds;
this.startTime = startTime;
this.endTime = endTime;
this.correlationAlertService = correlationAlertService;
this.notificationService = notificationService;
this.indexTimeout = indexTimeout;
this.sourceFindingId = sourceFindingId;
}

@Override
Expand All @@ -86,6 +101,19 @@ public void run() {
public void onResponse(CorrelationAlertsList correlationAlertsList) {
if (correlationAlertsList.getTotalAlerts() == 0) {
addCorrelationAlertIntoIndex();
List<Action> actions = correlationRule.getCorrelationTrigger().getActions();
for (Action action : actions) {
CorrelationAlertContext ctx = new CorrelationAlertContext(correlatedFindingIds, correlationRule.getName(), correlationRule.getCorrTimeWindow(), sourceFindingId);
String transfomedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate());
String transformedMessage = notificationService.compileTemplate(ctx, action.getMessageTemplate());
try {
notificationService.sendNotification(action.getDestinationId(), correlationRule.getCorrelationTrigger().getSeverity(), transfomedSubject, transformedMessage);
} catch (Exception e) {
log.error("Failed while sending a notification: " + e.toString());
new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e);
}

}
} else {
for (CorrelationAlert correlationAlert: correlationAlertsList.getCorrelationAlertList()) {
updateCorrelationAlert(correlationAlert);
Expand All @@ -96,10 +124,12 @@ public void onResponse(CorrelationAlertsList correlationAlertsList) {
@Override
public void onFailure(Exception e) {
log.error("Failed to search active correlation alert", e);
new SecurityAnalyticsException("Failed to search active correlation alert", RestStatus.INTERNAL_SERVER_ERROR, e);
}
});
} catch (Exception e) {
log.error("Failed to fetch active alerts in the time window", e);
new SecurityAnalyticsException("Failed to get active alerts in the correlationRuletimewindow", RestStatus.INTERNAL_SERVER_ERROR, e);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
import java.util.List;
import java.util.Map;

public abstract class CorrelationAlertContext {
public class CorrelationAlertContext {

private final CorrelationRule correlationRule;
private final List<String> correlatedFindingIds;
protected CorrelationAlertContext(CorrelationRule correlationRule, List<String> correlatedFindingIds) {
this.correlationRule = correlationRule;
private final String sourceFinding;
private final String correlationRuleName;
private final long timeWindow;
public CorrelationAlertContext(List<String> correlatedFindingIds, String correlationRuleName, long timeWindow, String sourceFinding) {
this.correlatedFindingIds = correlatedFindingIds;
this.correlationRuleName = correlationRuleName;
this.timeWindow = timeWindow;
this.sourceFinding = sourceFinding;
}

/**
Expand All @@ -21,16 +25,11 @@ protected CorrelationAlertContext(CorrelationRule correlationRule, List<String>
*/
public Map<String, Object> asTemplateArg() {
Map<String, Object> templateArg = new HashMap<>();
templateArg.put("correlationRule", correlationRule);
templateArg.put("correlatedFindingIds", correlatedFindingIds);
templateArg.put("sourceFinding", sourceFinding);
templateArg.put("correlationRuleName", correlationRuleName);
templateArg.put("timeWindow", timeWindow);
return templateArg;
}

public CorrelationRule getCorrelationRule() {
return correlationRule;
}

public List<String> getCorrelatedFindingIds() {
return correlatedFindingIds;
}
}
Loading

0 comments on commit 29ba163

Please sign in to comment.