Skip to content

Commit

Permalink
feat: add huggingface embedding options
Browse files Browse the repository at this point in the history
  • Loading branch information
danny-avila committed Mar 19, 2024
1 parent f1ab1ab commit b79f768
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
Binary file modified README.md
Binary file not shown.
34 changes: 32 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import logging
from dotenv import find_dotenv, load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_openai import OpenAIEmbeddings
from store_factory import get_vector_store

Expand Down Expand Up @@ -36,6 +37,8 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:
CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"
DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"

## Logging

logger = logging.getLogger()

debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true"
Expand All @@ -49,8 +52,35 @@ def get_env_variable(var_name: str, default_value: str = None) -> str:
handler.setFormatter(formatter)
logger.addHandler(handler)

OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY")
embeddings = OpenAIEmbeddings()
## Credentials

OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "")
HF_TOKEN = get_env_variable("HF_TOKEN", "")

## Embeddings

def init_embeddings(provider, model):
if provider == "openai":
return OpenAIEmbeddings(model=model)
elif provider == "huggingface":
return HuggingFaceEmbeddings(model_name=model, encode_kwargs={'normalize_embeddings': True})
else:
raise ValueError(f"Unsupported embeddings provider: {provider}")

EMBEDDINGS_PROVIDER = get_env_variable("EMBEDDINGS_PROVIDER", "openai").lower()

if EMBEDDINGS_PROVIDER == "openai":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")
elif EMBEDDINGS_PROVIDER == "huggingface":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "sangmini/msmarco-cotmae-MiniLM-L12_en-ko-ja")
else:
raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")

embeddings = init_embeddings(EMBEDDINGS_PROVIDER, EMBEDDINGS_MODEL)

logger.info(f"Initialized embeddings of type: {type(embeddings)}")

## Vector store

vector_store = get_vector_store(
connection_string=CONNECTION_STRING,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ pypandoc==1.13
python-jose==3.3.0
asyncpg==0.29.0
python-multipart==0.0.9
sentence_transformers==2.5.1

0 comments on commit b79f768

Please sign in to comment.