Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
Add chroma
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Sep 21, 2023
1 parent b431117 commit a2ad18d
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 8 deletions.
6 changes: 5 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
OPENAI_API_KEY=your_openai_key
ANTHROPIC_API_KEY=your_anthropic_key
HF_TOKEN=your_huggingface_token
HF_TOKEN=your_huggingface_token
CHROMA_REMOTE_ADDR=your_chroma_db_addr
CHROMA_REMOTE_PORT="8000" # default
CHROMA_TOKEN=your_chroma_db_token
CHROMA_AUTH_PROVIDER="chromadb.auth.token.TokenAuthClientProvider"
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ git clone https://github.com/emrgnt-cmplxty/sciphi.git
cd sciphi
# Install dependencies
# pip3 install poetry (if you don't have it)
poetry install -E openai_support
# Add other optional dependencies
# poetry install -E openai_support -E anthropic_support -E hf_support ...
poetry install -E all
# Setup your environment
cp .env.example .env && vim .env
```
Expand Down
21 changes: 18 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,36 @@ python = ">=3.10,<3.12"
python-dotenv = "^1.0.0"
pandas = "^2.1.0"
# Begin optional dependencies
accelerate = { version = "^0.23.0", optional = true }
# anthropic
anthropic = { version = "^0.3.10", optional = true }
# hf
accelerate = { version = "^0.23.0", optional = true }
datasets = { version = "^2.14.5", optional = true }
llama-index = { version = "^0.8.29.post1", optional = true }
openai = { version = "0.27.8", optional = true }
torch = { version = "^2.0.1", optional = true }
transformers = { version = "^4.33.1", optional = true }
# openai
openai = { version = "0.27.8", optional = true }
matplotlib = {version = "^3.8.0", optional = true}
plotly = {version = "^5.17.0", optional = true}
scipy = {version = "^1.11.2", optional = true}
scikit-learn = {version = "^1.3.1", optional = true}
# vllm
vllm = { version = "0.1.7", optional = true }
# llama-index
llama-index = { version = "^0.8.29.post1", optional = true }
# chroma
chromadb = { version = "^0.4.12", optional = true }

[tool.poetry.extras]
anthropic_support = ["anthropic"]
hf_support = ["accelerate", "datasets", "torch", "transformers"]
openai_support = ["openai"]
openai_utils_support = ["openai"]
vllm_support = ["accelerate", "torch", "vllm"]
llama_index_support = ["llama-index"]
chroma_support = ["chromadb"]
# omit "vllm", since it can cause failures w/out cuda
all = ["anthropic", "accelerate", "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb"]

[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
Expand Down
2 changes: 1 addition & 1 deletion sciphi/examples/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def parse_arguments() -> argparse.Namespace:
parser.add_argument(
"--max_tokens_to_sample",
type=int,
default=None,
default=1_024,
help="Max tokens to sample for each completion from the provided model.",
)
parser.add_argument(
Expand Down
Empty file.
133 changes: 133 additions & 0 deletions sciphi/examples/populate_chroma/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os

import chromadb
import dotenv
import openai
from chromadb.config import Settings
from datasets import load_dataset
from openai.embeddings_utils import get_embeddings

from sciphi.core.utils import get_configured_logger

dotenv.load_dotenv()


def chunk_text(text: str, chunk_size: int) -> list[str]:
return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)]


if __name__ == "__main__":
chroma_addr = os.environ["CHROMA_REMOTE_ADDR"]
chroma_port = os.environ["CHROMA_REMOTE_PORT"]
chroma_token = os.environ["CHROMA_TOKEN"]
chroma_auth_provider = os.environ["CHROMA_AUTH_PROVIDER"]
openai_api_key = os.environ["OPENAI_API_KEY"]

openai.api_key = openai_api_key
dataset_name = "vikp/pypi_clean"
chunk_size = 2048
batch_size = 64
sample_log_interval = 10
collection_name = f"{dataset_name.replace('/', '_')}_chunk_size_eq_2048"
log_level = "INFO"

logger = get_configured_logger("populate_chroma_db", log_level)
logger.info("Starting to populate ChromaDB")

if not chroma_token or not chroma_addr or not chroma_port:
raise ValueError(
f"ChromaDB environment variables not set correctly, found: chroma_token={chroma_token}, chroma_addr={chroma_addr}, chroma_port={chroma_port}"
)

if not openai_api_key:
raise ValueError(
"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
)

client = chromadb.HttpClient(
host=chroma_addr,
port=chroma_port,
settings=Settings(
chroma_client_auth_provider=chroma_auth_provider,
chroma_client_auth_credentials=chroma_token,
),
)

try:
collection = client.create_collection(name=collection_name)
except Exception as e:
logger.info(
f"Collection {collection_name} likely already exists, skipping creation. For completeness, here is the exception: {e}"
)
collection = client.get_collection(name=collection_name)

parsed_ids = collection.get(include=[])["ids"]

dataset = load_dataset(dataset_name, streaming=True)
n_samples_iter = 0

# Count the number of chunks we have already parsed
n_entries = len(parsed_ids)
if n_entries > 0:
logger.info(f"Loaded {n_entries} entries from ChromaDB")

buffer: dict[str, list] = {
"documents": [],
"embeddings": [],
"metadatas": [],
"ids": [],
}

for entry in dataset["train"]:
chunks = chunk_text(entry["code"], chunk_size)
raw_ids = [
f"id_{i}" for i in range(n_entries, n_entries + len(chunks))
]
n_entries += len(chunks)
n_samples_iter += 1
if n_samples_iter % sample_log_interval == 0:
logger.info(
f"Processed {n_samples_iter} samples, total chunks = {n_entries}"
)
logger.info(f"Current max id = {raw_ids[-1]}")
logger.info(
"Logging buffer info:\n"
+ "\n".join(
[
f"Sanity check -- There are {len(buffer[key])} entries in {key}"
for key in buffer
]
)
)
if set(raw_ids).issubset(set(parsed_ids)):
logger.debug(f"Skipping ids = {raw_ids} as they already exist")
continue

buffer["documents"].extend(chunks)
buffer["embeddings"].extend(get_embeddings(chunks))
buffer["metadatas"].extend(
[
{
"package": entry["package"],
"path": entry["path"],
"filename": entry["filename"],
}
]
* len(chunks)
)
buffer["ids"].extend(raw_ids)

if len(buffer["documents"]) >= batch_size:
logger.debug(f"Inserting ids = {buffer['ids']}")
collection.add(
embeddings=buffer["embeddings"],
documents=buffer["documents"],
metadatas=buffer["metadatas"],
ids=buffer["ids"],
)
buffer = {
"documents": [],
"embeddings": [],
"metadatas": [],
"ids": [],
}

0 comments on commit a2ad18d

Please sign in to comment.