Skip to content

Commit

Permalink
Tweak math to evaluate lazily and reduce division
Browse files Browse the repository at this point in the history
Signed-off-by: James Duong <[email protected]>
  • Loading branch information
jduo committed Oct 28, 2024
1 parent a36cd95 commit 68f2d0a
Showing 1 changed file with 17 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.LiteralExpression;

/** Trendline command implementation */
@ToString
Expand Down Expand Up @@ -119,12 +118,12 @@ private interface TrendlineAccumulator {

// TODO: Make the actual math polymorphic based on types to deal with datetimes.
private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator {
private final ExprValue dataPointsNeeded;
private final LiteralExpression dataPointsNeeded;
private final EvictingQueue<ExprValue> receivedValues;
private ExprValue runningAverage = null;
private Expression runningTotal = null;

public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation) {
dataPointsNeeded = new ExprIntegerValue(computation.getNumberOfDataPoints());
dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue());
receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints());
}

Expand All @@ -135,51 +134,44 @@ public void accumulate(ExprValue value) {
return;
}

if (dataPointsNeeded.integerValue() == 1) {
runningAverage = value;
if (dataPointsNeeded.valueOf().integerValue() == 1) {
runningTotal = DSL.literal(value);
receivedValues.add(value);
return;
}

final ExprValue valueToRemove;
if (receivedValues.size() == dataPointsNeeded.integerValue()) {
if (receivedValues.size() == dataPointsNeeded.valueOf().integerValue()) {
valueToRemove = receivedValues.remove();
} else {
valueToRemove = null;
}
receivedValues.add(value);

if (receivedValues.size() == dataPointsNeeded.integerValue()) {
if (runningAverage != null) {
// We can use the previous average calculation.
// Subtract the evicted value / period and add the new value / period.
// Refactored, that would be previous + (newValue - oldValue) / period
runningAverage =
DSL.add(
DSL.literal(runningAverage),
DSL.divide(
DSL.subtract(DSL.literal(value), DSL.literal(valueToRemove)),
DSL.literal(dataPointsNeeded.doubleValue())))
.valueOf();
if (receivedValues.size() == dataPointsNeeded.valueOf().integerValue()) {
if (runningTotal != null) {
// We can use the previous calculation.
// Subtract the evicted value and add the new value.
// Refactored, that would be previous + (newValue - oldValue).
runningTotal =
DSL.add(runningTotal, DSL.subtract(DSL.literal(value), DSL.literal(valueToRemove)));
} else {
// This is the first average calculation so sum the entire receivedValues dataset.
final List<ExprValue> data = receivedValues.stream().toList();
Expression runningTotal = DSL.literal(0.0D);
runningTotal = DSL.literal(0.0D);
for (ExprValue entry : data) {
runningTotal = DSL.add(runningTotal, DSL.literal(entry));
}
runningAverage =
DSL.divide(runningTotal, DSL.literal(dataPointsNeeded.doubleValue())).valueOf();
}
}
}

@Override
public ExprValue calculate() {
if (receivedValues.size() < dataPointsNeeded.integerValue()) {
if (receivedValues.size() < dataPointsNeeded.valueOf().integerValue()) {
return null;
}
return runningAverage;
return DSL.divide(runningTotal, dataPointsNeeded).valueOf();
}
}
}

0 comments on commit 68f2d0a

Please sign in to comment.