Skip to content

Commit

Permalink
Add integration tests
Browse files Browse the repository at this point in the history
Signed-off-by: James Duong <[email protected]>
  • Loading branch information
jduo committed Oct 26, 2024
1 parent f5ca19d commit 79796be
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.google.common.collect.ImmutableMap.Builder;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.EqualsAndHashCode;
Expand All @@ -30,15 +31,20 @@ public class TrendlineOperator extends PhysicalPlan {
@Getter private final List<Trendline.TrendlineComputation> computations;
@EqualsAndHashCode.Exclude private final List<TrendlineAccumulator> accumulators;
@EqualsAndHashCode.Exclude private final Map<String, Integer> fieldToIndexMap;
@EqualsAndHashCode.Exclude private final HashSet<String> aliases;

public TrendlineOperator(PhysicalPlan input, List<Trendline.TrendlineComputation> computations) {
this.input = input;
this.computations = computations;
this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList();
fieldToIndexMap = new HashMap<>(computations.size());
aliases = new HashSet<>(computations.size());
for (int i = 0; i < computations.size(); ++i) {

fieldToIndexMap.put(computations.get(i).getDataField().getChild().get(0).toString(), i);
final Trendline.TrendlineComputation computation = computations.get(i);
fieldToIndexMap.put(computation.getDataField().getChild().get(0).toString(), i);
if (computation.getAlias() != null) {
aliases.add(computation.getAlias());
}
}
}

Expand All @@ -61,36 +67,36 @@ public boolean hasNext() {
public ExprValue next() {
final ExprValue result;
final ExprValue next = input.next();
consumeInputTuple(next);
final Map<String, ExprValue> inputStruct = ExprValueUtils.getTupleValue(next);
final Map<String, ExprValue> inputStruct = consumeInputTuple(next);
final Builder<String, ExprValue> mapBuilder = new Builder<>();
mapBuilder.putAll(inputStruct);

// Add calculated trendline values, which might overwrite existing fields from the input.
for (int i = 0; i < accumulators.size(); ++i) {
final ExprValue calculateResult = accumulators.get(i).calculate();
if (null != calculateResult) {
if (null != computations.get(i).getAlias()) {
mapBuilder.put(computations.get(i).getAlias(), calculateResult);
} else {
mapBuilder.put(
computations.get(i).getDataField().getChild().get(0).toString(), calculateResult);
}
final String field =
null != computations.get(i).getAlias()
? computations.get(i).getAlias()
: computations.get(i).getDataField().getChild().get(0).toString();
if (calculateResult != null) {
mapBuilder.put(field, calculateResult);
}
}

result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast());
return result;
}

private void consumeInputTuple(ExprValue inputValue) {
private Map<String, ExprValue> consumeInputTuple(ExprValue inputValue) {
final Map<String, ExprValue> tupleValue = ExprValueUtils.getTupleValue(inputValue);
for (String bindName : tupleValue.keySet()) {
final Integer index = fieldToIndexMap.get(bindName);
if (index == null) {
continue;
if (index != null) {
accumulators.get(index).accumulate(tupleValue.get(bindName));
}
accumulators.get(index).accumulate(tupleValue.get(bindName));
}
tupleValue.keySet().removeAll(aliases);
return tupleValue;
}

private static TrendlineAccumulator createAccumulator(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl;

import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
import static org.opensearch.sql.util.MatcherUtils.rows;
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;

import java.io.IOException;
import org.json.JSONObject;
import org.junit.jupiter.api.Test;

public class TrendlineCommandIT extends PPLIntegTestCase {

@Override
public void init() throws IOException {
loadIndex(Index.BANK);
}

@Test
public void testTrendline() throws IOException {
final JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as"
+ " balance_trend | fields balance_trend",
TEST_INDEX_BANK));
verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5));
}

@Test
public void testTrendlineMultipleFields() throws IOException {
final JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as"
+ " balance_trend sma(2, account_number) as account_number_trend | fields"
+ " balance_trend, account_number_trend",
TEST_INDEX_BANK));
verifyDataRows(result, rows(null, null), rows(44313.0, 28.5), rows(39882.5, 13.0));
}

@Test
public void testTrendlineOverwritesExistingField() throws IOException {
final JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as"
+ " age | fields age",
TEST_INDEX_BANK));
verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5));
}
}

0 comments on commit 79796be

Please sign in to comment.