Skip to content

Commit

Permalink
🅰️ feat: Add support for Azure OpenAI Embeddings (#4)
Browse files Browse the repository at this point in the history
Co-authored-by: fgoiriz <[email protected]>
  • Loading branch information
Fakamoto and fgoiriz authored Mar 21, 2024
1 parent 03cd625 commit 7287d20
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Binary file modified README.md
Binary file not shown.
12 changes: 10 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from dotenv import find_dotenv, load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from store_factory import get_vector_store

load_dotenv(find_dotenv())
Expand Down Expand Up @@ -55,13 +55,17 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:
## Credentials

OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "")
AZURE_OPENAI_API_KEY = get_env_variable("AZURE_OPENAI_API_KEY", "")
AZURE_OPENAI_ENDPOINT = get_env_variable("AZURE_OPENAI_ENDPOINT", "")
HF_TOKEN = get_env_variable("HF_TOKEN", "")

## Embeddings

def init_embeddings(provider, model):
if provider == "openai":
return OpenAIEmbeddings(model=model)
return OpenAIEmbeddings(model=model, api_key=OPENAI_API_KEY)
elif provider == "azure":
return AzureOpenAIEmbeddings(model=model, api_key=AZURE_OPENAI_API_KEY) # AZURE_OPENAI_ENDPOINT is being grabbed from the environment
elif provider == "huggingface":
return HuggingFaceEmbeddings(model_name=model, encode_kwargs={'normalize_embeddings': True})
else:
Expand All @@ -71,6 +75,10 @@ def init_embeddings(provider, model):

if EMBEDDINGS_PROVIDER == "openai":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "azure":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "huggingface":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "sangmini/msmarco-cotmae-MiniLM-L12_en-ko-ja")
else:
Expand Down

0 comments on commit 7287d20

Please sign in to comment.