From f8267b4a949902d4534dbe248f45f473a5939e69 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 21:21:54 -0700 Subject: [PATCH] Fix race condition in PageListener (#1351) (#1352) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix race condition in PageListener This PR - Introduced an `AtomicInteger` called `pagesInFlight` to track the number of pages currently being processed.  - Incremented `pagesInFlight` before processing each page and decremented it after processing is complete - Adjusted the condition in `scheduleImputeHCTask` to check both `pagesInFlight.get() == 0` (all pages have been processed) and `sentOutPages.get() == receivedPages.get()` (all responses have been received) before scheduling the `imputeHC` task.  - Removed the previous final check in `onResponse` that decided when to schedule `imputeHC`, relying instead on the updated counters for accurate synchronization. These changes address the race condition where `sentOutPages` might not have been incremented in time before checking whether to schedule the `imputeHC` task. By accurately tracking the number of in-flight pages and sent pages, we ensure that `imputeHC` is executed only after all pages have been fully processed and all responses have been received. Testing done: 1. Reproduced the race condition by starting two detectors with imputation. This causes an out of order illegal argument exception from RCF due to this race condition. Also verified the change fixed the problem. 2. added an IT for the above scenario. * make sure increment before schedule --------- (cherry picked from commit f62885a7e2dee69cd067667c68e540451aa4e883) Signed-off-by: Kaituo Li Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- .../transport/ResultBulkTransportAction.java | 1 - .../timeseries/transport/ResultProcessor.java | 34 ++++++--- .../AbstractMissingSingleFeatureTestCase.java | 3 +- .../java/org/opensearch/ad/e2e/MissingIT.java | 35 +++++++-- .../ad/e2e/MissingMultiFeatureIT.java | 75 ++++++++++++++++++- .../ad/e2e/PreviewMissingSingleFeatureIT.java | 4 +- 6 files changed, 127 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java index 20b3e9fba..61efd6104 100644 --- a/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java @@ -85,7 +85,6 @@ protected void doExecute(Task task, ResultBulkRequestType request, ActionListene // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; - @SuppressWarnings("rawtypes") List results = request.getResults(); if (results == null || results.size() < 1) { diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java index f412ce84e..2b7ebca65 100644 --- a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java @@ -210,6 +210,10 @@ class PageListener implements ActionListener { private String taskId; private AtomicInteger receivedPages; private AtomicInteger sentOutPages; + // By introducing pagesInFlight and incrementing it in the main thread before asynchronous processing begins, + // we ensure that the count of in-flight pages is accurate at all times. This allows us to reliably determine + // when all pages have been processed. + private AtomicInteger pagesInFlight; PageListener(PageIterator pageIterator, Config config, long dataStartTime, long dataEndTime, String taskId) { this.pageIterator = pageIterator; @@ -220,14 +224,21 @@ class PageListener implements ActionListener { this.taskId = taskId; this.receivedPages = new AtomicInteger(); this.sentOutPages = new AtomicInteger(); + this.pagesInFlight = new AtomicInteger(); } @Override public void onResponse(CompositeRetriever.Page entityFeatures) { + // Increment pagesInFlight to track the processing of this page + pagesInFlight.incrementAndGet(); + // start processing next page after sending out features for previous page if (pageIterator.hasNext()) { pageIterator.next(this); + } else if (config.getImputationOption() != null) { + scheduleImputeHCTask(); } + if (entityFeatures != null && false == entityFeatures.isEmpty()) { LOG .info( @@ -309,19 +320,15 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { } catch (Exception e) { LOG.error("Unexpected exception", e); handleException(e); + } finally { + // Decrement pagesInFlight after processing is complete + pagesInFlight.decrementAndGet(); } }); - } - - if (!pageIterator.hasNext() && config.getImputationOption() != null) { - if (sentOutPages.get() > 0) { - // at least 1 page sent out. Wait until all responses are back. - scheduleImputeHCTask(); - } else { - // no data in current interval. Send out impute request right away. - imputeHC(dataStartTime, dataEndTime, configId, taskId); - } - + } else { + // No entity features to process + // Decrement pagesInFlight immediately + pagesInFlight.decrementAndGet(); } } @@ -358,7 +365,10 @@ private void scheduleImputeHCTask() { @Override public void run() { - if (sentOutPages.get() == receivedPages.get()) { + // By using pagesInFlight in the condition within scheduleImputeHCTask, we ensure that imputeHC + // is executed only after all pages have been processed (pagesInFlight.get() == 0) and all + // responses have been received (sentOutPages.get() == receivedPages.get()). + if (pagesInFlight.get() == 0 && sentOutPages.get() == receivedPages.get()) { if (!sent.get()) { // since we don't know when cancel will succeed, need sent to ensure imputeHC is only called once sent.set(true); diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java b/src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java index 4cce93b12..4a6274149 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractMissingSingleFeatureTestCase.java @@ -28,7 +28,8 @@ protected String genDetector( long windowDelayMinutes, boolean hc, ImputationMethod imputation, - long trainTimeMillis + long trainTimeMillis, + String name ) { StringBuilder sb = new StringBuilder(); // common part diff --git a/src/test/java/org/opensearch/ad/e2e/MissingIT.java b/src/test/java/org/opensearch/ad/e2e/MissingIT.java index f6f459d67..673e9d91b 100644 --- a/src/test/java/org/opensearch/ad/e2e/MissingIT.java +++ b/src/test/java/org/opensearch/ad/e2e/MissingIT.java @@ -78,15 +78,27 @@ protected TrainResult createAndStartRealTimeDetector( List data, ImputationMethod imputation, boolean hc, - long trainTimeMillis + long trainTimeMillis, + String name ) throws Exception { - TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis); + TrainResult trainResult = createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, name); List result = startRealTimeDetector(trainResult, numberOfEntities, intervalMinutes, true); recordLastSeenFromResult(result); return trainResult; } + protected TrainResult createAndStartRealTimeDetector( + int numberOfEntities, + int trainTestSplit, + List data, + ImputationMethod imputation, + boolean hc, + long trainTimeMillis + ) throws Exception { + return createAndStartRealTimeDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test"); + } + protected TrainResult createAndStartHistoricalDetector( int numberOfEntities, int trainTestSplit, @@ -115,12 +127,13 @@ protected TrainResult createDetector( List data, ImputationMethod imputation, boolean hc, - long trainTimeMillis + long trainTimeMillis, + String name ) throws Exception { Instant trainTime = Instant.ofEpochMilli(trainTimeMillis); Duration windowDelay = getWindowDelay(trainTimeMillis); - String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis); + String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), hc, imputation, trainTimeMillis, name); RestClient client = client(); String detectorId = createDetector(client, detector); @@ -129,6 +142,17 @@ protected TrainResult createDetector( return new TrainResult(detectorId, data, trainTestSplit * numberOfEntities, windowDelay, trainTime, "timestamp"); } + protected TrainResult createDetector( + int numberOfEntities, + int trainTestSplit, + List data, + ImputationMethod imputation, + boolean hc, + long trainTimeMillis + ) throws Exception { + return createDetector(numberOfEntities, trainTestSplit, data, imputation, hc, trainTimeMillis, "test"); + } + protected Duration getWindowDelay(long trainTimeMillis) { /* * AD accepts windowDelay in the unit of minutes. Thus, we need to convert the delay in minutes. This will @@ -156,7 +180,8 @@ protected abstract String genDetector( long windowDelayMinutes, boolean hc, ImputationMethod imputation, - long trainTimeMillis + long trainTimeMillis, + String name ); protected abstract AbstractSyntheticDataTest.GenData genData( diff --git a/src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java b/src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java index 0b5708c5a..2602f9541 100644 --- a/src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java +++ b/src/test/java/org/opensearch/ad/e2e/MissingMultiFeatureIT.java @@ -135,13 +135,80 @@ public void testHCPrevious() throws Exception { ); } + /** + * test we start two HC detector with zero imputation consecutively. + * We expect there is no out of order error from RCF. + * @throws Exception + */ + public void testDoubleHCZero() throws Exception { + lastSeen.clear(); + int numberOfEntities = 2; + + AbstractSyntheticDataTest.MISSING_MODE mode = AbstractSyntheticDataTest.MISSING_MODE.NO_MISSING_DATA; + ImputationMethod method = ImputationMethod.ZERO; + + AbstractSyntheticDataTest.GenData dataGenerated = genData(trainTestSplit, numberOfEntities, mode); + + // only ingest train data to avoid validation error as we use latest data time as starting point. + // otherwise, we will have too many missing points. + ingestUniformSingleFeatureData( + trainTestSplit + numberOfEntities * 6, // we only need a few to verify and trigger train. + dataGenerated.data + ); + + TrainResult trainResult1 = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime, + "test1" + ); + + TrainResult trainResult2 = createAndStartRealTimeDetector( + numberOfEntities, + trainTestSplit, + dataGenerated.data, + method, + true, + dataGenerated.testStartTime, + "test2" + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult1.windowDelay, + trainResult1.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + + runTest( + dataGenerated.testStartTime, + dataGenerated, + trainResult2.windowDelay, + trainResult2.detectorId, + numberOfEntities, + mode, + method, + 3, + true + ); + } + @Override protected String genDetector( int trainTestSplit, long windowDelayMinutes, boolean hc, ImputationMethod imputation, - long trainTimeMillis + long trainTimeMillis, + String name ) { StringBuilder sb = new StringBuilder(); @@ -185,7 +252,7 @@ protected String genDetector( // common part sb .append( - "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + "{ \"name\": \"%s\", \"description\": \"test\", \"time_field\": \"timestamp\"" + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_id\": \"feature2\", \"feature_name\": \"feature 2\", \"feature_enabled\": " + "\"true\", \"aggregation_query\": { \"Feature2\": { \"avg\": { \"field\": \"data\" } } } }," + featureWithFilter @@ -226,9 +293,9 @@ protected String genDetector( sb.append("\"schema_version\": 0}"); if (hc) { - return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1, categoricalField); + return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1, categoricalField); } else { - return String.format(Locale.ROOT, sb.toString(), datasetName, intervalMinutes, trainTestSplit - 1); + return String.format(Locale.ROOT, sb.toString(), name, datasetName, intervalMinutes, trainTestSplit - 1); } } diff --git a/src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java b/src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java index 6b0273c0a..61b5ed77f 100644 --- a/src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java +++ b/src/test/java/org/opensearch/ad/e2e/PreviewMissingSingleFeatureIT.java @@ -35,7 +35,7 @@ public void testSingleStream() throws Exception { ); Duration windowDelay = getWindowDelay(dataGenerated.testStartTime); - String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime); + String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), false, method, dataGenerated.testStartTime, "test"); Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong()); Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong()); @@ -63,7 +63,7 @@ public void testHC() throws Exception { ); Duration windowDelay = getWindowDelay(dataGenerated.testStartTime); - String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime); + String detector = genDetector(trainTestSplit, windowDelay.toMinutes(), true, method, dataGenerated.testStartTime, "test"); Instant begin = Instant.ofEpochMilli(dataGenerated.data.get(0).get("timestamp").getAsLong()); Instant end = Instant.ofEpochMilli(dataGenerated.data.get(dataGenerated.data.size() - 1).get("timestamp").getAsLong());