Skip to content

Commit

Permalink
Fix #1775: use optimized CqlVector<Float> codec to improve performance (
Browse files Browse the repository at this point in the history
  • Loading branch information
tatu-at-datastax authored Jan 3, 2025
1 parent 54d0f87 commit 902f9d0
Show file tree
Hide file tree
Showing 10 changed files with 496 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -148,6 +149,9 @@ private CqlSession getNewSession(SessionCacheKey cacheKey) {
builder.addContactPoints(seeds);
}

// Add optimized CqlVector codec (see [data-api#1775])
builder = builder.addTypeCodecs(SubtypeOnlyFloatVectorToArrayCodec.instance());

// aaron - this used to have an if / else that threw an exception if the database type was not
// known but we test that when creating the credentials for the cache key so no need to do it
// here.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector;

import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.internal.core.type.codec.FloatCodec;
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterators;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;

/**
* Implementation of {@link TypeCodec} which translates CQL vectors into float arrays. Difference
* between this and {@link
* com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec} is that
* we don't concern ourselves with the dimensionality specified in the input CQL type. This codec
* just reads all the bytes, tries to deserislize them consecutively into subtypes and then returns
* the result. Serialiation is similar: we take the input array, serialize each element and return
* the result.
*/
public class SubtypeOnlyFloatVectorToArrayCodec implements TypeCodec<float[]> {

private static final int ELEMENT_SIZE = 4;

protected final VectorType cqlType;
protected final GenericType<float[]> javaType;

private final FloatCodec floatCodec = new FloatCodec();

private static final SubtypeOnlyFloatVectorToArrayCodec INSTANCE =
new SubtypeOnlyFloatVectorToArrayCodec(DataTypes.FLOAT);

private SubtypeOnlyFloatVectorToArrayCodec(DataType subType) {
cqlType = new SubtypeOnlyVectorType(Objects.requireNonNull(subType, "subType cannot be null"));
javaType = GenericType.of(float[].class);
}

public static TypeCodec<float[]> instance() {
return INSTANCE;
}

@Override
public GenericType<float[]> getJavaType() {
return javaType;
}

@Override
public DataType getCqlType() {
return cqlType;
}

@Override
public boolean accepts(Class<?> javaClass) {
return float[].class.equals(javaClass);
}

@Override
public boolean accepts(Object value) {
return value instanceof float[];
}

@Override
public boolean accepts(DataType value) {
if (!(value instanceof VectorType)) {
return false;
}
VectorType valueVectorType = (VectorType) value;
return this.cqlType.getElementType().equals(valueVectorType.getElementType());
}

@Override
public ByteBuffer encode(float[] array, ProtocolVersion protocolVersion) {
if (array == null) {
return null;
}
int length = array.length;
int totalSize = length * ELEMENT_SIZE;
ByteBuffer output = ByteBuffer.allocate(totalSize);
for (int i = 0; i < length; i++) {
serializeElement(output, array, i, protocolVersion);
}
output.flip();
return output;
}

@Override
public float[] decode(ByteBuffer bytes, ProtocolVersion protocolVersion) {
if (bytes == null || bytes.remaining() == 0) {
throw new IllegalArgumentException(
"Input ByteBuffer must not be null and must have non-zero remaining bytes");
}
// TODO: Do we want to treat this as an error? We could also just ignore any extraneous bytes
// if they appear.
if (bytes.remaining() % ELEMENT_SIZE != 0) {
throw new IllegalArgumentException(
String.format("Input ByteBuffer should have a multiple of %d bytes", ELEMENT_SIZE));
}
ByteBuffer input = bytes.duplicate();
int elementCount = input.remaining() / 4;
float[] array = new float[elementCount];
for (int i = 0; i < elementCount; i++) {
deserializeElement(input, array, i, protocolVersion);
}
return array;
}

/**
* Write the {@code index}th element of {@code array} to {@code output}.
*
* @param output The ByteBuffer to write to.
* @param array The array to read from.
* @param index The element index.
* @param protocolVersion The protocol version to use.
*/
protected void serializeElement(
ByteBuffer output, float[] array, int index, ProtocolVersion protocolVersion) {
output.putFloat(array[index]);
}

/**
* Read the {@code index}th element of {@code array} from {@code input}.
*
* @param input The ByteBuffer to read from.
* @param array The array to write to.
* @param index The element index.
* @param protocolVersion The protocol version to use.
*/
protected void deserializeElement(
ByteBuffer input, float[] array, int index, ProtocolVersion protocolVersion) {
array[index] = input.getFloat();
}

@Override
public String format(float[] value) {
return value == null ? "NULL" : Arrays.toString(value);
}

@Override
public float[] parse(String str) {
/* TODO: Logic below requires a double traversal through the input String but there's no other obvious way to
* get the size. It's still probably worth the initial pass through in order to avoid having to deal with
* resizing ops. Fortunately we're only dealing with the format/parse pair here so this shouldn't impact
* general performance much. */
if ((str == null) || str.isEmpty()) {
throw new IllegalArgumentException("Cannot create float array from null or empty string");
}
Iterable<String> strIterable =
Splitter.on(", ").trimResults().split(str.substring(1, str.length() - 1));
float[] rv = new float[Iterators.size(strIterable.iterator())];
Iterator<String> strIterator = strIterable.iterator();
for (int i = 0; i < rv.length; ++i) {
String strVal = strIterator.next();
if (strVal == null || strVal.isBlank()) {
throw new IllegalArgumentException("Null element observed in float array string");
}
rv[i] = floatCodec.parse(strVal).floatValue();
}
return rv;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector;

import com.datastax.oss.driver.api.core.detach.AttachmentPoint;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.internal.core.type.DefaultVectorType;
import java.util.Objects;

/**
* An implementation of {@link VectorType} which is only concerned with the subtype of the vector.
* Useful if you want to describe a call of vector types that do not differ by subtype but do differ
* by dimension.
*/
public class SubtypeOnlyVectorType extends DefaultVectorType {
private static final int NO_DIMENSION = -1;

public SubtypeOnlyVectorType(DataType subtype) {
super(subtype, NO_DIMENSION);
}

@Override
public int getDimensions() {
throw new UnsupportedOperationException("Subtype-only vectors do not support dimensions");
}

/* ============== General class implementation ============== */
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
return (o instanceof VectorType that) && that.getElementType().equals(getElementType());
}

@Override
public int hashCode() {
return super.hashCode() ^ Objects.hashCode(getElementType());
}

@Override
public String toString() {
return String.format("(Subtype-only) Vector(%s)", getElementType());
}

@Override
public boolean isDetached() {
return false;
}

@Override
public void attach(AttachmentPoint attachmentPoint) {
// nothing to do
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ public <JavaT, CqlT> JSONCodec<JavaT, CqlT> codecToCQL(
throw new ToCQLCodecException(value, columnType, "only Vector<Float> supported");
}
if (value instanceof Collection<?>) {
return VectorCodecs.arrayToCQLFloatVectorCodec(vt);
return VectorCodecs.arrayToCQLFloatArrayCodec(vt);
}
if (value instanceof EJSONWrapper) {
return VectorCodecs.binaryToCQLFloatVectorCodec(vt);
return VectorCodecs.binaryToCQLFloatArrayCodec(vt);
}

throw new ToCQLCodecException(value, columnType, "no codec matching value type");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.JsonLiteral;
import io.stargate.sgv2.jsonapi.exception.checked.ToCQLCodecException;
import io.stargate.sgv2.jsonapi.util.CqlVectorUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

Expand All @@ -22,7 +21,7 @@ public abstract class VectorCodecs {
private static final GenericType<List<Float>> FLOAT_LIST_TYPE = GenericType.listOf(Float.class);
private static final GenericType<EJSONWrapper> EJSON_TYPE = GenericType.of(EJSONWrapper.class);

public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> arrayToCQLFloatVectorCodec(
public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> arrayToCQLFloatArrayCodec(
VectorType vectorType) {
// Unfortunately we cannot simply construct and return a single Codec instance here
// because ApiVectorType's dimensions vary, and we need to know the expected dimensions
Expand All @@ -31,18 +30,18 @@ public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> arrayToCQLFloatVectorCodec(
new JSONCodec<>(
FLOAT_LIST_TYPE,
vectorType,
(cqlType, value) -> listToCQLFloatVector(vectorType, value),
(cqlType, value) -> listToCQLFloatArray(vectorType, value),
// This codec only for to-cql case, not to-json, so we don't need this
null);
}

public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> binaryToCQLFloatVectorCodec(
public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> binaryToCQLFloatArrayCodec(
VectorType vectorType) {
return (JSONCodec<JavaT, CqlT>)
new JSONCodec<>(
EJSON_TYPE,
vectorType,
(cqlType, value) -> binaryToCQLFloatVector(vectorType, value),
(cqlType, value) -> binaryToCQLFloatArray(vectorType, value),
null);
}

Expand All @@ -53,49 +52,54 @@ public static <JavaT, CqlT> JSONCodec<JavaT, CqlT> toJSONFloatVectorCodec(Vector
vectorType,
// This codec only for to-json case, not to-cql, so we don't need this
null,
(objectMapper, cqlType, value) -> toJsonNode(objectMapper, (CqlVector<Number>) value));
(objectMapper, cqlType, value) -> toJsonNode(objectMapper, value));
}

/** Method for actual conversion from JSON Number Array into CQL Float Vector. */
static CqlVector<Float> listToCQLFloatVector(VectorType vectorType, Collection<?> listValue)
/**
* Method for actual conversion from JSON Number Array into float array for Codec to use as Vector
* value.
*/
static float[] listToCQLFloatArray(VectorType vectorType, Collection<?> listValue)
throws ToCQLCodecException {
Collection<JsonLiteral<?>> vectorIn = (Collection<JsonLiteral<?>>) listValue;
validateVectorLength(vectorType, vectorIn, vectorIn.size());

List<Float> floats = new ArrayList<>(vectorIn.size());
float[] floats = new float[vectorIn.size()];
int ix = 0;
for (JsonLiteral<?> literalElement : vectorIn) {
Object element = literalElement.value();
if (element instanceof Number num) {
floats.add(num.floatValue());
floats[ix++] = num.floatValue();
continue;
}
throw new ToCQLCodecException(
vectorIn,
vectorType,
String.format(
"expected JSON Number value as Vector element at position #%d (of %d), instead have: %s",
floats.size(), vectorIn.size(), literalElement));
ix, vectorIn.size(), literalElement));
}
return CqlVector.newInstance(floats);
return floats;
}

/**
* Method for actual conversion from EJSON-wrapped Base64-encoded String into CQL Float Vector.
* Method for actual conversion from EJSON-wrapped Base64-encoded String into float array for
* Codec to use as Vector value.
*/
static CqlVector<Float> binaryToCQLFloatVector(VectorType vectorType, EJSONWrapper binaryValue)
static float[] binaryToCQLFloatArray(VectorType vectorType, EJSONWrapper binaryValue)
throws ToCQLCodecException {
byte[] binary = JSONCodec.ToCQL.byteArrayFromEJSON(vectorType, binaryValue);
CqlVector<Float> vector;
float[] floats;
try {
vector = CqlVectorUtil.bytesToCqlVector(binary);
floats = CqlVectorUtil.bytesToFloats(binary);
} catch (IllegalArgumentException e) {
throw new ToCQLCodecException(
binaryValue,
vectorType,
String.format("failed to decode Base64-encoded packed Vector value: %s", e.getMessage()));
}
validateVectorLength(vectorType, binaryValue, vector.size());
return vector;
validateVectorLength(vectorType, binaryValue, floats.length);
return floats;
}

private static void validateVectorLength(VectorType vectorType, Object value, int actualLen)
Expand All @@ -110,8 +114,22 @@ private static void validateVectorLength(VectorType vectorType, Object value, in
}
}

static JsonNode toJsonNode(ObjectMapper objectMapper, CqlVector<Number> vectorValue) {
final ArrayNode result = objectMapper.createArrayNode();
static JsonNode toJsonNode(ObjectMapper objectMapper, Object vectorValue) {
// 18-Dec-2024, tatu: [data-api#1775] Support for more efficient but still
// allow old binding to work too; test type here, use appropriate logic
if (vectorValue instanceof float[] floats) {
return toJsonNodeFromFloats(objectMapper, floats);
}
if (vectorValue instanceof CqlVector<?> vector) {
return toJsonNodeFromCqlVector(objectMapper, (CqlVector<Number>) vector);
}
throw new IllegalArgumentException(
"Unrecognized type for CQL Vector value: " + vectorValue.getClass().getCanonicalName());
}

static JsonNode toJsonNodeFromCqlVector(
ObjectMapper objectMapper, CqlVector<Number> vectorValue) {
final ArrayNode result = objectMapper.getNodeFactory().arrayNode(vectorValue.size());
for (Number element : vectorValue) {
if (element == null) { // is this even legal?
result.addNull();
Expand All @@ -121,4 +139,13 @@ static JsonNode toJsonNode(ObjectMapper objectMapper, CqlVector<Number> vectorVa
}
return result;
}

static JsonNode toJsonNodeFromFloats(ObjectMapper objectMapper, float[] vectorValue) {
// For now, output still as array of floats; in future maybe as Base64-encoded packed binary
final ArrayNode result = objectMapper.getNodeFactory().arrayNode(vectorValue.length);
for (float f : vectorValue) {
result.add(f);
}
return result;
}
}
Loading

0 comments on commit 902f9d0

Please sign in to comment.