diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index 52997d092696..13f77b2c5a3d 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -305,4 +305,48 @@ public static Object arrayValueConstructor(Object... arr) { } return arr; } + + @ScalarFunction + public static int[] generateIntArray(int start, int end, int inc) { + int size = (end - start) / inc + 1; + int[] arr = new int[size]; + + for (int i = 0, value = start; i < size; i++, value += inc) { + arr[i] = value; + } + return arr; + } + + @ScalarFunction + public static long[] generateLongArray(long start, long end, long inc) { + int size = (int) ((end - start) / inc + 1); + long[] arr = new long[size]; + + for (int i = 0; i < size; i++, start += inc) { + arr[i] = start; + } + return arr; + } + + @ScalarFunction + public static float[] generateFloatArray(float start, float end, float inc) { + int size = (int) ((end - start) / inc + 1); + float[] arr = new float[size]; + + for (int i = 0; i < size; i++, start += inc) { + arr[i] = start; + } + return arr; + } + + @ScalarFunction + public static double[] generateDoubleArray(double start, double end, double inc) { + int size = (int) ((end - start) / inc + 1); + double[] arr = new double[size]; + + for (int i = 0; i < size; i++, start += inc) { + arr[i] = start; + } + return arr; + } } diff --git a/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java b/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java index d2771bc626f1..600016fd8f98 100644 --- a/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java @@ -42,8 +42,8 @@ public class FunctionDefinitionRegistryTest { private static final List IGNORED_FUNCTION_NAMES = ImmutableList.of( // Geo functions are defined in pinot-core "geotoh3", - // ArrayToMV and ArrayValueConstructor are placeholder functions without implementation - "arraytomv", "arrayvalueconstructor", + // ArrayToMV, ArrayValueConstructor and GenerateArray are placeholder functions without implementation + "arraytomv", "arrayvalueconstructor", "generatearray", // Scalar function "scalar", // Functions without scalar function counterpart as of now diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GenerateArrayTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GenerateArrayTransformFunction.java new file mode 100644 index 000000000000..44632867ac37 --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/GenerateArrayTransformFunction.java @@ -0,0 +1,410 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.transform.function; + +import com.google.common.base.Preconditions; +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.operator.ColumnContext; +import org.apache.pinot.core.operator.blocks.ValueBlock; +import org.apache.pinot.core.operator.transform.TransformResultMetadata; +import org.apache.pinot.segment.spi.index.reader.Dictionary; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.roaringbitmap.RoaringBitmap; + +public class GenerateArrayTransformFunction implements TransformFunction { + public static final String FUNCTION_NAME = "generateArray"; + + private final DataType _dataType; + + private final int[] _intArrayLiteral; + private final long[] _longArrayLiteral; + private final float[] _floatArrayLiteral; + private final double[] _doubleArrayLiteral; + private int[][] _intArrayResult; + private long[][] _longArrayResult; + private float[][] _floatArrayResult; + private double[][] _doubleArrayResult; + + public GenerateArrayTransformFunction(List literalContexts) { + Preconditions.checkNotNull(literalContexts); + if (literalContexts.isEmpty()) { + _dataType = DataType.UNKNOWN; + _intArrayLiteral = new int[0]; + _longArrayLiteral = new long[0]; + _floatArrayLiteral = new float[0]; + _doubleArrayLiteral = new double[0]; + return; + } + Preconditions.checkState(literalContexts.size() == 2 || literalContexts.size() == 3, + "GenerateArrayTransformFunction takes only 2 or 3 arguments, found: %s", literalContexts.size()); + for (ExpressionContext literalContext : literalContexts) { + Preconditions.checkState(literalContext.getType() == ExpressionContext.Type.LITERAL, + "GenerateArrayTransformFunction only takes literals as arguments, found: %s", literalContext); + } + // Get the type of the first member in the literalContext and generate an array + _dataType = literalContexts.get(0).getLiteral().getType(); + + switch (_dataType) { + case INT: + int startInt = literalContexts.get(0).getLiteral().getIntValue(); + int endInt = literalContexts.get(1).getLiteral().getIntValue(); + int incInt; + if (literalContexts.size() == 3) { + incInt = literalContexts.get(2).getLiteral().getIntValue(); + } else { + incInt = 1; + } + Preconditions.checkState((endInt > startInt && incInt > 0) || (startInt > endInt + && incInt < 0), "Incorrect Step value."); + int size = (endInt - startInt) / incInt + 1; + _intArrayLiteral = new int[size]; + for (int i = 0, value = startInt; i < size; i++, value += incInt) { + _intArrayLiteral[i] = value; + } + _longArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + break; + case LONG: + long startLong = Long.parseLong(literalContexts.get(0).getLiteral().getStringValue()); + long endLong = Long.parseLong(literalContexts.get(1).getLiteral().getStringValue()); + long incLong; + if (literalContexts.size() == 3) { + incLong = Long.parseLong(literalContexts.get(2).getLiteral().getStringValue()); + } else { + incLong = 1L; + } + Preconditions.checkState((endLong > startLong && incLong > 0) || (startLong > endLong + && incLong < 0), "Incorrect Step value."); + size = (int) ((endLong - startLong) / incLong + 1); + _longArrayLiteral = new long[size]; + for (int i = 0; i < size; i++, startLong += incLong) { + _longArrayLiteral[i] = startLong; + } + _intArrayLiteral = null; + _floatArrayLiteral = null; + _doubleArrayLiteral = null; + break; + case FLOAT: + float startFloat = Float.parseFloat(literalContexts.get(0).getLiteral().getStringValue()); + float endFloat = Float.parseFloat(literalContexts.get(1).getLiteral().getStringValue()); + float incFloat; + if (literalContexts.size() == 3) { + incFloat = Float.parseFloat(literalContexts.get(2).getLiteral().getStringValue()); + } else { + incFloat = 1; + } + Preconditions.checkState((endFloat > startFloat && incFloat > 0) || (startFloat > endFloat + && incFloat < 0), "Incorrect Step value."); + size = (int) ((endFloat - startFloat) / incFloat + 1); + _floatArrayLiteral = new float[size]; + for (int i = 0; i < size; i++, startFloat += incFloat) { + _floatArrayLiteral[i] = startFloat; + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _doubleArrayLiteral = null; + break; + case DOUBLE: + double startDouble = Double.parseDouble(literalContexts.get(0).getLiteral().getStringValue()); + double endDouble = Double.parseDouble(literalContexts.get(1).getLiteral().getStringValue()); + double incDouble; + if (literalContexts.size() == 3) { + incDouble = Double.parseDouble(literalContexts.get(2).getLiteral().getStringValue()); + } else { + incDouble = 1.0; + } + Preconditions.checkState((endDouble > startDouble && incDouble > 0) || (startDouble > endDouble + && incDouble < 0), "Incorrect Step value."); + size = (int) ((endDouble - startDouble) / incDouble + 1); + _doubleArrayLiteral = new double[size]; + for (int i = 0; i < size; i++, startDouble += incDouble) { + _doubleArrayLiteral[i] = startDouble; + } + _intArrayLiteral = null; + _longArrayLiteral = null; + _floatArrayLiteral = null; + break; + default: + throw new IllegalStateException( + "Illegal data type for GenerateArrayTransformFunction: " + _dataType + ", literal contexts: " + + Arrays.toString(literalContexts.toArray())); + } + } + + public int[] getIntArrayLiteral() { + return _intArrayLiteral; + } + + public long[] getLongArrayLiteral() { + return _longArrayLiteral; + } + + public float[] getFloatArrayLiteral() { + return _floatArrayLiteral; + } + + public double[] getDoubleArrayLiteral() { + return _doubleArrayLiteral; + } + + @Override + public String getName() { + return FUNCTION_NAME; + } + + @Override + public void init(List arguments, Map columnContextMap) { + if (arguments.size() < 2) { + throw new IllegalArgumentException("At least 2 arguments are required for generateArray function"); + } + } + + @Override + public TransformResultMetadata getResultMetadata() { + return new TransformResultMetadata(_dataType, false, false); + } + + @Nullable + @Override + public Dictionary getDictionary() { + return null; + } + + @Override + public int[] transformToDictIdsSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public int[][] transformToDictIdsMV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public int[] transformToIntValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public long[] transformToLongValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public float[] transformToFloatValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public double[] transformToDoubleValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal[] transformToBigDecimalValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public String[] transformToStringValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[][] transformToBytesValuesSV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public int[][] transformToIntValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + int[][] intArrayResult = _intArrayResult; + if (intArrayResult == null || intArrayResult.length < numDocs) { + intArrayResult = new int[numDocs][]; + int[] intArrayLiteral = _intArrayLiteral; + if (intArrayLiteral == null) { + switch (_dataType) { + case LONG: + intArrayLiteral = new int[_longArrayLiteral.length]; + for (int i = 0; i < _longArrayLiteral.length; i++) { + intArrayLiteral[i] = (int) _longArrayLiteral[i]; + } + break; + case FLOAT: + intArrayLiteral = new int[_floatArrayLiteral.length]; + for (int i = 0; i < _floatArrayLiteral.length; i++) { + intArrayLiteral[i] = (int) _floatArrayLiteral[i]; + } + break; + case DOUBLE: + intArrayLiteral = new int[_doubleArrayLiteral.length]; + for (int i = 0; i < _doubleArrayLiteral.length; i++) { + intArrayLiteral[i] = (int) _doubleArrayLiteral[i]; + } + break; + default: + throw new IllegalStateException("Unable to convert data type: " + _dataType + " to in array"); + } + } + Arrays.fill(intArrayResult, intArrayLiteral); + _intArrayResult = intArrayResult; + } + return intArrayResult; + } + + @Override + public long[][] transformToLongValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + long[][] longArrayResult = _longArrayResult; + if (longArrayResult == null || longArrayResult.length < numDocs) { + longArrayResult = new long[numDocs][]; + long[] longArrayLiteral = _longArrayLiteral; + if (longArrayLiteral == null) { + switch (_dataType) { + case INT: + longArrayLiteral = new long[_intArrayLiteral.length]; + for (int i = 0; i < _intArrayLiteral.length; i++) { + longArrayLiteral[i] = _intArrayLiteral[i]; + } + break; + case FLOAT: + longArrayLiteral = new long[_floatArrayLiteral.length]; + for (int i = 0; i < _floatArrayLiteral.length; i++) { + longArrayLiteral[i] = (long) _floatArrayLiteral[i]; + } + break; + case DOUBLE: + longArrayLiteral = new long[_doubleArrayLiteral.length]; + for (int i = 0; i < _doubleArrayLiteral.length; i++) { + longArrayLiteral[i] = (long) _doubleArrayLiteral[i]; + } + break; + default: + throw new IllegalStateException("Unable to convert data type: " + _dataType + " to long array"); + } + } + Arrays.fill(longArrayResult, longArrayLiteral); + _longArrayResult = longArrayResult; + } + return longArrayResult; + } + + @Override + public float[][] transformToFloatValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + float[][] floatArrayResult = _floatArrayResult; + if (floatArrayResult == null || floatArrayResult.length < numDocs) { + floatArrayResult = new float[numDocs][]; + float[] floatArrayLiteral = _floatArrayLiteral; + if (floatArrayLiteral == null) { + switch (_dataType) { + case INT: + floatArrayLiteral = new float[_intArrayLiteral.length]; + for (int i = 0; i < _intArrayLiteral.length; i++) { + floatArrayLiteral[i] = _intArrayLiteral[i]; + } + break; + case LONG: + floatArrayLiteral = new float[_longArrayLiteral.length]; + for (int i = 0; i < _longArrayLiteral.length; i++) { + floatArrayLiteral[i] = _longArrayLiteral[i]; + } + break; + case DOUBLE: + floatArrayLiteral = new float[_doubleArrayLiteral.length]; + for (int i = 0; i < _doubleArrayLiteral.length; i++) { + floatArrayLiteral[i] = (float) _doubleArrayLiteral[i]; + } + break; + default: + throw new IllegalStateException("Unable to convert data type: " + _dataType + " to float array"); + } + } + Arrays.fill(floatArrayResult, floatArrayLiteral); + _floatArrayResult = floatArrayResult; + } + return floatArrayResult; + } + + @Override + public double[][] transformToDoubleValuesMV(ValueBlock valueBlock) { + int numDocs = valueBlock.getNumDocs(); + double[][] doubleArrayResult = _doubleArrayResult; + if (doubleArrayResult == null || doubleArrayResult.length < numDocs) { + doubleArrayResult = new double[numDocs][]; + double[] doubleArrayLiteral = _doubleArrayLiteral; + if (doubleArrayLiteral == null) { + switch (_dataType) { + case INT: + doubleArrayLiteral = new double[_intArrayLiteral.length]; + for (int i = 0; i < _intArrayLiteral.length; i++) { + doubleArrayLiteral[i] = _intArrayLiteral[i]; + } + break; + case LONG: + doubleArrayLiteral = new double[_longArrayLiteral.length]; + for (int i = 0; i < _longArrayLiteral.length; i++) { + doubleArrayLiteral[i] = _longArrayLiteral[i]; + } + break; + case FLOAT: + doubleArrayLiteral = new double[_floatArrayLiteral.length]; + for (int i = 0; i < _floatArrayLiteral.length; i++) { + doubleArrayLiteral[i] = _floatArrayLiteral[i]; + } + break; + default: + throw new IllegalStateException("Unable to convert data type: " + _dataType + " to double array"); + } + } + Arrays.fill(doubleArrayResult, doubleArrayLiteral); + _doubleArrayResult = doubleArrayResult; + } + return doubleArrayResult; + } + + @Override + public String[][] transformToStringValuesMV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[][][] transformToBytesValuesMV(ValueBlock valueBlock) { + throw new UnsupportedOperationException(); + } + + @Nullable + @Override + public RoaringBitmap getNullBitmap(ValueBlock block) { + // Treat all unknown type values as null regardless of the value. + if (_dataType != DataType.UNKNOWN) { + return null; + } + int length = block.getNumDocs(); + RoaringBitmap bitmap = new RoaringBitmap(); + bitmap.add(0L, length); + return bitmap; + } +} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java index d5e4d9d481d6..c2a25ad4e25d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java @@ -299,6 +299,13 @@ public static TransformFunction get(ExpressionContext expression, Map transformFunctionClass = TRANSFORM_FUNCTION_MAP.get(functionName); if (transformFunctionClass != null) { diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GenerateArrayTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GenerateArrayTransformFunctionTest.java new file mode 100644 index 000000000000..23a0fd4bb77d --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/GenerateArrayTransformFunctionTest.java @@ -0,0 +1,159 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.operator.transform.function; + +import java.util.ArrayList; +import java.util.List; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.core.operator.blocks.ProjectionBlock; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testng.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import static org.mockito.Mockito.when; + +public class GenerateArrayTransformFunctionTest { + private static final int NUM_DOCS = 100; + private AutoCloseable _mocks; + + @Mock + private ProjectionBlock _projectionBlock; + + @BeforeMethod + public void setUp() { + _mocks = MockitoAnnotations.openMocks(this); + when(_projectionBlock.getNumDocs()).thenReturn(NUM_DOCS); + } + + @AfterMethod + public void tearDown() + throws Exception { + _mocks.close(); + } + @Test + public void testGenerateIntArrayTransformFunction() { + List arrayExpressions = new ArrayList<>(); + int[] inputArray = {0, 10, 1}; + for (int j : inputArray) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.INT, j)); + } + + GenerateArrayTransformFunction intArray = new GenerateArrayTransformFunction(arrayExpressions); + Assert.assertEquals(intArray.getResultMetadata().getDataType(), DataType.INT); + Assert.assertEquals(intArray.getIntArrayLiteral(), new int[]{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + }); + } + + @Test + public void testGenerateLongArrayTransformFunction() { + List arrayExpressions = new ArrayList<>(); + int[] inputArray = {0, 10, 1}; + for (int j : inputArray) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.LONG, (long) j)); + } + + GenerateArrayTransformFunction longArray = new GenerateArrayTransformFunction(arrayExpressions); + Assert.assertEquals(longArray.getResultMetadata().getDataType(), DataType.LONG); + Assert.assertEquals(longArray.getLongArrayLiteral(), new long[]{ + 0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L + }); + } + + @Test + public void testGenerateFloatArrayTransformFunction() { + List arrayExpressions = new ArrayList<>(); + int[] inputArray = {0, 10, 1}; + for (int j : inputArray) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.FLOAT, (float) j)); + } + + GenerateArrayTransformFunction floatArray = new GenerateArrayTransformFunction(arrayExpressions); + Assert.assertEquals(floatArray.getResultMetadata().getDataType(), DataType.FLOAT); + Assert.assertEquals(floatArray.getFloatArrayLiteral(), new float[]{ + 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f + }); + } + + @Test + public void testGenerateDoubleArrayTransformFunction() { + List arrayExpressions = new ArrayList<>(); + int[] inputArray = {0, 10, 1}; + for (int j : inputArray) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.DOUBLE, (double) j)); + } + + GenerateArrayTransformFunction doubleArray = new GenerateArrayTransformFunction(arrayExpressions); + Assert.assertEquals(doubleArray.getResultMetadata().getDataType(), DataType.DOUBLE); + Assert.assertEquals(doubleArray.getDoubleArrayLiteral(), new double[]{ + 0d, 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d + }); + } + @Test + public void testGenerateEmptyArrayTransformFunction() { + List arrayExpressions = new ArrayList<>(); + GenerateArrayTransformFunction emptyLiteral = new GenerateArrayTransformFunction(arrayExpressions); + Assert.assertEquals(emptyLiteral.getIntArrayLiteral(), new int[0]); + Assert.assertEquals(emptyLiteral.getLongArrayLiteral(), new long[0]); + Assert.assertEquals(emptyLiteral.getFloatArrayLiteral(), new float[0]); + Assert.assertEquals(emptyLiteral.getDoubleArrayLiteral(), new double[0]); + + int[][] ints = emptyLiteral.transformToIntValuesMV(_projectionBlock); + Assert.assertEquals(ints.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(ints[i].length, 0); + } + + long[][] longs = emptyLiteral.transformToLongValuesMV(_projectionBlock); + Assert.assertEquals(longs.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(longs[i].length, 0); + } + + float[][] floats = emptyLiteral.transformToFloatValuesMV(_projectionBlock); + Assert.assertEquals(floats.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(floats[i].length, 0); + } + + double[][] doubles = emptyLiteral.transformToDoubleValuesMV(_projectionBlock); + Assert.assertEquals(doubles.length, NUM_DOCS); + for (int i = 0; i < NUM_DOCS; i++) { + Assert.assertEquals(doubles[i].length, 0); + } + } + @Test + public void testGenerateIntArrayTransformFunctionWithIncorrectStepValue() { + List arrayExpressions = new ArrayList<>(); + int[] inputArray = {0, 10, -1}; + for (int j : inputArray) { + arrayExpressions.add(ExpressionContext.forLiteralContext(DataType.INT, j)); + } + + try { + GenerateArrayTransformFunction intArray = new GenerateArrayTransformFunction(arrayExpressions); + Assert.fail(); + } catch (IllegalStateException ignored) { + } + } +} diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java index ceeefa28d295..b1e525deb407 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java @@ -525,6 +525,205 @@ public void testStringArrayLiteral(boolean useMultiStageQueryEngine) } } + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateIntArray(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(1, 3, 1) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 3); + assertEquals(row.get(0).get(0).asInt(), 1); + assertEquals(row.get(0).get(1).asInt(), 2); + assertEquals(row.get(0).get(2).asInt(), 3); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateIntArrayWithoutStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(1, 3) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 3); + assertEquals(row.get(0).get(0).asInt(), 1); + assertEquals(row.get(0).get(1).asInt(), 2); + assertEquals(row.get(0).get(2).asInt(), 3); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateIntArrayWithIncorrectStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(1, 3, -1) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + assertEquals(jsonNode.get("exceptions").size(), 1); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateLongArray(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(2147483648, 2147483650, 2) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 2); + assertEquals(row.get(0).get(0).asLong(), 2147483648L); + assertEquals(row.get(0).get(1).asLong(), 2147483650L); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateLongArrayWithoutStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(2147483648, 2147483650) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 3); + assertEquals(row.get(0).get(0).asLong(), 2147483648L); + assertEquals(row.get(0).get(1).asLong(), 2147483649L); + assertEquals(row.get(0).get(2).asLong(), 2147483650L); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateLongArrayWithIncorrectStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(2147483648, 2147483650, -1) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + assertEquals(jsonNode.get("exceptions").size(), 1); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateFloatArray(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(0.1, 0.3, 0.1) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 3); + assertEquals(row.get(0).get(0).asDouble(), 0.1); + assertEquals(row.get(0).get(1).asDouble(), 0.1 + 0.1 * 1); + assertEquals(row.get(0).get(2).asDouble(), 0.1 + 0.1 * 2); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateFloatArrayWithoutStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(0.3, 3.1) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 3); + assertEquals(row.get(0).get(0).asDouble(), 0.3); + assertEquals(row.get(0).get(1).asDouble(), 1.3); + assertEquals(row.get(0).get(2).asDouble(), 2.3); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateFloatArrayWithIncorrectStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(0.3, 0.1, 1.1) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + assertEquals(jsonNode.get("exceptions").size(), 1); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateDoubleArray(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(CAST(0.1 AS DOUBLE), CAST(0.3 AS DOUBLE), CAST(0.1 AS DOUBLE)) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 3); + assertEquals(row.get(0).get(0).asDouble(), 0.1); + assertEquals(row.get(0).get(1).asDouble(), 0.1 + 0.1 * 1); + assertEquals(row.get(0).get(2).asDouble(), 0.1 + 0.1 * 2); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateDoubleArrayWithoutStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(CAST(0.3 AS DOUBLE), CAST(3.1 AS DOUBLE)) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + JsonNode rows = jsonNode.get("resultTable").get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 1); + assertEquals(row.get(0).size(), 3); + assertEquals(row.get(0).get(0).asDouble(), 0.3); + assertEquals(row.get(0).get(1).asDouble(), 1.3); + assertEquals(row.get(0).get(2).asDouble(), 2.3); + } + + @Test(dataProvider = "useV1QueryEngine") + public void testGenerateDoubleArrayWithIncorrectStepValue(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = + String.format("SELECT " + + "GENERATE_ARRAY(CAST(0.3 AS DOUBLE), CAST(0.1 AS DOUBLE), CAST(1.1 AS DOUBLE)) " + + "FROM %s LIMIT 1", getTableName()); + JsonNode jsonNode = postQuery(query); + assertEquals(jsonNode.get("exceptions").size(), 1); + } + @Override public String getTableName() { return DEFAULT_TABLE_NAME;