diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 5503fe876b..b74c922418 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -9,6 +9,7 @@ import com.google.common.collect.ImmutableMap.Builder; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import lombok.EqualsAndHashCode; @@ -30,15 +31,20 @@ public class TrendlineOperator extends PhysicalPlan { @Getter private final List computations; @EqualsAndHashCode.Exclude private final List accumulators; @EqualsAndHashCode.Exclude private final Map fieldToIndexMap; + @EqualsAndHashCode.Exclude private final HashSet aliases; public TrendlineOperator(PhysicalPlan input, List computations) { this.input = input; this.computations = computations; this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList(); fieldToIndexMap = new HashMap<>(computations.size()); + aliases = new HashSet<>(computations.size()); for (int i = 0; i < computations.size(); ++i) { - - fieldToIndexMap.put(computations.get(i).getDataField().getChild().get(0).toString(), i); + final Trendline.TrendlineComputation computation = computations.get(i); + fieldToIndexMap.put(computation.getDataField().getChild().get(0).toString(), i); + if (computation.getAlias() != null) { + aliases.add(computation.getAlias()); + } } } @@ -61,36 +67,36 @@ public boolean hasNext() { public ExprValue next() { final ExprValue result; final ExprValue next = input.next(); - consumeInputTuple(next); - final Map inputStruct = ExprValueUtils.getTupleValue(next); + final Map inputStruct = consumeInputTuple(next); final Builder mapBuilder = new Builder<>(); mapBuilder.putAll(inputStruct); // Add calculated trendline values, which might overwrite existing fields from the input. for (int i = 0; i < accumulators.size(); ++i) { final ExprValue calculateResult = accumulators.get(i).calculate(); - if (null != calculateResult) { - if (null != computations.get(i).getAlias()) { - mapBuilder.put(computations.get(i).getAlias(), calculateResult); - } else { - mapBuilder.put( - computations.get(i).getDataField().getChild().get(0).toString(), calculateResult); - } + final String field = + null != computations.get(i).getAlias() + ? computations.get(i).getAlias() + : computations.get(i).getDataField().getChild().get(0).toString(); + if (calculateResult != null) { + mapBuilder.put(field, calculateResult); } } + result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast()); return result; } - private void consumeInputTuple(ExprValue inputValue) { + private Map consumeInputTuple(ExprValue inputValue) { final Map tupleValue = ExprValueUtils.getTupleValue(inputValue); for (String bindName : tupleValue.keySet()) { final Integer index = fieldToIndexMap.get(bindName); - if (index == null) { - continue; + if (index != null) { + accumulators.get(index).accumulate(tupleValue.get(bindName)); } - accumulators.get(index).accumulate(tupleValue.get(bindName)); } + tupleValue.keySet().removeAll(aliases); + return tupleValue; } private static TrendlineAccumulator createAccumulator( diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java new file mode 100644 index 0000000000..98e33c09a6 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; + +public class TrendlineCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws IOException { + loadIndex(Index.BANK); + } + + @Test + public void testTrendline() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " balance_trend | fields balance_trend", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } + + @Test + public void testTrendlineMultipleFields() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " balance_trend sma(2, account_number) as account_number_trend | fields" + + " balance_trend, account_number_trend", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(null, null), rows(44313.0, 28.5), rows(39882.5, 13.0)); + } + + @Test + public void testTrendlineOverwritesExistingField() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " age | fields age", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } +}