Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding with ollama #70

Merged
merged 3 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ AZURE_OPENAI_CHAT_MODEL=gpt-35-turbo
AZURE_OPENAI_EMBED_DEPLOYMENT=text-embedding-ada-002
AZURE_OPENAI_EMBED_MODEL=text-embedding-ada-002
AZURE_OPENAI_EMBED_MODEL_DIMENSIONS=1536
AZURE_OPENAI_EMBEDDING_COLUMN=embedding_ada002
# Only needed when using key-based Azure authentication:
AZURE_OPENAI_KEY=
# Needed for OpenAI.com:
OPENAICOM_KEY=YOUR-OPENAI-API-KEY
OPENAICOM_CHAT_MODEL=gpt-3.5-turbo
OPENAICOM_EMBED_MODEL=text-embedding-ada-002
OPENAICOM_EMBED_MODEL_DIMENSIONS=1536
OPENAICOM_EMBEDDING_COLUMN=embedding_ada002
# Needed for Ollama:
OLLAMA_ENDPOINT=http://host.docker.internal:11434/v1
OLLAMA_CHAT_MODEL=phi3:3.8b
OLLAMA_CHAT_MODEL=llama3.1
OLLAMA_EMBED_MODEL=nomic-embed-text
OLLAMA_EMBEDDING_COLUMN=embedding_nomic
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@
"ssl": true
}
}
]
],
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
14 changes: 12 additions & 2 deletions src/backend/fastapi_app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ class FastAPIAppContext(BaseModel):

openai_chat_model: str
openai_embed_model: str
openai_embed_dimensions: int
openai_embed_dimensions: int | None
openai_chat_deployment: str | None
openai_embed_deployment: str | None
embedding_column: str


async def common_parameters():
Expand All @@ -43,16 +44,24 @@ async def common_parameters():
openai_embed_deployment = os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT", "text-embedding-ada-002")
openai_embed_model = os.getenv("AZURE_OPENAI_EMBED_MODEL", "text-embedding-ada-002")
openai_embed_dimensions = int(os.getenv("AZURE_OPENAI_EMBED_DIMENSIONS", 1536))
embedding_column = os.getenv("AZURE_OPENAI_EMBEDDING_COLUMN", "embedding_ada002")
elif OPENAI_EMBED_HOST == "ollama":
openai_embed_deployment = None
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
openai_embed_dimensions = None
embedding_column = os.getenv("OLLAMA_EMBEDDING_COLUMN", "embedding_nomic")
else:
openai_embed_deployment = "text-embedding-ada-002"
openai_embed_deployment = None
openai_embed_model = os.getenv("OPENAICOM_EMBED_MODEL", "text-embedding-ada-002")
openai_embed_dimensions = int(os.getenv("OPENAICOM_EMBED_DIMENSIONS", 1536))
embedding_column = os.getenv("OPENAICOM_EMBEDDING_COLUMN", "embedding_ada002")
if OPENAI_CHAT_HOST == "azure":
openai_chat_deployment = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT", "gpt-35-turbo")
openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL", "gpt-35-turbo")
elif OPENAI_CHAT_HOST == "ollama":
openai_chat_deployment = None
openai_chat_model = os.getenv("OLLAMA_CHAT_MODEL", "phi3:3.8b")
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
else:
openai_chat_deployment = None
openai_chat_model = os.getenv("OPENAICOM_CHAT_MODEL", "gpt-3.5-turbo")
Expand All @@ -62,6 +71,7 @@ async def common_parameters():
openai_embed_dimensions=openai_embed_dimensions,
openai_chat_deployment=openai_chat_deployment,
openai_embed_deployment=openai_embed_deployment,
embedding_column=embedding_column,
)


