Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
Feature/finish chroma runner (#24)
Browse files Browse the repository at this point in the history
* multithread chroma workflow

* Rebase
  • Loading branch information
emrgnt-cmplxty authored Sep 21, 2023
1 parent 8033f52 commit d2d2427
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 67 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ openai_utils_support = ["openai"]
vllm_support = ["accelerate", "torch", "vllm"]
llama_index_support = ["llama-index"]
chroma_support = ["chromadb"]
# omit "vllm", since it can cause failures w/out cuda
all = ["anthropic", "accelerate", "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb"]
# omit "accelerate" and "vllm", since it can cause failures w/out cuda
all = ["anthropic", "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb"]
all_with_cuda = ["anthropic", "accelerate" , "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb", "vllm"]

[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
Expand Down
172 changes: 107 additions & 65 deletions sciphi/examples/populate_chroma/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A module for populating ChromaDB with a dataset."""
import os
from concurrent.futures import ThreadPoolExecutor
from threading import current_thread

import chromadb
import dotenv
Expand All @@ -14,9 +16,91 @@


def chunk_text(text: str, chunk_size: int) -> list[str]:
"""Chunk a text into a list of strings of size chunk_size."""
return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)]


def worker(worker_args: tuple) -> None:
"""Worker function to populate ChromaDB with a batch of entries."""
thread_name = current_thread().name
entries_batch, parsed_ids, logger, logger_interval = worker_args
logger.info(f"Starting worker thread: {thread_name}")

local_buffer: dict[str, list] = {
"documents": [],
"embeddings": [],
"metadatas": [],
"ids": [],
}
n_entries_local = len(parsed_ids)
n_samples_iter_local = 0

for entry in entries_batch:
chunks = chunk_text(entry["code"], chunk_size)
raw_ids = [
f"id_{i}"
for i in range(n_entries_local, n_entries_local + len(chunks))
]
n_entries_local += len(chunks)
n_samples_iter_local += 1
if n_samples_iter_local % logger_interval == 0:
logger.info(
f"Thread {thread_name} processed {n_samples_iter_local} samples"
)

if set(raw_ids).issubset(set(parsed_ids)):
logger.debug(f"Skipping ids = {raw_ids} as they already exist")
continue

local_buffer["documents"].extend(chunks)
local_buffer["embeddings"].extend(
get_embeddings(chunks, engine=embedding_engine)
)
local_buffer["metadatas"].extend(
[
{
"package": entry["package"],
"path": entry["path"],
"filename": entry["filename"],
}
]
* len(chunks)
)
local_buffer["ids"].extend(raw_ids)

# Write to database in chunks
if len(local_buffer["documents"]) >= batch_size:
logger.debug(f"Inserting ids = {local_buffer['ids']}")
try:
collection.add(
embeddings=local_buffer["embeddings"],
documents=local_buffer["documents"],
metadatas=local_buffer["metadatas"],
ids=local_buffer["ids"],
)
except Exception as e:
logger.error(
f"Failed to insert ids = {local_buffer['ids']}, with {e} skipping."
)
local_buffer = {
"documents": [],
"embeddings": [],
"metadatas": [],
"ids": [],
}


def batch_dataset(dataset, batch_size):
batch = []
for entry in dataset:
batch.append(entry)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch


if __name__ == "__main__":
# TODO - Move to proper CLI based approach, like Fire.
# For now, we are getting it running quick and dirty.
Expand All @@ -34,15 +118,24 @@ def chunk_text(text: str, chunk_size: int) -> list[str]:

# HF dataset
dataset_name = "vikp/pypi_clean"

streaming = False
# Script variables
# For chunking the code into smaller pieces
chunk_size = 2048
# For batching the embedding calls & inserts into ChromaDB
batch_size = 64
sample_log_interval = 10
batches_per_split = 8
# Process dataset in multiple threads
num_threads = 1
# For logging
# TODO - Modify to sure we are logging by-process
log_level = "INFO"
sample_log_interval = 100

# Output collectionn name
collection_name = (
f"{dataset_name.replace('/', '_')}_chunk_size_eq__{chunk_size}"
)
log_level = "INFO"

logger = get_configured_logger("populate_chroma_db", log_level)
logger.info("Starting to populate ChromaDB")
Expand Down Expand Up @@ -75,74 +168,23 @@ def chunk_text(text: str, chunk_size: int) -> list[str]:
collection = client.get_collection(name=collection_name)

parsed_ids = collection.get(include=[])["ids"]
logger.info("Loading the HF dataset now...")
dataset = load_dataset(dataset_name, streaming=streaming)
if not streaming:
dataset = dataset["train"].shuffle(seed=42)

dataset = load_dataset(dataset_name, streaming=True)
n_samples_iter = 0

# Count the number of chunks we have already parsed
n_entries = len(parsed_ids)
if n_entries > 0:
logger.info(f"Loaded {n_entries} entries from ChromaDB")

buffer: dict[str, list] = {
"documents": [],
"embeddings": [],
"metadatas": [],
"ids": [],
}

for entry in dataset["train"]:
chunks = chunk_text(entry["code"], chunk_size)
raw_ids = [
f"id_{i}" for i in range(n_entries, n_entries + len(chunks))
]
n_entries += len(chunks)
n_samples_iter += 1
if n_samples_iter % sample_log_interval == 0:
logger.info(
f"Processed {n_samples_iter} samples, total chunks = {n_entries}"
)
logger.info(f"Current max id = {raw_ids[-1]}")
logger.info(
"Logging buffer info:\n"
+ "\n".join(
[
f"Sanity check -- There are {len(buffer[key])} entries in {key}"
for key in buffer
]
)
)
if set(raw_ids).issubset(set(parsed_ids)):
logger.debug(f"Skipping ids = {raw_ids} as they already exist")
continue

buffer["documents"].extend(chunks)
buffer["embeddings"].extend(
get_embeddings(chunks, engine=embedding_engine)
logger.info("Creating the dataset batches...")
with ThreadPoolExecutor(max_workers=num_threads) as executor:
args_for_workers = (
(batch, parsed_ids, logger, sample_log_interval)
for batch in batch_dataset(dataset, batches_per_split * batch_size)
)
buffer["metadatas"].extend(
[
{
"package": entry["package"],
"path": entry["path"],
"filename": entry["filename"],
}
]
* len(chunks)
)
buffer["ids"].extend(raw_ids)

if len(buffer["documents"]) >= batch_size:
logger.debug(f"Inserting ids = {buffer['ids']}")
collection.add(
embeddings=buffer["embeddings"],
documents=buffer["documents"],
metadatas=buffer["metadatas"],
ids=buffer["ids"],
)
buffer = {
"documents": [],
"embeddings": [],
"metadatas": [],
"ids": [],
}
# The map method blocks until all results are returned
list(executor.map(worker, args_for_workers))

0 comments on commit d2d2427

Please sign in to comment.