Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AWS Bedrock Embedding Provider #1219

Merged
merged 19 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>2.26.12</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
Expand Down Expand Up @@ -134,6 +141,23 @@
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrockruntime</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sts</artifactId>
</dependency>
<dependency>
<groupId>com.bpodgursky</groupId>
<artifactId>jbool_expressions</artifactId>
<version>1.23</version>
</dependency>
<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
</dependency>
<dependency>
<groupId>com.datastax.oss</groupId>
<artifactId>java-driver-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
public class DataApiRequestInfo {
private final Optional<String> tenantId;
private final Optional<String> cassandraToken;
private final Optional<String> embeddingApiKey;
private final EmbeddingCredentials embeddingCredentials;

/**
* Constructor that will be useful in the offline library mode, where only the tenant will be set
Expand All @@ -28,7 +28,7 @@ public class DataApiRequestInfo {
public DataApiRequestInfo(Optional<String> tenantId) {
this.tenantId = tenantId;
this.cassandraToken = Optional.empty();
this.embeddingApiKey = Optional.empty();
this.embeddingCredentials = null;
}

@Inject
Expand All @@ -37,8 +37,8 @@ public DataApiRequestInfo(
SecurityContext securityContext,
Instance<DataApiTenantResolver> tenantResolver,
Instance<DataApiTokenResolver> tokenResolver,
Instance<EmbeddingApiKeyResolver> apiKeyResolver) {
this.embeddingApiKey = apiKeyResolver.get().resolveApiKey(routingContext);
Instance<EmbeddingCredentialsResolver> apiKeysResolver) {
this.embeddingCredentials = apiKeysResolver.get().resolveEmbeddingCredentials(routingContext);
this.tenantId = (tenantResolver.get()).resolve(routingContext, securityContext);
this.cassandraToken = (tokenResolver.get()).resolve(routingContext, securityContext);
}
Expand All @@ -51,7 +51,7 @@ public Optional<String> getCassandraToken() {
return this.cassandraToken;
}

public Optional<String> getEmbeddingApiKey() {
return this.embeddingApiKey;
public EmbeddingCredentials getEmbeddingCredentials() {
return this.embeddingCredentials;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.stargate.sgv2.jsonapi.api.request;

import java.util.Optional;

/**
* EmbeddingCredentials is a record that holds the embedding provider credentials for the embedding
* service passed as header.
*
* @param apiKey - API token for the embedding service
* @param accessId - Access Id used for AWS Bedrock embedding service
* @param secretId - Secret Id used for AWS Bedrock embedding service
*/
public record EmbeddingCredentials(
Optional<String> apiKey, Optional<String> accessId, Optional<String> secretId) {}
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package io.stargate.sgv2.jsonapi.api.request;

import io.vertx.ext.web.RoutingContext;
import java.util.Optional;

/** Functional interface to resolve the embedding api key from the request context. */
@FunctionalInterface
public interface EmbeddingApiKeyResolver {
Optional<String> resolveApiKey(RoutingContext context);
public interface EmbeddingCredentialsResolver {
EmbeddingCredentials resolveEmbeddingCredentials(RoutingContext context);
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.stargate.sgv2.jsonapi.api.request;

import io.vertx.core.http.HttpServerRequest;
import io.vertx.ext.web.RoutingContext;
import java.util.Objects;
import java.util.Optional;

/**
* Implementation to resolve the embedding api key, access id and secret id from the request header.
*/
public class HeaderBasedEmbeddingCredentialsResolver implements EmbeddingCredentialsResolver {
private final String tokenHeaderName;
private final String accessIdHeaderName;
private final String secretIdHeaderName;

public HeaderBasedEmbeddingCredentialsResolver(
String tokenHeaderName, String accessIdHeaderName, String secretIdHeaderName) {
this.tokenHeaderName =
Objects.requireNonNull(tokenHeaderName, "Token header name cannot be null");
this.accessIdHeaderName =
Objects.requireNonNull(accessIdHeaderName, "Access Id header name cannot be null");
this.secretIdHeaderName =
Objects.requireNonNull(secretIdHeaderName, "Secret Id header name cannot be null");
}

public EmbeddingCredentials resolveEmbeddingCredentials(RoutingContext context) {
HttpServerRequest request = context.request();
String headerValue = request.getHeader(this.tokenHeaderName);
String accessId = request.getHeader(this.accessIdHeaderName);
String secretId = request.getHeader(this.secretIdHeaderName);
return new EmbeddingCredentials(
Optional.ofNullable(headerValue),
Optional.ofNullable(accessId),
Optional.ofNullable(secretId));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,27 @@ public interface HttpConstants {
/** JSON API Embedding serive Authentication token header name. */
String EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME = "x-embedding-api-key";

/** JSON API Embedding serive access id header name. */
String EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME = "x-embedding-access-id";

/** JSON API Embedding serive secret id header name. */
String EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME = "x-embedding-secret-id";

/**
* @return Embedding service header name <code>20</code>.
* @return Embedding service header name for token.
*/
@WithDefault(EMBEDDING_AUTHENTICATION_TOKEN_HEADER_NAME)
String embeddingApiKey();

/**
* @return Embedding service header name for access id.
*/
@WithDefault(EMBEDDING_AUTHENTICATION_ACCESS_ID_HEADER_NAME)
String embeddingAccessId();

/**
* @return Embedding service header name for secret id.
*/
@WithDefault(EMBEDDING_AUTHENTICATION_SECRET_ID_HEADER_NAME)
String embeddingSecretId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ public enum ErrorCode {
"The replace document and document resolved using filter have different _id"),

/** Embedding provider service error codes. */
EMBEDDING_REQUEST_ENCODING_ERROR("Unable to create embedding provider request message"),
EMBEDDING_RESPONSE_DECODING_ERROR("Unable to parse embedding provider response message"),
EMBEDDING_PROVIDER_AUTHENTICATION_KEYS_NOT_PROVIDED(
"The Embedding Provider authentication keys not provided"),
EMBEDDING_PROVIDER_CLIENT_ERROR("The Embedding Provider returned a HTTP client error"),
EMBEDDING_PROVIDER_SERVER_ERROR("The Embedding Provider returned a HTTP server error"),
EMBEDDING_PROVIDER_RATE_LIMITED("The Embedding Provider rate limited the request"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression;
import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateClause;
import io.stargate.sgv2.jsonapi.api.model.command.clause.update.UpdateOperator;
import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentials;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.exception.ErrorCode;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
Expand All @@ -20,7 +21,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Utility class to execute embedding serive to get vector embeddings for the text fields in the
Expand All @@ -30,7 +30,7 @@
public class DataVectorizer {
private final EmbeddingProvider embeddingProvider;
private final JsonNodeFactory nodeFactory;
private final Optional<String> embeddingApiKey;
private final EmbeddingCredentials embeddingCredentials;
private final CollectionSettings collectionSettings;

/**
Expand All @@ -39,17 +39,17 @@ public class DataVectorizer {
* @param embeddingProvider - Service client based on embedding service configuration set for the
* table
* @param nodeFactory - Jackson node factory to create json nodes added to the document
* @param embeddingApiKey - Optional override embedding api key came in request header
* @param embeddingCredentials - Credentials for the embedding service
* @param collectionSettings - The collection setting for vectorize call
*/
public DataVectorizer(
EmbeddingProvider embeddingProvider,
JsonNodeFactory nodeFactory,
Optional<String> embeddingApiKey,
EmbeddingCredentials embeddingCredentials,
CollectionSettings collectionSettings) {
this.embeddingProvider = embeddingProvider;
this.nodeFactory = nodeFactory;
this.embeddingApiKey = embeddingApiKey;
this.embeddingCredentials = embeddingCredentials;
this.collectionSettings = collectionSettings;
}

Expand Down Expand Up @@ -111,7 +111,7 @@ public Uni<Boolean> vectorize(List<JsonNode> documents) {
.vectorize(
1,
vectorizeTexts,
embeddingApiKey,
embeddingCredentials,
EmbeddingProvider.EmbeddingRequestType.INDEX)
.map(res -> res.embeddings());
return vectors
Expand Down Expand Up @@ -178,7 +178,7 @@ public Uni<Boolean> vectorize(SortClause sortClause) {
.vectorize(
1,
List.of(text),
embeddingApiKey,
embeddingCredentials,
EmbeddingProvider.EmbeddingRequestType.SEARCH)
.map(res -> res.embeddings());
return vectors
Expand Down Expand Up @@ -264,7 +264,7 @@ private Uni<Boolean> updateVectorize(ObjectNode node) {
.vectorize(
1,
List.of(text),
embeddingApiKey,
embeddingCredentials,
EmbeddingProvider.EmbeddingRequestType.INDEX)
.map(res -> res.embeddings());
return vectors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public Uni<Command> vectorize(
new DataVectorizer(
embeddingProvider,
objectMapper.getNodeFactory(),
dataApiRequestInfo.getEmbeddingApiKey(),
dataApiRequestInfo.getEmbeddingCredentials(),
commandContext.collectionSettings());
return vectorizeSortClause(dataVectorizer, commandContext, command)
.onItem()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
package io.stargate.sgv2.jsonapi.service.embedding;

import io.stargate.sgv2.jsonapi.api.request.EmbeddingApiKeyResolver;
import io.stargate.sgv2.jsonapi.api.request.HeaderBasedEmbeddingApiKeyResolver;
import io.stargate.sgv2.jsonapi.api.request.EmbeddingCredentialsResolver;
import io.stargate.sgv2.jsonapi.api.request.HeaderBasedEmbeddingCredentialsResolver;
import io.stargate.sgv2.jsonapi.config.constants.HttpConstants;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.Produces;

/**
* Simple CDI producer for the {@link EmbeddingApiKeyResolver} to be used in the embedding service
* Simple CDI producer for the {@link EmbeddingCredentialsResolver} to be used in the embedding
* service
*/
@Singleton
public class EmbeddingApiKeyResolverProvider {
@Inject HttpConstants httpConstants;

@Produces
@ApplicationScoped
EmbeddingApiKeyResolver headerTokenResolver() {
return new HeaderBasedEmbeddingApiKeyResolver(httpConstants.embeddingApiKey());
EmbeddingCredentialsResolver headerTokenResolver() {
return new HeaderBasedEmbeddingCredentialsResolver(
httpConstants.embeddingApiKey(),
httpConstants.embeddingAccessId(),
httpConstants.embeddingSecretId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ interface EmbeddingProviderConfig {
@JsonProperty
boolean enabled();

@Nullable
@JsonProperty
String url();
Optional<String> url();

/**
* A map of supported authentications. HEADER, SHARED_SECRET and NONE are the only techniques
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public record EmbeddingProvidersConfigImpl(
public record EmbeddingProviderConfigImpl(
String displayName,
boolean enabled,
String url,
Optional<String> url,
Map<AuthenticationType, AuthenticationConfig> supportedAuthentications,
List<ParameterConfig> parameters,
RequestProperties properties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Produces;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;
import org.eclipse.microprofile.faulttolerance.Retry;
import org.slf4j.Logger;
Expand Down Expand Up @@ -144,7 +141,9 @@ private EmbeddingProvidersConfig grpcResponseToConfig(
new EmbeddingProvidersConfigImpl.EmbeddingProviderConfigImpl(
grpcProviderConfig.getDisplayName(),
grpcProviderConfig.getEnabled(),
grpcProviderConfig.getUrl(),
grpcProviderConfig.hasUrl()
? Optional.of(grpcProviderConfig.getUrl())
: Optional.empty(),
supportedAuthenticationsMap,
providerParameterList,
providerRequestProperties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public EmbeddingProviderConfigStore.ServiceConfig getConfiguration(
return ServiceConfig.provider(
serviceName,
serviceName,
config.providers().get(serviceName).url().toString(),
config.providers().get(serviceName).url().orElse(null),
RequestProperties.of(
properties.atMostRetries(),
properties.initialBackOffMillis(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public final class ProviderConstants {
public static final String JINA_AI = "jinaAI";
public static final String CUSTOM = "custom";
public static final String MISTRAL = "mistral";
public static final String BEDROCK = "bedrock";

// Private constructor to prevent instantiation
private ProviderConstants() {}
Expand Down
Loading