From e8e54368c7adbdb947fadbf45ecd6aed25841185 Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Sun, 4 Feb 2024 15:11:42 -0800 Subject: [PATCH] Fixing array literal usage for vector --- .../common/function/FunctionRegistry.java | 2 +- .../function/TransformFunctionType.java | 4 ++ .../request/context/RequestContextUtils.java | 31 ++++++++++ .../pinot/sql/parsers/CalciteSqlParser.java | 20 +++++-- .../rewriter/PredicateComparisonRewriter.java | 18 ++++-- .../integration/tests/custom/VectorTest.java | 57 +++++++++++++++++++ 6 files changed, 120 insertions(+), 12 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index deb1673d8bac..00df9498ddac 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -195,7 +195,7 @@ public static boolean jsonMatch(String text, String pattern) { } @ScalarFunction(names = {"vectorSimilarity", "vector_similarity"}, isPlaceholder = true) - public static double vectorSimilarity(float[] vector1, float[] vector2) { + public static boolean vectorSimilarity(float[] vector1, float[] vector2, int topk) { throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java index 7753260192bb..20bc26854cfa 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java @@ -273,6 +273,10 @@ public enum TransformFunctionType { VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE), OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"), + VECTOR_SIMILARITY("vectorSimilarity", ReturnTypes.BOOLEAN_NOT_NULL, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC), + ordinal -> ordinal > 1 && ordinal < 4), "vector_similarity"), + ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor", "array_value_constructor"), // Trigonometry diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java index 28f3037b252e..958a20da685d 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/RequestContextUtils.java @@ -470,6 +470,37 @@ private static float[] getVectorValue(ExpressionContext expressionContext) { } private static float[] getVectorValue(Expression thriftExpression) { + if (thriftExpression.getType() == ExpressionType.LITERAL) { + Literal literalExpression = thriftExpression.getLiteral(); + if (literalExpression.isSetIntArrayValue()) { + float[] vector = new float[literalExpression.getIntArrayValue().size()]; + for (int i = 0; i < literalExpression.getIntArrayValue().size(); i++) { + vector[i] = literalExpression.getIntArrayValue().get(i).floatValue(); + } + return vector; + } + if (literalExpression.isSetLongArrayValue()) { + float[] vector = new float[literalExpression.getLongArrayValue().size()]; + for (int i = 0; i < literalExpression.getLongArrayValue().size(); i++) { + vector[i] = literalExpression.getLongArrayValue().get(i).floatValue(); + } + return vector; + } + if (literalExpression.isSetFloatArrayValue()) { + float[] vector = new float[literalExpression.getFloatArrayValue().size()]; + for (int i = 0; i < literalExpression.getFloatArrayValue().size(); i++) { + vector[i] = literalExpression.getFloatArrayValue().get(i); + } + return vector; + } + if (literalExpression.isSetDoubleArrayValue()) { + float[] vector = new float[literalExpression.getDoubleArrayValue().size()]; + for (int i = 0; i < literalExpression.getDoubleArrayValue().size(); i++) { + vector[i] = literalExpression.getDoubleArrayValue().get(i).floatValue(); + } + return vector; + } + } if (thriftExpression.getType() != ExpressionType.FUNCTION) { throw new BadQueryRequestException( "Pinot does not support column or function on the right-hand side of the predicate"); diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java index 63058fd8c333..3d216bd64364 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/CalciteSqlParser.java @@ -296,17 +296,25 @@ private static void validateFilter(Expression filterExpression) { + "the signature is VECTOR_SIMILARITY(float[], float[], int)."); } Expression vectorLiteral = filterExpression.getFunctionCall().getOperands().get(1); - // Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals - if (!vectorLiteral.isSetFunctionCall() || !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase( - "arrayvalueconstructor")) { - throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float array literal, " - + "the signature is VECTOR_SIMILARITY(float[], float[], int)."); + /* + * Array Literal could be either: + * 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double + * 2. a float/double array literals + * Also check in + * {@link org.apache.pinot.sql.parsers.rewriter.PredicateComparisonRewriter#updateFunctionExpression(Expression)} + */ + if ((vectorLiteral.isSetFunctionCall() && !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase( + "arrayvalueconstructor")) + || (vectorLiteral.isSetLiteral() && !vectorLiteral.getLiteral().isSetFloatArrayValue() + && !vectorLiteral.getLiteral().isSetDoubleArrayValue())) { + throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float/double array " + + "literal, the signature is VECTOR_SIMILARITY(float[], float[], int)"); } if (filterExpression.getFunctionCall().getOperands().size() == 3) { Expression topK = filterExpression.getFunctionCall().getOperands().get(2); if (!topK.isSetLiteral()) { throw new IllegalStateException("The third argument of VECTOR_SIMILARITY must be an integer literal, " - + "the signature is VECTOR_SIMILARITY(float[], float[], int)."); + + "the signature is VECTOR_SIMILARITY(float[], float[], int)"); } } } else { diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java index 1917e37abc5e..c59b5126ec4a 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/PredicateComparisonRewriter.java @@ -132,11 +132,19 @@ private static Expression updateFunctionExpression(Expression expression) { case VECTOR_SIMILARITY: { Preconditions.checkArgument(operands.size() >= 2 && operands.size() <= 3, "For %s predicate, the number of operands must be at either 2 or 3, got: %s", filterKind, expression); - // Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals - if (operands.get(1).getFunctionCall() == null || !operands.get(1).getFunctionCall().getOperator() - .equalsIgnoreCase("arrayvalueconstructor")) { + /* + * Array Literal could be either: + * 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double + * 2. a float/double array literals + * Also check in {@link org.apache.pinot.sql.parsers.CalciteSqlParser#validateFilter(Expression)}} + */ + if ((operands.get(1).getFunctionCall() != null && !operands.get(1).getFunctionCall().getOperator() + .equalsIgnoreCase("arrayvalueconstructor")) + || (operands.get(1).getLiteral() != null && !operands.get(1).getLiteral().isSetFloatArrayValue() + && !operands.get(1).getLiteral().isSetDoubleArrayValue())) { throw new SqlCompilationException( - String.format("For %s predicate, the second operand must be a float array literal, got: %s", filterKind, + String.format("For %s predicate, the second operand must be a float/double array literal, got: %s", + filterKind, expression)); } if (operands.size() == 3 && operands.get(2).getLiteral() == null) { @@ -165,7 +173,7 @@ private static Expression updateFunctionExpression(Expression expression) { /** * Rewrite predicates to boolean expressions with EQUALS operator * Example1: "select * from table where col1" converts to - * "select * from table where col1 = true" + * "select * from table where col1 = true" * Example2: "select * from table where startsWith(col1, 'str')" converts to * "select * from table where startsWith(col1, 'str') = true" * diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java index 245cd31d05ae..24902150816f 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/VectorTest.java @@ -23,6 +23,8 @@ import java.io.File; import java.util.ArrayList; import java.util.Collection; +import java.util.List; +import java.util.Map; import java.util.stream.IntStream; import org.apache.avro.file.DataFileWriter; import org.apache.avro.generic.GenericData; @@ -30,8 +32,12 @@ import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.StringUtils; import org.apache.pinot.common.function.scalar.VectorFunctions; +import org.apache.pinot.spi.config.table.FieldConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; @@ -172,11 +178,62 @@ public void testQueriesWithLiterals(boolean useMultiStageQueryEngine) assertEquals(l2Distance, 22.627416997969522); } + @Test(dataProvider = "useBothQueryEngines") + public void testVectorSimilarity(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + int topK = 5; + String oneVectorStringLiteral = "ARRAY[1.1" + + StringUtils.repeat(", 1.1", VECTOR_DIM_SIZE - 1) + + "]"; + String query1 = + String.format("SELECT " + + "cosineDistance(%s, %s) AS dist " + + "FROM %s " + + "WHERE vectorSimilarity(%s, %s, %d) " + + "ORDER BY dist ASC " + + "LIMIT %d", + VECTOR_1, oneVectorStringLiteral, getTableName(), VECTOR_1, oneVectorStringLiteral, topK * 10, topK); + String query2 = + String.format("SELECT " + + "cosineDistance(%s, %s) as dist " + + "FROM %s " + + "ORDER BY dist ASC " + + "LIMIT %d", + VECTOR_1, oneVectorStringLiteral, getTableName(), topK); + + JsonNode jsonNode1 = postQuery(query1); + JsonNode jsonNode2 = postQuery(query2); + for (int i = 0; i < topK; i++) { + double dist1 = jsonNode1.get("resultTable").get("rows").get(i).get(0).asDouble(); + double dist2 = jsonNode2.get("resultTable").get("rows").get(i).get(0).asDouble(); + assertEquals(dist1, dist2); + } + } + @Override public String getTableName() { return DEFAULT_TABLE_NAME; } + @Override + public TableConfig createOfflineTableConfig() { + return new TableConfigBuilder(TableType.OFFLINE) + .setTableName(getTableName()) + .setFieldConfigList(List.of( + new FieldConfig.Builder(VECTOR_1) + .withIndexTypes(List.of(FieldConfig.IndexType.VECTOR)) + .withEncodingType(FieldConfig.EncodingType.RAW) + .withProperties(Map.of( + "vectorIndexType", "HNSW", + "vectorDimension", String.valueOf(VECTOR_DIM_SIZE), + "vectorDistanceFunction", "COSINE", + "version", "1")) + .build() + )) + .build(); + } + @Override public Schema createSchema() { return new Schema.SchemaBuilder().setSchemaName(getTableName())