Skip to content

Commit

Permalink
[NOID] Fixes #3634: Updated ML procs for Azure OpenAI services (#3850) (
Browse files Browse the repository at this point in the history
#3863) (#3885)

* Fixes #3634: Updated ML procs for Azure OpenAI services

* Code clean

* added enpoint env vars

* Code clean part 2

* removed unused imports

---------

Co-authored-by: Andrea Santurbano <[email protected]>
  • Loading branch information
vga91 and conker84 committed Nov 28, 2024
1 parent 83b0e50 commit 094b2bf
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 44 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ plugins {

ext {
// Move to gradle.properties (but versions.json generation build depends on this right now)
neo4jVersion = '4.4.40'
neo4jVersion = '4.4.39'
publicDir = "${project.rootDir}"
neo4jVersionEffective = project.hasProperty("neo4jVersionOverride") ? project.getProperty("neo4jVersionOverride") : neo4jVersion
neo4jDockerVersion = project.hasProperty("neo4jDockerVersionOverride") ? project.getProperty("neo4jDockerVersionOverride") : neo4jVersion
Expand Down
9 changes: 9 additions & 0 deletions full/src/main/java/apoc/ml/MLUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package apoc.ml;

public class MLUtil {
public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";
public static final String MODEL_CONF_KEY = "model";
public static final String API_TYPE_CONF_KEY = "apiType";
public static final String APIKEY_CONF_KEY = "apiKey";
}
47 changes: 26 additions & 21 deletions full/src/main/java/apoc/ml/MixedbreadAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,48 @@

import apoc.ApocConfig;
import apoc.result.ObjectResult;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;

import static apoc.ApocConfig.APOC_ML_OPENAI_URL;
import static apoc.ml.MLUtil.API_TYPE_CONF_KEY;
import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY;
import static apoc.ml.MLUtil.MODEL_CONF_KEY;
import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.APOC_ML_OPENAI_URL;

public class MixedbreadAI {

public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String DEFAULT_MODEL_ID = "mxbai-embed-large-v1";
public static final String MIXEDBREAD_BASE_URL = "https://api.mixedbread.ai/v1";
public static final String ERROR_MSG_MISSING_ENDPOINT = "The endpoint must be defined via config `%s` or via apoc.conf `%s`"
.formatted(ENDPOINT_CONF_KEY, APOC_ML_OPENAI_URL);
public static final String ERROR_MSG_MISSING_ENDPOINT = String.format("The endpoint must be defined via config `%s` or via apoc.conf `%s`",
ENDPOINT_CONF_KEY, APOC_ML_OPENAI_URL);

public static final String ERROR_MSG_MISSING_MODELID = "The model must be defined via config `%s`"
.formatted(MODEL_CONF_KEY);
public static final String ERROR_MSG_MISSING_MODELID = String.format("The model must be defined via config `%s`",
MODEL_CONF_KEY);


/**
* embedding is an Object instead of List<Double>, as with a Mixedbread request having `"encoding_format": [<multipleFormat>]`,
* the result can be e.g. {... "embedding": { "float": [<floatEmbedding>], "base": <base64Embedding>, } ...}
* instead of e.g. {... "embedding": [<floatEmbedding>] ...}
*/
public record EmbeddingResult(long index, String text, Object embedding) { }

@Context
public URLAccessChecker urlAccessChecker;
* embedding is an Object instead of List<Double>, as with a Mixedbread request having `"encoding_format": [<multipleFormat>]`,
* the result can be e.g. {... "embedding": { "float": [<floatEmbedding>], "base": <base64Embedding>, } ...}
* instead of e.g. {... "embedding": [<floatEmbedding>] ...}
*/
public static final class EmbeddingResult {
public final long index;
public final String text;
public final Object embedding;

public EmbeddingResult(long index, String text, Object embedding) {
this.index = index;
this.text = text;
this.embedding = embedding;
}
}

@Context
public ApocConfig apocConfig;
Expand All @@ -53,8 +60,7 @@ public Stream<ObjectResult> custom(@Name("api_key") String apiKey, @Name(value =

return OpenAI.executeRequest(apiKey, configuration,
null, null, null, null, null,
apocConfig,
urlAccessChecker)
apocConfig)
.map(ObjectResult::new);
}

Expand All @@ -67,12 +73,11 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts,
configuration.putIfAbsent(MODEL_CONF_KEY, DEFAULT_MODEL_ID);

configuration.put(API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.MIXEDBREAD_EMBEDDING.name());
return OpenAI.getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker,
return OpenAI.getEmbeddingResult(texts, apiKey, configuration, apocConfig,
(map, text) -> {
Long index = (Long) map.get("index");
return new EmbeddingResult(index, text, map.get("embedding"));
},
m -> new EmbeddingResult(-1, m, List.of())
}
);

}
Expand Down
54 changes: 38 additions & 16 deletions full/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import static apoc.ApocConfig.APOC_ML_OPENAI_TYPE;
import static apoc.ApocConfig.APOC_OPENAI_KEY;
import static apoc.ml.MLUtil.APIKEY_CONF_KEY;
import static apoc.ml.MLUtil.API_TYPE_CONF_KEY;
import static apoc.ml.MLUtil.API_VERSION_CONF_KEY;
import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY;
import static apoc.ml.MLUtil.MODEL_CONF_KEY;

import apoc.ApocConfig;
import apoc.Extended;
Expand All @@ -13,7 +18,10 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import org.jetbrains.annotations.NotNull;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
Expand All @@ -24,12 +32,6 @@ public class OpenAI {
@Context
public ApocConfig apocConfig;

public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url";
public static final String API_TYPE_CONF_KEY = "apiType";
public static final String APIKEY_CONF_KEY = "apiKey";
public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";

public static class EmbeddingResult {
public final long index;
public final String text;
Expand All @@ -56,20 +58,27 @@ static Stream<Object> executeRequest(
if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty");
String apiTypeString = (String) configuration.getOrDefault(
API_TYPE_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name()));
OpenAIRequestHandler apiType = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH))
.get();

final Map<String, Object> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
apiType.addApiKey(headers, apiKey);
OpenAIRequestHandler.Type type = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH));

