From 8033f526fdb6c27f6c5bdd48d476d87dcb94bf3a Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Thu, 21 Sep 2023 11:33:46 -0400 Subject: [PATCH] Feature/fix embedding engine (#22) * fix the embedding engine * fix db name conv --- README.md | 2 ++ sciphi/examples/populate_chroma/runner.py | 21 ++++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 37bde6c..a95c8f2 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ poetry install -E - `openai_support`: For running with OpenAI models. - `vllm_support`: For with VLLM, useful for fast inference. - `llama_index_support`: For LlamaIndex, useful for grounded synthesis. +- `chroma_support`: For Chroma support, used for large vector databases. +- `all`: For all dependencies (ex-vllm, which requires a separate install). ## Usage diff --git a/sciphi/examples/populate_chroma/runner.py b/sciphi/examples/populate_chroma/runner.py index 21d7bcb..259b5e2 100644 --- a/sciphi/examples/populate_chroma/runner.py +++ b/sciphi/examples/populate_chroma/runner.py @@ -1,3 +1,4 @@ +"""A module for populating ChromaDB with a dataset.""" import os import chromadb @@ -17,18 +18,30 @@ def chunk_text(text: str, chunk_size: int) -> list[str]: if __name__ == "__main__": + # TODO - Move to proper CLI based approach, like Fire. + # For now, we are getting it running quick and dirty. + + # Chroma environment variables chroma_addr = os.environ["CHROMA_REMOTE_ADDR"] chroma_port = os.environ["CHROMA_REMOTE_PORT"] chroma_token = os.environ["CHROMA_TOKEN"] chroma_auth_provider = os.environ["CHROMA_AUTH_PROVIDER"] - openai_api_key = os.environ["OPENAI_API_KEY"] + # OpenAI environment variables + openai_api_key = os.environ["OPENAI_API_KEY"] openai.api_key = openai_api_key + embedding_engine = "text-embedding-ada-002" + + # HF dataset dataset_name = "vikp/pypi_clean" + + # Script variables chunk_size = 2048 batch_size = 64 sample_log_interval = 10 - collection_name = f"{dataset_name.replace('/', '_')}_chunk_size_eq_2048" + collection_name = ( + f"{dataset_name.replace('/', '_')}_chunk_size_eq__{chunk_size}" + ) log_level = "INFO" logger = get_configured_logger("populate_chroma_db", log_level) @@ -104,7 +117,9 @@ def chunk_text(text: str, chunk_size: int) -> list[str]: continue buffer["documents"].extend(chunks) - buffer["embeddings"].extend(get_embeddings(chunks)) + buffer["embeddings"].extend( + get_embeddings(chunks, engine=embedding_engine) + ) buffer["metadatas"].extend( [ {