Skip to content

Commit

Permalink
Fixing array literal usage for vector
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Feb 5, 2024
1 parent 3cefed5 commit e8e5436
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public static boolean jsonMatch(String text, String pattern) {
}

@ScalarFunction(names = {"vectorSimilarity", "vector_similarity"}, isPlaceholder = true)
public static double vectorSimilarity(float[] vector1, float[] vector2) {
public static boolean vectorSimilarity(float[] vector1, float[] vector2, int topk) {
throw new UnsupportedOperationException("Placeholder scalar function, should not reach here");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ public enum TransformFunctionType {
VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE),
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"),

VECTOR_SIMILARITY("vectorSimilarity", ReturnTypes.BOOLEAN_NOT_NULL,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC),
ordinal -> ordinal > 1 && ordinal < 4), "vector_similarity"),

ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor", "array_value_constructor"),

// Trigonometry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,37 @@ private static float[] getVectorValue(ExpressionContext expressionContext) {
}

private static float[] getVectorValue(Expression thriftExpression) {
if (thriftExpression.getType() == ExpressionType.LITERAL) {
Literal literalExpression = thriftExpression.getLiteral();
if (literalExpression.isSetIntArrayValue()) {
float[] vector = new float[literalExpression.getIntArrayValue().size()];
for (int i = 0; i < literalExpression.getIntArrayValue().size(); i++) {
vector[i] = literalExpression.getIntArrayValue().get(i).floatValue();
}
return vector;
}
if (literalExpression.isSetLongArrayValue()) {
float[] vector = new float[literalExpression.getLongArrayValue().size()];
for (int i = 0; i < literalExpression.getLongArrayValue().size(); i++) {
vector[i] = literalExpression.getLongArrayValue().get(i).floatValue();
}
return vector;
}
if (literalExpression.isSetFloatArrayValue()) {
float[] vector = new float[literalExpression.getFloatArrayValue().size()];
for (int i = 0; i < literalExpression.getFloatArrayValue().size(); i++) {
vector[i] = literalExpression.getFloatArrayValue().get(i);
}
return vector;
}
if (literalExpression.isSetDoubleArrayValue()) {
float[] vector = new float[literalExpression.getDoubleArrayValue().size()];
for (int i = 0; i < literalExpression.getDoubleArrayValue().size(); i++) {
vector[i] = literalExpression.getDoubleArrayValue().get(i).floatValue();
}
return vector;
}
}
if (thriftExpression.getType() != ExpressionType.FUNCTION) {
throw new BadQueryRequestException(
"Pinot does not support column or function on the right-hand side of the predicate");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,25 @@ private static void validateFilter(Expression filterExpression) {
+ "the signature is VECTOR_SIMILARITY(float[], float[], int).");
}
Expression vectorLiteral = filterExpression.getFunctionCall().getOperands().get(1);
// Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals
if (!vectorLiteral.isSetFunctionCall() || !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase(
"arrayvalueconstructor")) {
throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float array literal, "
+ "the signature is VECTOR_SIMILARITY(float[], float[], int).");
/*
* Array Literal could be either:
* 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double
* 2. a float/double array literals
* Also check in
* {@link org.apache.pinot.sql.parsers.rewriter.PredicateComparisonRewriter#updateFunctionExpression(Expression)}
*/
if ((vectorLiteral.isSetFunctionCall() && !vectorLiteral.getFunctionCall().getOperator().equalsIgnoreCase(
"arrayvalueconstructor"))
|| (vectorLiteral.isSetLiteral() && !vectorLiteral.getLiteral().isSetFloatArrayValue()
&& !vectorLiteral.getLiteral().isSetDoubleArrayValue())) {
throw new IllegalStateException("The second argument of VECTOR_SIMILARITY must be a float/double array "
+ "literal, the signature is VECTOR_SIMILARITY(float[], float[], int)");
}
if (filterExpression.getFunctionCall().getOperands().size() == 3) {
Expression topK = filterExpression.getFunctionCall().getOperands().get(2);
if (!topK.isSetLiteral()) {
throw new IllegalStateException("The third argument of VECTOR_SIMILARITY must be an integer literal, "
+ "the signature is VECTOR_SIMILARITY(float[], float[], int).");
+ "the signature is VECTOR_SIMILARITY(float[], float[], int)");
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,19 @@ private static Expression updateFunctionExpression(Expression expression) {
case VECTOR_SIMILARITY: {
Preconditions.checkArgument(operands.size() >= 2 && operands.size() <= 3,
"For %s predicate, the number of operands must be at either 2 or 3, got: %s", filterKind, expression);
// Array Literal is a function of type 'ARRAYVALUECONSTRUCTOR' with operands of Float/Double Literals
if (operands.get(1).getFunctionCall() == null || !operands.get(1).getFunctionCall().getOperator()
.equalsIgnoreCase("arrayvalueconstructor")) {
/*
* Array Literal could be either:
* 1. a function of type 'ARRAYVALUECONSTRUCTOR' with operands of float/double
* 2. a float/double array literals
* Also check in {@link org.apache.pinot.sql.parsers.CalciteSqlParser#validateFilter(Expression)}}
*/
if ((operands.get(1).getFunctionCall() != null && !operands.get(1).getFunctionCall().getOperator()
.equalsIgnoreCase("arrayvalueconstructor"))
|| (operands.get(1).getLiteral() != null && !operands.get(1).getLiteral().isSetFloatArrayValue()
&& !operands.get(1).getLiteral().isSetDoubleArrayValue())) {
throw new SqlCompilationException(
String.format("For %s predicate, the second operand must be a float array literal, got: %s", filterKind,
String.format("For %s predicate, the second operand must be a float/double array literal, got: %s",
filterKind,
expression));
}
if (operands.size() == 3 && operands.get(2).getLiteral() == null) {
Expand Down Expand Up @@ -165,7 +173,7 @@ private static Expression updateFunctionExpression(Expression expression) {
/**
* Rewrite predicates to boolean expressions with EQUALS operator
* Example1: "select * from table where col1" converts to
* "select * from table where col1 = true"
* "select * from table where col1 = true"
* Example2: "select * from table where startsWith(col1, 'str')" converts to
* "select * from table where startsWith(col1, 'str') = true"
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,21 @@
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.commons.lang3.RandomUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.function.scalar.VectorFunctions;
import org.apache.pinot.spi.config.table.FieldConfig;
import org.apache.pinot.spi.config.table.TableConfig;
import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;
Expand Down Expand Up @@ -172,11 +178,62 @@ public void testQueriesWithLiterals(boolean useMultiStageQueryEngine)
assertEquals(l2Distance, 22.627416997969522);
}

@Test(dataProvider = "useBothQueryEngines")
public void testVectorSimilarity(boolean useMultiStageQueryEngine)
throws Exception {
setUseMultiStageQueryEngine(useMultiStageQueryEngine);
int topK = 5;
String oneVectorStringLiteral = "ARRAY[1.1"
+ StringUtils.repeat(", 1.1", VECTOR_DIM_SIZE - 1)
+ "]";
String query1 =
String.format("SELECT "
+ "cosineDistance(%s, %s) AS dist "
+ "FROM %s "
+ "WHERE vectorSimilarity(%s, %s, %d) "
+ "ORDER BY dist ASC "
+ "LIMIT %d",
VECTOR_1, oneVectorStringLiteral, getTableName(), VECTOR_1, oneVectorStringLiteral, topK * 10, topK);
String query2 =
String.format("SELECT "
+ "cosineDistance(%s, %s) as dist "
+ "FROM %s "
+ "ORDER BY dist ASC "
+ "LIMIT %d",
VECTOR_1, oneVectorStringLiteral, getTableName(), topK);

JsonNode jsonNode1 = postQuery(query1);
JsonNode jsonNode2 = postQuery(query2);
for (int i = 0; i < topK; i++) {
double dist1 = jsonNode1.get("resultTable").get("rows").get(i).get(0).asDouble();
double dist2 = jsonNode2.get("resultTable").get("rows").get(i).get(0).asDouble();
assertEquals(dist1, dist2);
}
}

@Override
public String getTableName() {
return DEFAULT_TABLE_NAME;
}

@Override
public TableConfig createOfflineTableConfig() {
return new TableConfigBuilder(TableType.OFFLINE)
.setTableName(getTableName())
.setFieldConfigList(List.of(
new FieldConfig.Builder(VECTOR_1)
.withIndexTypes(List.of(FieldConfig.IndexType.VECTOR))
.withEncodingType(FieldConfig.EncodingType.RAW)
.withProperties(Map.of(
"vectorIndexType", "HNSW",
"vectorDimension", String.valueOf(VECTOR_DIM_SIZE),
"vectorDistanceFunction", "COSINE",
"version", "1"))
.build()
))
.build();
}

@Override
public Schema createSchema() {
return new Schema.SchemaBuilder().setSchemaName(getTableName())
Expand Down

0 comments on commit e8e5436

Please sign in to comment.