Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
Feature/fix embedding engine (#22)
Browse files Browse the repository at this point in the history
* fix the embedding engine

* fix db name conv
  • Loading branch information
emrgnt-cmplxty authored Sep 21, 2023
1 parent 98fcf94 commit 8033f52
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ poetry install -E <extra_name>
- `openai_support`: For running with OpenAI models.
- `vllm_support`: For with VLLM, useful for fast inference.
- `llama_index_support`: For LlamaIndex, useful for grounded synthesis.
- `chroma_support`: For Chroma support, used for large vector databases.
- `all`: For all dependencies (ex-vllm, which requires a separate install).

## Usage

Expand Down
21 changes: 18 additions & 3 deletions sciphi/examples/populate_chroma/runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""A module for populating ChromaDB with a dataset."""
import os

import chromadb
Expand All @@ -17,18 +18,30 @@ def chunk_text(text: str, chunk_size: int) -> list[str]:


if __name__ == "__main__":
# TODO - Move to proper CLI based approach, like Fire.
# For now, we are getting it running quick and dirty.

# Chroma environment variables
chroma_addr = os.environ["CHROMA_REMOTE_ADDR"]
chroma_port = os.environ["CHROMA_REMOTE_PORT"]
chroma_token = os.environ["CHROMA_TOKEN"]
chroma_auth_provider = os.environ["CHROMA_AUTH_PROVIDER"]
openai_api_key = os.environ["OPENAI_API_KEY"]

# OpenAI environment variables
openai_api_key = os.environ["OPENAI_API_KEY"]
openai.api_key = openai_api_key
embedding_engine = "text-embedding-ada-002"

# HF dataset
dataset_name = "vikp/pypi_clean"

# Script variables
chunk_size = 2048
batch_size = 64
sample_log_interval = 10
collection_name = f"{dataset_name.replace('/', '_')}_chunk_size_eq_2048"
collection_name = (
f"{dataset_name.replace('/', '_')}_chunk_size_eq__{chunk_size}"
)
log_level = "INFO"

logger = get_configured_logger("populate_chroma_db", log_level)
Expand Down Expand Up @@ -104,7 +117,9 @@ def chunk_text(text: str, chunk_size: int) -> list[str]:
continue

buffer["documents"].extend(chunks)
buffer["embeddings"].extend(get_embeddings(chunks))
buffer["embeddings"].extend(
get_embeddings(chunks, engine=embedding_engine)
)
buffer["metadatas"].extend(
[
{
Expand Down

0 comments on commit 8033f52

Please sign in to comment.