Skip to content

Commit

Permalink
Fix FilterOperator to cache next element and avoid repeated consumpti…
Browse files Browse the repository at this point in the history
…on on hasNext() calls (#3123) (#3199)

(cherry picked from commit 3e2cb1d)

Signed-off-by: Peng Huo <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 4b2eb67 commit 48998e5
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class FilterOperator extends PhysicalPlan {
@Getter private final PhysicalPlan input;
@Getter private final Expression conditions;
@ToString.Exclude private ExprValue next = null;
@ToString.Exclude private boolean nextPrepared = false;

@Override
public <R, C> R accept(PhysicalPlanNodeVisitor<R, C> visitor, C context) {
Expand All @@ -41,19 +42,34 @@ public List<PhysicalPlan> getChild() {

@Override
public boolean hasNext() {
if (!nextPrepared) {
prepareNext();
}
return next != null;
}

@Override
public ExprValue next() {
if (!nextPrepared) {
prepareNext();
}
ExprValue result = next;
next = null;
nextPrepared = false;
return result;
}

private void prepareNext() {
while (input.hasNext()) {
ExprValue inputValue = input.next();
ExprValue exprValue = conditions.valueOf(inputValue.bindingTuples());
if (!(exprValue.isNull() || exprValue.isMissing()) && (exprValue.booleanValue())) {
if (!(exprValue.isNull() || exprValue.isMissing()) && exprValue.booleanValue()) {
next = inputValue;
return true;
nextPrepared = true;
return;
}
}
return false;
}

@Override
public ExprValue next() {
return next;
next = null;
nextPrepared = true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_FALSE;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_MISSING;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL;
import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.filter;

import com.google.common.collect.ImmutableMap;
import java.util.LinkedHashMap;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
Expand All @@ -26,12 +36,22 @@
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;

@ExtendWith(MockitoExtension.class)
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
class FilterOperatorTest extends PhysicalPlanTestBase {
@Mock private PhysicalPlan inputPlan;

@Mock private Expression condition;

private FilterOperator filterOperator;

@BeforeEach
public void setup() {
filterOperator = filter(inputPlan, condition);
}

@Test
public void filter_test() {
FilterOperator plan =
Expand Down Expand Up @@ -82,4 +102,68 @@ public void missing_value_should_been_ignored() {
List<ExprValue> result = execute(plan);
assertEquals(0, result.size());
}

@Test
public void testHasNextWhenInputHasNoElements() {
when(inputPlan.hasNext()).thenReturn(false);

assertFalse(
filterOperator.hasNext(), "hasNext() should return false when input has no elements");
}

@Test
public void testHasNextWithMatchingCondition() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true).thenReturn(false);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_TRUE);

assertTrue(filterOperator.hasNext(), "hasNext() should return true when condition matches");
assertEquals(
inputValue, filterOperator.next(), "next() should return the matching input value");
}

@Test
public void testHasNextWithNonMatchingCondition() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true, false);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_FALSE);

assertFalse(
filterOperator.hasNext(), "hasNext() should return false if no values match the condition");
}

@Test
public void testMultipleCallsToHasNextDoNotConsumeInput() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_TRUE);

assertTrue(
filterOperator.hasNext(),
"First hasNext() call should return true if there is a matching value");
verify(inputPlan, times(1)).next();
assertTrue(
filterOperator.hasNext(),
"Subsequent hasNext() calls should still return true without advancing the input");
verify(inputPlan, times(1)).next();
assertEquals(
inputValue, filterOperator.next(), "next() should return the matching input value");
verify(inputPlan, times(1)).next();
}

@Test
public void testNextWithoutCallingHasNext() {
ExprValue inputValue = mock(ExprValue.class);
when(inputPlan.hasNext()).thenReturn(true, false);
when(inputPlan.next()).thenReturn(inputValue);
when(condition.valueOf(any())).thenReturn(LITERAL_TRUE);

assertEquals(
inputValue,
filterOperator.next(),
"next() should return the matching input value even if hasNext() was not called");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT Origin, Dest FROM (SELECT * FROM opensearch_dashboards_sample_data_flights WHERE AvgTicketPrice > 100 GROUP BY Origin, Dest, AvgTicketPrice) AS flights WHERE AvgTicketPrice < 1000 ORDER BY AvgTicketPrice LIMIT 30

0 comments on commit 48998e5

Please sign in to comment.