Expand Down
9 changes: 7 additions & 2 deletions src/backend/fastapi_app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def compute_text_embedding(
openai_client: AsyncOpenAI | AsyncAzureOpenAI,
embed_model: str,
embed_deployment: str | None = None,
embedding_dimensions: int = 1536,
embedding_dimensions: int | None = None,
) -> list[float]:
SUPPORTED_DIMENSIONS_MODEL = {
"text-embedding-ada-002": False,
Expand All @@ -21,7 +21,12 @@ async def compute_text_embedding(
class ExtraArgs(TypedDict, total=False):
dimensions: int

dimensions_args: ExtraArgs = {"dimensions": embedding_dimensions} if SUPPORTED_DIMENSIONS_MODEL[embed_model] else {}
dimensions_args: ExtraArgs = {}
if SUPPORTED_DIMENSIONS_MODEL.get(embed_model):
if embedding_dimensions is None:
raise ValueError(f"Model {embed_model} requires embedding dimensions")
else:
dimensions_args = {"dimensions": embedding_dimensions}

embedding = await openai_client.embeddings.create(
# Azure OpenAI takes the deployment name as the model name
Expand Down
8 changes: 7 additions & 1 deletion src/backend/fastapi_app/openai_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ async def create_openai_embed_client(
azure_deployment=azure_deployment,
azure_ad_token_provider=token_provider,
)

elif OPENAI_EMBED_HOST == "ollama":
logger.info("Authenticating to OpenAI using Ollama...")
openai_embed_client = openai.AsyncOpenAI(
base_url=os.getenv("OLLAMA_ENDPOINT"),
api_key="nokeyneeded",
)
else:
logger.info("Authenticating to OpenAI using OpenAI.com API key...")
openai_embed_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAICOM_KEY"))
return openai_embed_client
2 changes: 1 addition & 1 deletion src/backend/fastapi_app/postgres_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_password_from_azure_credential():

engine = create_async_engine(
DATABASE_URI,
echo=False,
echo=True,
)

@event.listens_for(engine.sync_engine, "do_connect")
Expand Down
25 changes: 18 additions & 7 deletions src/backend/fastapi_app/postgres_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ class Item(Base):
name: Mapped[str] = mapped_column()
description: Mapped[str] = mapped_column()
price: Mapped[float] = mapped_column()
embedding: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
embedding_ada002: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
embedding_nomic: Mapped[Vector] = mapped_column(Vector(768)) # nomic-embed-text

def to_dict(self, include_embedding: bool = False):
model_dict = asdict(self)
if include_embedding:
model_dict["embedding"] = model_dict["embedding"].tolist()
model_dict["embedding_ada002"] = model_dict.get("embedding_ada002", [])
model_dict["embedding_nomic"] = model_dict.get("embedding_nomic", [])
else:
del model_dict["embedding"]
del model_dict["embedding_ada002"]
del model_dict["embedding_nomic"]
return model_dict

def to_str_for_rag(self):
Expand All @@ -38,10 +41,18 @@ def to_str_for_embedding(self):


# Define HNSW index to support vector similarity search through the vector_cosine_ops access method (cosine distance).
index = Index(
"hnsw_index_for_innerproduct_item_embedding",
Item.embedding,
index_ada002 = Index(
"hnsw_index_for_innerproduct_item_embedding_ada002",
Item.embedding_ada002,
postgresql_using="hnsw",
postgresql_with={"m": 16, "ef_construction": 64},
postgresql_ops={"embedding": "vector_ip_ops"},
postgresql_ops={"embedding_ada002": "vector_ip_ops"},
)

index_nomic = Index(
"hnsw_index_for_innerproduct_item_embedding_nomic",
Item.embedding_nomic,
postgresql_using="hnsw",
postgresql_with={"m": 16, "ef_construction": 64},
postgresql_ops={"embedding_nomic": "vector_ip_ops"},
)
14 changes: 6 additions & 8 deletions src/backend/fastapi_app/postgres_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def __init__(
openai_embed_client: AsyncOpenAI | AsyncAzureOpenAI,
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embed_model: str,
embed_dimensions: int,
embed_dimensions: int | None,
embedding_column: str,
):
self.db_session = db_session
self.openai_embed_client = openai_embed_client
self.embed_model = embed_model
self.embed_deployment = embed_deployment
self.embed_dimensions = embed_dimensions
self.embedding_column = embedding_column

def build_filter_clause(self, filters) -> tuple[str, str]:
if filters is None:
Expand All @@ -36,19 +38,15 @@ def build_filter_clause(self, filters) -> tuple[str, str]:
return "", ""

async def search(
self,
query_text: str | None,
query_vector: list[float] | list,
top: int = 5,
filters: list[dict] | None = None,
self, query_text: str | None, query_vector: list[float] | list, top: int = 5, filters: list[dict] | None = None
):
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)

vector_query = f"""
SELECT id, RANK () OVER (ORDER BY embedding <=> :embedding) AS rank
SELECT id, RANK () OVER (ORDER BY {self.embedding_column} <=> :embedding) AS rank
FROM items
{filter_clause_where}
ORDER BY embedding <=> :embedding
ORDER BY {self.embedding_column} <=> :embedding
LIMIT 20
"""

Expand Down
12 changes: 9 additions & 3 deletions src/backend/fastapi_app/routes/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,18 @@ async def item_handler(database_session: DBSession, id: int) -> ItemPublic:


@router.get("/similar", response_model=list[ItemWithDistance])
async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[ItemWithDistance]:
async def similar_handler(
context: CommonDeps, database_session: DBSession, id: int, n: int = 5
) -> list[ItemWithDistance]:
"""A similarity API to find items similar to items with given ID."""
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
if not item:
raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404)

closest = await database_session.execute(
select(Item, Item.embedding.l2_distance(item.embedding))
select(Item, Item.embedding_ada002.l2_distance(item.embedding_ada002))
.filter(Item.id != id)
.order_by(Item.embedding.l2_distance(item.embedding))
.order_by(Item.embedding_ada002.l2_distance(item.embedding_ada002))
.limit(n)
)
return [
Expand All @@ -78,6 +81,7 @@ async def search_handler(
embed_deployment=context.openai_embed_deployment,
embed_model=context.openai_embed_model,
embed_dimensions=context.openai_embed_dimensions,
embedding_column=context.embedding_column,
)
results = await searcher.search_and_embed(
query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search
Expand All @@ -99,6 +103,7 @@ async def chat_handler(
embed_deployment=context.openai_embed_deployment,
embed_model=context.openai_embed_model,
embed_dimensions=context.openai_embed_dimensions,
embedding_column=context.embedding_column,
)
rag_flow: SimpleRAGChat | AdvancedRAGChat
if chat_request.context.overrides.use_advanced_flow:
Expand Down Expand Up @@ -139,6 +144,7 @@ async def chat_stream_handler(
embed_deployment=context.openai_embed_deployment,
embed_model=context.openai_embed_model,
embed_dimensions=context.openai_embed_dimensions,
embedding_column=context.embedding_column,
)

rag_flow: SimpleRAGChat | AdvancedRAGChat
Expand Down
234,823 changes: 233,915 additions & 908 deletions src/backend/fastapi_app/seed_data.json

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions src/backend/fastapi_app/setup_postgres_seeddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@ async def seed_data(engine):
with open(os.path.join(current_dir, "seed_data.json")) as f:
catalog_items = json.load(f)
for catalog_item in catalog_items:
db_item = await session.execute(select(Item).filter(Item.id == catalog_item["Id"]))
db_item = await session.execute(select(Item).filter(Item.id == catalog_item["id"]))
if db_item.scalars().first():
continue
item = Item(
id=catalog_item["Id"],
type=catalog_item["Type"],
brand=catalog_item["Brand"],
name=catalog_item["Name"],
description=catalog_item["Description"],
price=catalog_item["Price"],
embedding=catalog_item["Embedding"],
id=catalog_item["id"],
type=catalog_item["type"],
brand=catalog_item["brand"],
name=catalog_item["name"],
description=catalog_item["description"],
price=catalog_item["price"],
embedding_ada002=catalog_item["embedding_ada002"],
embedding_nomic=catalog_item.get("embedding_nomic"),
)
session.add(item)
try:
Expand Down
64 changes: 57 additions & 7 deletions src/backend/fastapi_app/update_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import asyncio
import json
import logging
import os

from dotenv import load_dotenv
from sqlalchemy import select
Expand All @@ -10,28 +13,75 @@
from fastapi_app.postgres_engine import create_postgres_engine_from_env
from fastapi_app.postgres_models import Item

logger = logging.getLogger("ragapp")

async def update_embeddings():

async def update_embeddings(in_seed_data=False):
azure_credential = await get_azure_credentials()
engine = await create_postgres_engine_from_env(azure_credential)
openai_embed_client = await create_openai_embed_client(azure_credential)
common_params = await common_parameters()

async with async_sessionmaker(engine, expire_on_commit=False)() as session:
async with session.begin():
items = (await session.scalars(select(Item))).all()

for item in items:
item.embedding = await compute_text_embedding(
embedding_column = ""
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
if OPENAI_EMBED_HOST == "azure":
embedding_column = os.getenv("AZURE_OPENAI_EMBEDDING_COLUMN", "embedding_ada002")
elif OPENAI_EMBED_HOST == "ollama":
embedding_column = os.getenv("OLLAMA_EMBEDDING_COLUMN", "embedding_nomic")
else:
embedding_column = os.getenv("OPENAICOM_EMBEDDING_COLUMN", "embedding_ada002")
logger.info(f"Updating embeddings in column: {embedding_column}")
if in_seed_data:
current_dir = os.path.dirname(os.path.realpath(__file__))
items = []
with open(os.path.join(current_dir, "seed_data.json")) as f:
catalog_items = json.load(f)
for catalog_item in catalog_items:
item = Item(
id=catalog_item["id"],
type=catalog_item["type"],
brand=catalog_item["brand"],
name=catalog_item["name"],
description=catalog_item["description"],
price=catalog_item["price"],
embedding_ada002=catalog_item["embedding_ada002"],
embedding_nomic=catalog_item.get("embedding_nomic"),
)
embedding = await compute_text_embedding(
item.to_str_for_embedding(),
openai_client=openai_embed_client,
embed_model=common_params.openai_embed_model,
embed_deployment=common_params.openai_embed_deployment,
embedding_dimensions=common_params.openai_embed_dimensions,
)
setattr(item, embedding_column, embedding)
items.append(item)
# write to the file
with open(os.path.join(current_dir, "seed_data.json"), "w") as f:
json.dump([item.to_dict(include_embedding=True) for item in items], f, indent=4)
return

async with async_sessionmaker(engine, expire_on_commit=False)() as session:
async with session.begin():
items_to_update = (await session.scalars(select(Item))).all()

for item in items_to_update:
setattr(
item,
embedding_column,
await compute_text_embedding(
item.to_str_for_embedding(),
openai_client=openai_embed_client,
embed_model=common_params.openai_embed_model,
embed_deployment=common_params.openai_embed_deployment,
embedding_dimensions=common_params.openai_embed_dimensions,
),
)
await session.commit()


if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logger.setLevel(logging.INFO)
load_dotenv(override=True)
asyncio.run(update_embeddings())
Loading
Loading