Skip to content

Commit

Permalink
Merge branch 'main' into hazel/ann_limit_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tatu-at-datastax authored Jan 3, 2025
2 parents 5e13d28 + 902f9d0 commit 6b7e257
Show file tree
Hide file tree
Showing 20 changed files with 545 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@JsonSubTypes.Type(value = InsertManyCommand.class),
@JsonSubTypes.Type(value = UpdateManyCommand.class),
@JsonSubTypes.Type(value = UpdateOneCommand.class),
// We have only collection resource that is used for api tables
// We have only collection resource that is used for API Tables
@JsonSubTypes.Type(value = AlterTableCommand.class),
@JsonSubTypes.Type(value = CreateIndexCommand.class),
@JsonSubTypes.Type(value = CreateVectorIndexCommand.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.eclipse.microprofile.openapi.annotations.enums.SchemaType;
import org.eclipse.microprofile.openapi.annotations.media.Schema;

@Schema(description = "Command that creates an api table.")
@Schema(description = "Command that creates an API Table.")
@JsonTypeName(CommandName.Names.CREATE_TABLE)
public record CreateTableCommand(
@NotNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public ProjectionException(ErrorInstance errorInstance) {
}

public enum Code implements ErrorCode<ProjectionException> {
UNSUPPORTED_COLUMN_TYPES;
UNSUPPORTED_COLUMN_TYPES,
UNKNOWN_TABLE_COLUMNS;

private final ErrorTemplate<ProjectionException> template;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector.SubtypeOnlyFloatVectorToArrayCodec;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -148,6 +149,9 @@ private CqlSession getNewSession(SessionCacheKey cacheKey) {
builder.addContactPoints(seeds);
}

// Add optimized CqlVector codec (see [data-api#1775])
builder = builder.addTypeCodecs(SubtypeOnlyFloatVectorToArrayCodec.instance());

// aaron - this used to have an if / else that threw an exception if the database type was not
// known but we test that when creating the credentials for the cache key so no need to do it
// here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ private Uni<SchemaObject> loadSchemaObject(
optionalTable.orElseThrow(
() -> ErrorCodeV1.COLLECTION_NOT_EXIST.toApiException("%s", collectionName));

// check if its a valid json api table
// check if its a valid json API Table
// TODO: re-use the table matcher this is on the request hot path
if (new CollectionTableMatcher().test(table)) {
return CollectionSchemaObject.getCollectionSettings(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector;

import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.internal.core.type.codec.FloatCodec;
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterators;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;

/**
* Implementation of {@link TypeCodec} which translates CQL vectors into float arrays. Difference
* between this and {@link
* com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec} is that
* we don't concern ourselves with the dimensionality specified in the input CQL type. This codec
* just reads all the bytes, tries to deserislize them consecutively into subtypes and then returns
* the result. Serialiation is similar: we take the input array, serialize each element and return
* the result.
*/
public class SubtypeOnlyFloatVectorToArrayCodec implements TypeCodec<float[]> {

private static final int ELEMENT_SIZE = 4;

protected final VectorType cqlType;
protected final GenericType<float[]> javaType;

private final FloatCodec floatCodec = new FloatCodec();

private static final SubtypeOnlyFloatVectorToArrayCodec INSTANCE =
new SubtypeOnlyFloatVectorToArrayCodec(DataTypes.FLOAT);

private SubtypeOnlyFloatVectorToArrayCodec(DataType subType) {
cqlType = new SubtypeOnlyVectorType(Objects.requireNonNull(subType, "subType cannot be null"));
javaType = GenericType.of(float[].class);
}

public static TypeCodec<float[]> instance() {
return INSTANCE;
}

@Override
public GenericType<float[]> getJavaType() {
return javaType;
}

@Override
public DataType getCqlType() {
return cqlType;
}

@Override
public boolean accepts(Class<?> javaClass) {
return float[].class.equals(javaClass);
}

@Override
public boolean accepts(Object value) {
return value instanceof float[];
}

@Override
public boolean accepts(DataType value) {
if (!(value instanceof VectorType)) {
return false;
}
VectorType valueVectorType = (VectorType) value;
return this.cqlType.getElementType().equals(valueVectorType.getElementType());
}

@Override
public ByteBuffer encode(float[] array, ProtocolVersion protocolVersion) {
if (array == null) {
return null;
}
int length = array.length;
int totalSize = length * ELEMENT_SIZE;
ByteBuffer output = ByteBuffer.allocate(totalSize);
for (int i = 0; i < length; i++) {
serializeElement(output, array, i, protocolVersion);
}
output.flip();
return output;
}

@Override
public float[] decode(ByteBuffer bytes, ProtocolVersion protocolVersion) {
if (bytes == null || bytes.remaining() == 0) {
throw new IllegalArgumentException(
"Input ByteBuffer must not be null and must have non-zero remaining bytes");
}
// TODO: Do we want to treat this as an error? We could also just ignore any extraneous bytes
// if they appear.
if (bytes.remaining() % ELEMENT_SIZE != 0) {
throw new IllegalArgumentException(
String.format("Input ByteBuffer should have a multiple of %d bytes", ELEMENT_SIZE));
}
ByteBuffer input = bytes.duplicate();
int elementCount = input.remaining() / 4;
float[] array = new float[elementCount];
for (int i = 0; i < elementCount; i++) {
deserializeElement(input, array, i, protocolVersion);
}
return array;
}

/**
* Write the {@code index}th element of {@code array} to {@code output}.
*
* @param output The ByteBuffer to write to.
* @param array The array to read from.
* @param index The element index.
* @param protocolVersion The protocol version to use.
*/
protected void serializeElement(
ByteBuffer output, float[] array, int index, ProtocolVersion protocolVersion) {
output.putFloat(array[index]);
}

/**
* Read the {@code index}th element of {@code array} from {@code input}.
*
* @param input The ByteBuffer to read from.
* @param array The array to write to.
* @param index The element index.
* @param protocolVersion The protocol version to use.
*/
protected void deserializeElement(
ByteBuffer input, float[] array, int index, ProtocolVersion protocolVersion) {
array[index] = input.getFloat();
}

@Override
public String format(float[] value) {
return value == null ? "NULL" : Arrays.toString(value);
}

@Override
public float[] parse(String str) {
/* TODO: Logic below requires a double traversal through the input String but there's no other obvious way to
* get the size. It's still probably worth the initial pass through in order to avoid having to deal with
* resizing ops. Fortunately we're only dealing with the format/parse pair here so this shouldn't impact
* general performance much. */
if ((str == null) || str.isEmpty()) {
throw new IllegalArgumentException("Cannot create float array from null or empty string");
}
Iterable<String> strIterable =
Splitter.on(", ").trimResults().split(str.substring(1, str.length() - 1));
float[] rv = new float[Iterators.size(strIterable.iterator())];
Iterator<String> strIterator = strIterable.iterator();
for (int i = 0; i < rv.length; ++i) {
String strVal = strIterator.next();
if (strVal == null || strVal.isBlank()) {
throw new IllegalArgumentException("Null element observed in float array string");
}
rv[i] = floatCodec.parse(strVal).floatValue();
}
return rv;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor.optvector;

import com.datastax.oss.driver.api.core.detach.AttachmentPoint;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.datastax.oss.driver.internal.core.type.DefaultVectorType;
import java.util.Objects;

/**
* An implementation of {@link VectorType} which is only concerned with the subtype of the vector.
* Useful if you want to describe a call of vector types that do not differ by subtype but do differ
* by dimension.
*/
public class SubtypeOnlyVectorType extends DefaultVectorType {
private static final int NO_DIMENSION = -1;

public SubtypeOnlyVectorType(DataType subtype) {
super(subtype, NO_DIMENSION);
}

@Override
public int getDimensions() {
throw new UnsupportedOperationException("Subtype-only vectors do not support dimensions");
}

/* ============== General class implementation ============== */
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
return (o instanceof VectorType that) && that.getElementType().equals(getElementType());
}

@Override
public int hashCode() {
return super.hashCode() ^ Objects.hashCode(getElementType());
}

@Override
public String toString() {
return String.format("(Subtype-only) Vector(%s)", getElementType());
}

@Override
public boolean isDetached() {
return false;
}

@Override
public void attach(AttachmentPoint attachmentPoint) {
// nothing to do
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ public <JavaT, CqlT> JSONCodec<JavaT, CqlT> codecToCQL(
throw new ToCQLCodecException(value, columnType, "only Vector<Float> supported");
}
if (value instanceof Collection<?>) {
return VectorCodecs.arrayToCQLFloatVectorCodec(vt);
return VectorCodecs.arrayToCQLFloatArrayCodec(vt);
}
if (value instanceof EJSONWrapper) {
return VectorCodecs.binaryToCQLFloatVectorCodec(vt);
return VectorCodecs.binaryToCQLFloatArrayCodec(vt);
}

throw new ToCQLCodecException(value, columnType, "no codec matching value type");
Expand Down
Loading

0 comments on commit 6b7e257

Please sign in to comment.