var config = new HashMap<>(configuration);
// we remove these keys from config, since the json payload is calculated starting from the config map
Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY)
.forEach(config::remove);
config.putIfAbsent("model", model);
config.put(key, inputs);

switch (type) {
case MIXEDBREAD_CUSTOM:
// no payload manipulation, taken from the configuration as-is
break;
default:
config.putIfAbsent(MODEL_CONF_KEY, model);
config.put(key, inputs);
}
OpenAIRequestHandler apiType = type.get();

final Map<String, Object> headers = new HashMap<>();
headers.put("Content-Type", "application/json");

apiType.addApiKey(headers, apiKey);

String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(config);

// new URL(endpoint), path) can produce a wrong path, since endpoint can have for example embedding,
Expand Down Expand Up @@ -99,13 +108,26 @@ public Stream<EmbeddingResult> getEmbedding(
"model": "text-embedding-ada-002",
"usage": { "prompt_tokens": 8, "total_tokens": 8 } }
*/

return getEmbeddingResult(texts, apiKey, configuration, apocConfig,
(map, text) -> {
Long index = (Long) map.get("index");
return new EmbeddingResult(index, text, (List<Double>) map.get("embedding"));
}
);
}

public static <T> Stream<T> getEmbeddingResult(List<String> texts, String apiKey, Map<String, Object> configuration, ApocConfig apocConfig, BiFunction<Map, String, T> embeddingMapping)
throws JsonProcessingException, MalformedURLException {
Stream<Object> resultStream = executeRequest(
apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data", apocConfig);

return resultStream
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
.map(m -> {
Long index = (Long) m.get("index");
return new EmbeddingResult(index, texts.get(index.intValue()), (List<Double>) m.get("embedding"));
String text = texts.get(index.intValue());
return embeddingMapping.apply(m, text);
});
}

Expand Down
19 changes: 17 additions & 2 deletions full/src/main/java/apoc/ml/OpenAIRequestHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import static apoc.ApocConfig.APOC_ML_OPENAI_AZURE_VERSION;
import static apoc.ApocConfig.APOC_ML_OPENAI_URL;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.MLUtil.API_VERSION_CONF_KEY;
import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY;
import static apoc.ml.MixedbreadAI.ERROR_MSG_MISSING_ENDPOINT;
import static apoc.ml.MixedbreadAI.MIXEDBREAD_BASE_URL;

import apoc.ApocConfig;
import java.util.Map;
Expand Down Expand Up @@ -41,6 +43,8 @@ public String getFullUrl(String method, Map<String, Object> procConfig, ApocConf

enum Type {
AZURE(new Azure(null)),
MIXEDBREAD_EMBEDDING(new OpenAi(MIXEDBREAD_BASE_URL)),
MIXEDBREAD_CUSTOM(new Custom()),
OPENAI(new OpenAi("https://api.openai.com/v1"));

private final OpenAIRequestHandler handler;
Expand Down Expand Up @@ -89,4 +93,15 @@ public void addApiKey(Map<String, Object> headers, String apiKey) {
headers.put("Authorization", "Bearer " + apiKey);
}
}

static class Custom extends OpenAi {
public Custom() {
super(null);
}
@Override
public String getDefaultUrl() {
throw new RuntimeException(ERROR_MSG_MISSING_ENDPOINT);

}
}
}
2 changes: 1 addition & 1 deletion full/src/test/java/apoc/ml/MixedbreadAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import java.util.Map;
import java.util.Set;

import static apoc.ml.MLUtil.MODEL_CONF_KEY;
import static apoc.ml.MLUtil.*;
import static apoc.ml.MixedbreadAI.*;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand Down
3 changes: 0 additions & 3 deletions full/src/test/java/apoc/ml/OpenAIAzureIT.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package apoc.ml;

import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion;
import static apoc.ml.OpenAITestResultUtils.assertCompletion;
import static apoc.util.TestUtil.testCall;
Expand Down

0 comments on commit 094b2bf

Please sign in to comment.