diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java index acc9d9b53..23f362238 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/CQLSessionCache.java @@ -14,6 +14,7 @@ import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1; import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import java.net.InetSocketAddress; @@ -148,6 +149,9 @@ private CqlSession getNewSession(SessionCacheKey cacheKey) { builder.addContactPoints(seeds); } + // Add optimized CqlVector codec (see [data-api#1775]) + builder = builder.addTypeCodecs(SubtypeOnlyFloatVectorToArrayCodec.instance()); + // aaron - this used to have an if / else that threw an exception if the database type was not // known but we test that when creating the credentials for the cache key so no need to do it // here. diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/optvector/SubtypeOnlyFloatVectorToArrayCodec.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/optvector/SubtypeOnlyFloatVectorToArrayCodec.java new file mode 100644 index 000000000..be054ebcb --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/optvector/SubtypeOnlyFloatVectorToArrayCodec.java @@ -0,0 +1,165 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector; + +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.api.core.type.VectorType; +import com.datastax.oss.driver.api.core.type.codec.TypeCodec; +import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import com.datastax.oss.driver.internal.core.type.codec.FloatCodec; +import com.datastax.oss.driver.shaded.guava.common.base.Splitter; +import com.datastax.oss.driver.shaded.guava.common.collect.Iterators; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Objects; + +/** + * Implementation of {@link TypeCodec} which translates CQL vectors into float arrays. Difference + * between this and {@link + * com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec} is that + * we don't concern ourselves with the dimensionality specified in the input CQL type. This codec + * just reads all the bytes, tries to deserislize them consecutively into subtypes and then returns + * the result. Serialiation is similar: we take the input array, serialize each element and return + * the result. + */ +public class SubtypeOnlyFloatVectorToArrayCodec implements TypeCodec { + + private static final int ELEMENT_SIZE = 4; + + protected final VectorType cqlType; + protected final GenericType javaType; + + private final FloatCodec floatCodec = new FloatCodec(); + + private static final SubtypeOnlyFloatVectorToArrayCodec INSTANCE = + new SubtypeOnlyFloatVectorToArrayCodec(DataTypes.FLOAT); + + private SubtypeOnlyFloatVectorToArrayCodec(DataType subType) { + cqlType = new SubtypeOnlyVectorType(Objects.requireNonNull(subType, "subType cannot be null")); + javaType = GenericType.of(float[].class); + } + + public static TypeCodec instance() { + return INSTANCE; + } + + @Override + public GenericType getJavaType() { + return javaType; + } + + @Override + public DataType getCqlType() { + return cqlType; + } + + @Override + public boolean accepts(Class javaClass) { + return float[].class.equals(javaClass); + } + + @Override + public boolean accepts(Object value) { + return value instanceof float[]; + } + + @Override + public boolean accepts(DataType value) { + if (!(value instanceof VectorType)) { + return false; + } + VectorType valueVectorType = (VectorType) value; + return this.cqlType.getElementType().equals(valueVectorType.getElementType()); + } + + @Override + public ByteBuffer encode(float[] array, ProtocolVersion protocolVersion) { + if (array == null) { + return null; + } + int length = array.length; + int totalSize = length * ELEMENT_SIZE; + ByteBuffer output = ByteBuffer.allocate(totalSize); + for (int i = 0; i < length; i++) { + serializeElement(output, array, i, protocolVersion); + } + output.flip(); + return output; + } + + @Override + public float[] decode(ByteBuffer bytes, ProtocolVersion protocolVersion) { + if (bytes == null || bytes.remaining() == 0) { + throw new IllegalArgumentException( + "Input ByteBuffer must not be null and must have non-zero remaining bytes"); + } + // TODO: Do we want to treat this as an error? We could also just ignore any extraneous bytes + // if they appear. + if (bytes.remaining() % ELEMENT_SIZE != 0) { + throw new IllegalArgumentException( + String.format("Input ByteBuffer should have a multiple of %d bytes", ELEMENT_SIZE)); + } + ByteBuffer input = bytes.duplicate(); + int elementCount = input.remaining() / 4; + float[] array = new float[elementCount]; + for (int i = 0; i < elementCount; i++) { + deserializeElement(input, array, i, protocolVersion); + } + return array; + } + + /** + * Write the {@code index}th element of {@code array} to {@code output}. + * + * @param output The ByteBuffer to write to. + * @param array The array to read from. + * @param index The element index. + * @param protocolVersion The protocol version to use. + */ + protected void serializeElement( + ByteBuffer output, float[] array, int index, ProtocolVersion protocolVersion) { + output.putFloat(array[index]); + } + + /** + * Read the {@code index}th element of {@code array} from {@code input}. + * + * @param input The ByteBuffer to read from. + * @param array The array to write to. + * @param index The element index. + * @param protocolVersion The protocol version to use. + */ + protected void deserializeElement( + ByteBuffer input, float[] array, int index, ProtocolVersion protocolVersion) { + array[index] = input.getFloat(); + } + + @Override + public String format(float[] value) { + return value == null ? "NULL" : Arrays.toString(value); + } + + @Override + public float[] parse(String str) { + /* TODO: Logic below requires a double traversal through the input String but there's no other obvious way to + * get the size. It's still probably worth the initial pass through in order to avoid having to deal with + * resizing ops. Fortunately we're only dealing with the format/parse pair here so this shouldn't impact + * general performance much. */ + if ((str == null) || str.isEmpty()) { + throw new IllegalArgumentException("Cannot create float array from null or empty string"); + } + Iterable strIterable = + Splitter.on(", ").trimResults().split(str.substring(1, str.length() - 1)); + float[] rv = new float[Iterators.size(strIterable.iterator())]; + Iterator strIterator = strIterable.iterator(); + for (int i = 0; i < rv.length; ++i) { + String strVal = strIterator.next(); + if (strVal == null || strVal.isBlank()) { + throw new IllegalArgumentException("Null element observed in float array string"); + } + rv[i] = floatCodec.parse(strVal).floatValue(); + } + return rv; + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/optvector/SubtypeOnlyVectorType.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/optvector/SubtypeOnlyVectorType.java new file mode 100644 index 000000000..8a40f3c28 --- /dev/null +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/optvector/SubtypeOnlyVectorType.java @@ -0,0 +1,54 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector; + +import com.datastax.oss.driver.api.core.detach.AttachmentPoint; +import com.datastax.oss.driver.api.core.type.DataType; +import com.datastax.oss.driver.api.core.type.VectorType; +import com.datastax.oss.driver.internal.core.type.DefaultVectorType; +import java.util.Objects; + +/** + * An implementation of {@link VectorType} which is only concerned with the subtype of the vector. + * Useful if you want to describe a call of vector types that do not differ by subtype but do differ + * by dimension. + */ +public class SubtypeOnlyVectorType extends DefaultVectorType { + private static final int NO_DIMENSION = -1; + + public SubtypeOnlyVectorType(DataType subtype) { + super(subtype, NO_DIMENSION); + } + + @Override + public int getDimensions() { + throw new UnsupportedOperationException("Subtype-only vectors do not support dimensions"); + } + + /* ============== General class implementation ============== */ + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + return (o instanceof VectorType that) && that.getElementType().equals(getElementType()); + } + + @Override + public int hashCode() { + return super.hashCode() ^ Objects.hashCode(getElementType()); + } + + @Override + public String toString() { + return String.format("(Subtype-only) Vector(%s)", getElementType()); + } + + @Override + public boolean isDetached() { + return false; + } + + @Override + public void attach(AttachmentPoint attachmentPoint) { + // nothing to do + } +} diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java index ec39e2911..444565b07 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistry.java @@ -142,10 +142,10 @@ public JSONCodec codecToCQL( throw new ToCQLCodecException(value, columnType, "only Vector supported"); } if (value instanceof Collection) { - return VectorCodecs.arrayToCQLFloatVectorCodec(vt); + return VectorCodecs.arrayToCQLFloatArrayCodec(vt); } if (value instanceof EJSONWrapper) { - return VectorCodecs.binaryToCQLFloatVectorCodec(vt); + return VectorCodecs.binaryToCQLFloatArrayCodec(vt); } throw new ToCQLCodecException(value, columnType, "no codec matching value type"); diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/VectorCodecs.java b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/VectorCodecs.java index ed57d4472..2dbf1d9ee 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/VectorCodecs.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/VectorCodecs.java @@ -10,7 +10,6 @@ import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.JsonLiteral; import io.stargate.sgv2.jsonapi.exception.checked.ToCQLCodecException; import io.stargate.sgv2.jsonapi.util.CqlVectorUtil; -import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -22,7 +21,7 @@ public abstract class VectorCodecs { private static final GenericType> FLOAT_LIST_TYPE = GenericType.listOf(Float.class); private static final GenericType EJSON_TYPE = GenericType.of(EJSONWrapper.class); - public static JSONCodec arrayToCQLFloatVectorCodec( + public static JSONCodec arrayToCQLFloatArrayCodec( VectorType vectorType) { // Unfortunately we cannot simply construct and return a single Codec instance here // because ApiVectorType's dimensions vary, and we need to know the expected dimensions @@ -31,18 +30,18 @@ public static JSONCodec arrayToCQLFloatVectorCodec( new JSONCodec<>( FLOAT_LIST_TYPE, vectorType, - (cqlType, value) -> listToCQLFloatVector(vectorType, value), + (cqlType, value) -> listToCQLFloatArray(vectorType, value), // This codec only for to-cql case, not to-json, so we don't need this null); } - public static JSONCodec binaryToCQLFloatVectorCodec( + public static JSONCodec binaryToCQLFloatArrayCodec( VectorType vectorType) { return (JSONCodec) new JSONCodec<>( EJSON_TYPE, vectorType, - (cqlType, value) -> binaryToCQLFloatVector(vectorType, value), + (cqlType, value) -> binaryToCQLFloatArray(vectorType, value), null); } @@ -53,20 +52,24 @@ public static JSONCodec toJSONFloatVectorCodec(Vector vectorType, // This codec only for to-json case, not to-cql, so we don't need this null, - (objectMapper, cqlType, value) -> toJsonNode(objectMapper, (CqlVector) value)); + (objectMapper, cqlType, value) -> toJsonNode(objectMapper, value)); } - /** Method for actual conversion from JSON Number Array into CQL Float Vector. */ - static CqlVector listToCQLFloatVector(VectorType vectorType, Collection listValue) + /** + * Method for actual conversion from JSON Number Array into float array for Codec to use as Vector + * value. + */ + static float[] listToCQLFloatArray(VectorType vectorType, Collection listValue) throws ToCQLCodecException { Collection> vectorIn = (Collection>) listValue; validateVectorLength(vectorType, vectorIn, vectorIn.size()); - List floats = new ArrayList<>(vectorIn.size()); + float[] floats = new float[vectorIn.size()]; + int ix = 0; for (JsonLiteral literalElement : vectorIn) { Object element = literalElement.value(); if (element instanceof Number num) { - floats.add(num.floatValue()); + floats[ix++] = num.floatValue(); continue; } throw new ToCQLCodecException( @@ -74,28 +77,29 @@ static CqlVector listToCQLFloatVector(VectorType vectorType, Collection binaryToCQLFloatVector(VectorType vectorType, EJSONWrapper binaryValue) + static float[] binaryToCQLFloatArray(VectorType vectorType, EJSONWrapper binaryValue) throws ToCQLCodecException { byte[] binary = JSONCodec.ToCQL.byteArrayFromEJSON(vectorType, binaryValue); - CqlVector vector; + float[] floats; try { - vector = CqlVectorUtil.bytesToCqlVector(binary); + floats = CqlVectorUtil.bytesToFloats(binary); } catch (IllegalArgumentException e) { throw new ToCQLCodecException( binaryValue, vectorType, String.format("failed to decode Base64-encoded packed Vector value: %s", e.getMessage())); } - validateVectorLength(vectorType, binaryValue, vector.size()); - return vector; + validateVectorLength(vectorType, binaryValue, floats.length); + return floats; } private static void validateVectorLength(VectorType vectorType, Object value, int actualLen) @@ -110,8 +114,22 @@ private static void validateVectorLength(VectorType vectorType, Object value, in } } - static JsonNode toJsonNode(ObjectMapper objectMapper, CqlVector vectorValue) { - final ArrayNode result = objectMapper.createArrayNode(); + static JsonNode toJsonNode(ObjectMapper objectMapper, Object vectorValue) { + // 18-Dec-2024, tatu: [data-api#1775] Support for more efficient but still + // allow old binding to work too; test type here, use appropriate logic + if (vectorValue instanceof float[] floats) { + return toJsonNodeFromFloats(objectMapper, floats); + } + if (vectorValue instanceof CqlVector vector) { + return toJsonNodeFromCqlVector(objectMapper, (CqlVector) vector); + } + throw new IllegalArgumentException( + "Unrecognized type for CQL Vector value: " + vectorValue.getClass().getCanonicalName()); + } + + static JsonNode toJsonNodeFromCqlVector( + ObjectMapper objectMapper, CqlVector vectorValue) { + final ArrayNode result = objectMapper.getNodeFactory().arrayNode(vectorValue.size()); for (Number element : vectorValue) { if (element == null) { // is this even legal? result.addNull(); @@ -121,4 +139,13 @@ static JsonNode toJsonNode(ObjectMapper objectMapper, CqlVector vectorVa } return result; } + + static JsonNode toJsonNodeFromFloats(ObjectMapper objectMapper, float[] vectorValue) { + // For now, output still as array of floats; in future maybe as Base64-encoded packed binary + final ArrayNode result = objectMapper.getNodeFactory().arrayNode(vectorValue.length); + for (float f : vectorValue) { + result.add(f); + } + return result; + } } diff --git a/src/main/java/io/stargate/sgv2/jsonapi/util/CqlVectorUtil.java b/src/main/java/io/stargate/sgv2/jsonapi/util/CqlVectorUtil.java index 9b97f6166..3cd3c4a81 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/util/CqlVectorUtil.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/util/CqlVectorUtil.java @@ -21,17 +21,6 @@ * */ public interface CqlVectorUtil { - /** - * Method for converting binary-packed representation of a CQL {@code float} vector into a {@link - * CqlVector} instance. - * - * @param packedBytes binary-packed representation of the vector - * @return {@link CqlVector} instance representing the vector - */ - static CqlVector bytesToCqlVector(byte[] packedBytes) { - return floatsToCqlVector(bytesToFloats(packedBytes)); - } - /** * Method for converting binary-packed representation of a CQL {@code float} vector into a raw * {@code float[]} array. diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/CodecTestBase.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/CodecTestBase.java new file mode 100644 index 000000000..c6de44fb5 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/CodecTestBase.java @@ -0,0 +1,44 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.optvector; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.type.codec.TypeCodec; +import com.datastax.oss.protocol.internal.util.Bytes; +import java.nio.ByteBuffer; + +public class CodecTestBase { + protected TypeCodec codec; + + protected String encode(T t, ProtocolVersion protocolVersion) { + assertThat(codec).as("Must set codec before calling this method").isNotNull(); + ByteBuffer bytes = codec.encode(t, protocolVersion); + return (bytes == null) ? null : Bytes.toHexString(bytes); + } + + protected String encode(T t) { + return encode(t, ProtocolVersion.DEFAULT); + } + + protected T decode(String hexString, ProtocolVersion protocolVersion) { + assertThat(codec).as("Must set codec before calling this method").isNotNull(); + ByteBuffer bytes = (hexString == null) ? null : Bytes.fromHexString(hexString); + // Decode twice, to assert that decode leaves the input buffer in its original state + codec.decode(bytes, protocolVersion); + return codec.decode(bytes, protocolVersion); + } + + protected T decode(String hexString) { + return decode(hexString, ProtocolVersion.DEFAULT); + } + + protected String format(T t) { + assertThat(codec).as("Must set codec before calling this method").isNotNull(); + return codec.format(t); + } + + protected T parse(String s) { + assertThat(codec).as("Must set codec before calling this method").isNotNull(); + return codec.parse(s); + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/SubtypeOnlyFloatVectorTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/SubtypeOnlyFloatVectorTest.java new file mode 100644 index 000000000..ea319864e --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/SubtypeOnlyFloatVectorTest.java @@ -0,0 +1,76 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.optvector; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.api.core.type.codec.TypeCodec; +import com.datastax.oss.driver.api.core.type.codec.registry.MutableCodecRegistry; +import com.datastax.oss.driver.internal.core.type.codec.registry.DefaultCodecRegistry; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec; +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Test; + +/** + * Test of the full suite of "subtype only" functionality. Goal here is to confirm two distinct + * questions: + * + *

* If we use the "subtype only" type with a {@link DefaultCodecRegistry} do we get the same + * codec regardless of vector dimension? * Can we use the codec we get back from the default codec + * registry to encode and decode vectors of different sizes? + * + *

Note that all of this works only because of an implementation detail in DefaultCodecRegistry. + * The use of Objects.equals() in the code referenced below means that we effectively use the + * equals() method of the DataType impl to determine whether keys in the codec cache match. We + * leverage this behaviour to make SubtypeOnlyVectorType match all vectors with an equivalent + * subtype. This behaviour is thus not guaranteed for other codec registry impls. + * + *

Codec + * registry code + */ +public class SubtypeOnlyFloatVectorTest { + + @Test + public void shouldFindSubtypeOnlyCodecRegardlessOfSize() { + + MutableCodecRegistry registry = new DefaultCodecRegistry("subtype_only"); + registry.register(SubtypeOnlyFloatVectorToArrayCodec.instance()); + + AtomicReference> codecRef = new AtomicReference<>(); + for (int i = 1; i <= 2000; ++i) { + + TypeCodec codec = registry.codecFor(DataTypes.vectorOf(DataTypes.FLOAT, i)); + codecRef.compareAndSet(null, codec); + assertThat(codec).isInstanceOf(SubtypeOnlyFloatVectorToArrayCodec.class); + assertThat(codec).isEqualTo(codecRef.get()); + } + } + + @Test + public void shouldEncodeAndDecodeVectorsOfArbitrarySize() { + + MutableCodecRegistry registry = new DefaultCodecRegistry("subtype_only"); + registry.register(SubtypeOnlyFloatVectorToArrayCodec.instance()); + + for (int i = 1; i <= 2000; ++i) { + + TypeCodec codec = registry.codecFor(DataTypes.vectorOf(DataTypes.FLOAT, i)); + float[] comparison = randomFloatArray(i); + float[] result = + codec.decode(codec.encode(comparison, ProtocolVersion.V4), ProtocolVersion.V4); + assertThat(result).isEqualTo(comparison); + } + } + + private float[] randomFloatArray(int size) { + // Use fixed seed for reproducibility + Random random = new Random(size); + float[] rv = new float[size]; + for (int i = 0; i < size; ++i) { + rv[0] = random.nextFloat(); + } + return rv; + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/SubtypeOnlyFloatVectorToArrayCodecTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/SubtypeOnlyFloatVectorToArrayCodecTest.java new file mode 100644 index 000000000..a376d6079 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/optvector/SubtypeOnlyFloatVectorToArrayCodecTest.java @@ -0,0 +1,101 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.optvector; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.api.core.type.reflect.GenericType; +import com.datastax.oss.driver.internal.core.type.DefaultVectorType; +import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec; +import org.junit.Test; + +/** + * Basic sanity checks to make sure {@link SubtypeOnlyFloatVectorToArrayCodec} is a wall-behaved + * type codec + */ +public class SubtypeOnlyFloatVectorToArrayCodecTest extends CodecTestBase { + + private static final float[] VECTOR = new float[] {1.0f, 2.5f}; + + private static final String VECTOR_HEX_STRING = "0x" + "3f800000" + "40200000"; + + private static final String FORMATTED_VECTOR = "[1.0, 2.5]"; + + public SubtypeOnlyFloatVectorToArrayCodecTest() { + codec = SubtypeOnlyFloatVectorToArrayCodec.instance(); + } + + @Test + public void shouldEncode() { + assertThat(encode(VECTOR)).isEqualTo(VECTOR_HEX_STRING); + assertThat(encode(null)).isNull(); + } + + @Test + public void shouldDecode() { + assertThat(decode(VECTOR_HEX_STRING)).isEqualTo(VECTOR); + assertThatThrownBy(() -> decode("0x")).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> decode(null)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void shouldThrowOnDecodeIfTooFewBytes() { + // Dropping 4 bytes would knock off exactly 1 float, anything less than that would be something + // we couldn't parse a float out of + for (int i = 1; i <= 3; ++i) { + // 2 chars of hex encoded string = 1 byte + int lastIndex = VECTOR_HEX_STRING.length() - (2 * i); + assertThatThrownBy(() -> decode(VECTOR_HEX_STRING.substring(0, lastIndex))) + .isInstanceOf(IllegalArgumentException.class); + } + } + + @Test + public void shouldFormat() { + assertThat(format(VECTOR)).isEqualTo(FORMATTED_VECTOR); + assertThat(format(null)).isEqualTo("NULL"); + } + + @Test + public void shouldParse() { + assertThat(parse(FORMATTED_VECTOR)).isEqualTo(VECTOR); + assertThatThrownBy(() -> parse("NULL")).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> parse("null")).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> parse("")).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> parse(null)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void shouldAcceptDataType() { + assertThat(codec.accepts(DataTypes.vectorOf(DataTypes.FLOAT, 2))).isTrue(); + assertThat(codec.accepts(DataTypes.INT)).isFalse(); + } + + @Test + public void shouldAcceptVectorTypeAllDimensionOnly() { + for (int i = 0; i < 1000; ++i) { + assertThat(codec.accepts(new DefaultVectorType(DataTypes.FLOAT, i))).isTrue(); + } + } + + @Test + public void shouldAcceptGenericType() { + assertThat(codec.accepts(GenericType.of(float[].class))).isTrue(); + assertThat(codec.accepts(GenericType.arrayOf(Float.class))).isFalse(); + assertThat(codec.accepts(GenericType.arrayOf(Integer.class))).isFalse(); + assertThat(codec.accepts(GenericType.of(Float.class))).isFalse(); + assertThat(codec.accepts(GenericType.of(Integer.class))).isFalse(); + } + + @Test + public void shouldAcceptRawType() { + assertThat(codec.accepts(float[].class)).isTrue(); + assertThat(codec.accepts(Integer.class)).isFalse(); + } + + @Test + public void shouldAcceptObject() { + assertThat(codec.accepts(VECTOR)).isTrue(); + assertThat(codec.accepts(Integer.MIN_VALUE)).isFalse(); + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java index 33e53e755..40763c6b9 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/operation/filters/table/codecs/JSONCodecRegistryTest.java @@ -477,14 +477,10 @@ private static Stream validCodecToCQLTestCasesVectors() { numberLiteral(0L), numberLiteral(new BigDecimal(-0.5)), numberLiteral(new BigDecimal(0.25))), - CqlVectorUtil.floatsToCqlVector(rawFloats3)), + rawFloats3), // Second: Base64-encoded representation (Base64 of 4-byte "packed" float values) - Arguments.of( - vector3Type, binaryWrapper(packedFloats3), CqlVectorUtil.floatsToCqlVector(rawFloats3)), - Arguments.of( - vector4Type, - binaryWrapper(packedFloats4), - CqlVectorUtil.floatsToCqlVector(rawFloats4))); + Arguments.of(vector3Type, binaryWrapper(packedFloats3), rawFloats3), + Arguments.of(vector4Type, binaryWrapper(packedFloats4), rawFloats4)); } private static JsonLiteral numberLiteral(Number value) {