diff --git a/pyproject.toml b/pyproject.toml index b69fe24..3e4fb0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,9 @@ openai_utils_support = ["openai"] vllm_support = ["accelerate", "torch", "vllm"] llama_index_support = ["llama-index"] chroma_support = ["chromadb"] -# omit "vllm", since it can cause failures w/out cuda -all = ["anthropic", "accelerate", "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb"] +# omit "accelerate" and "vllm", since it can cause failures w/out cuda +all = ["anthropic", "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb"] +all_with_cuda = ["anthropic", "accelerate" , "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb", "vllm"] [tool.poetry.group.dev.dependencies] black = "^23.3.0" diff --git a/sciphi/examples/populate_chroma/runner.py b/sciphi/examples/populate_chroma/runner.py index 259b5e2..30f6f3f 100644 --- a/sciphi/examples/populate_chroma/runner.py +++ b/sciphi/examples/populate_chroma/runner.py @@ -1,5 +1,7 @@ """A module for populating ChromaDB with a dataset.""" import os +from concurrent.futures import ThreadPoolExecutor +from threading import current_thread import chromadb import dotenv @@ -14,9 +16,91 @@ def chunk_text(text: str, chunk_size: int) -> list[str]: + """Chunk a text into a list of strings of size chunk_size.""" return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] +def worker(worker_args: tuple) -> None: + """Worker function to populate ChromaDB with a batch of entries.""" + thread_name = current_thread().name + entries_batch, parsed_ids, logger, logger_interval = worker_args + logger.info(f"Starting worker thread: {thread_name}") + + local_buffer: dict[str, list] = { + "documents": [], + "embeddings": [], + "metadatas": [], + "ids": [], + } + n_entries_local = len(parsed_ids) + n_samples_iter_local = 0 + + for entry in entries_batch: + chunks = chunk_text(entry["code"], chunk_size) + raw_ids = [ + f"id_{i}" + for i in range(n_entries_local, n_entries_local + len(chunks)) + ] + n_entries_local += len(chunks) + n_samples_iter_local += 1 + if n_samples_iter_local % logger_interval == 0: + logger.info( + f"Thread {thread_name} processed {n_samples_iter_local} samples" + ) + + if set(raw_ids).issubset(set(parsed_ids)): + logger.debug(f"Skipping ids = {raw_ids} as they already exist") + continue + + local_buffer["documents"].extend(chunks) + local_buffer["embeddings"].extend( + get_embeddings(chunks, engine=embedding_engine) + ) + local_buffer["metadatas"].extend( + [ + { + "package": entry["package"], + "path": entry["path"], + "filename": entry["filename"], + } + ] + * len(chunks) + ) + local_buffer["ids"].extend(raw_ids) + + # Write to database in chunks + if len(local_buffer["documents"]) >= batch_size: + logger.debug(f"Inserting ids = {local_buffer['ids']}") + try: + collection.add( + embeddings=local_buffer["embeddings"], + documents=local_buffer["documents"], + metadatas=local_buffer["metadatas"], + ids=local_buffer["ids"], + ) + except Exception as e: + logger.error( + f"Failed to insert ids = {local_buffer['ids']}, with {e} skipping." + ) + local_buffer = { + "documents": [], + "embeddings": [], + "metadatas": [], + "ids": [], + } + + +def batch_dataset(dataset, batch_size): + batch = [] + for entry in dataset: + batch.append(entry) + if len(batch) == batch_size: + yield batch + batch = [] + if batch: + yield batch + + if __name__ == "__main__": # TODO - Move to proper CLI based approach, like Fire. # For now, we are getting it running quick and dirty. @@ -34,15 +118,24 @@ def chunk_text(text: str, chunk_size: int) -> list[str]: # HF dataset dataset_name = "vikp/pypi_clean" - + streaming = False # Script variables + # For chunking the code into smaller pieces chunk_size = 2048 + # For batching the embedding calls & inserts into ChromaDB batch_size = 64 - sample_log_interval = 10 + batches_per_split = 8 + # Process dataset in multiple threads + num_threads = 1 + # For logging + # TODO - Modify to sure we are logging by-process + log_level = "INFO" + sample_log_interval = 100 + + # Output collectionn name collection_name = ( f"{dataset_name.replace('/', '_')}_chunk_size_eq__{chunk_size}" ) - log_level = "INFO" logger = get_configured_logger("populate_chroma_db", log_level) logger.info("Starting to populate ChromaDB") @@ -75,8 +168,11 @@ def chunk_text(text: str, chunk_size: int) -> list[str]: collection = client.get_collection(name=collection_name) parsed_ids = collection.get(include=[])["ids"] + logger.info("Loading the HF dataset now...") + dataset = load_dataset(dataset_name, streaming=streaming) + if not streaming: + dataset = dataset["train"].shuffle(seed=42) - dataset = load_dataset(dataset_name, streaming=True) n_samples_iter = 0 # Count the number of chunks we have already parsed @@ -84,65 +180,11 @@ def chunk_text(text: str, chunk_size: int) -> list[str]: if n_entries > 0: logger.info(f"Loaded {n_entries} entries from ChromaDB") - buffer: dict[str, list] = { - "documents": [], - "embeddings": [], - "metadatas": [], - "ids": [], - } - - for entry in dataset["train"]: - chunks = chunk_text(entry["code"], chunk_size) - raw_ids = [ - f"id_{i}" for i in range(n_entries, n_entries + len(chunks)) - ] - n_entries += len(chunks) - n_samples_iter += 1 - if n_samples_iter % sample_log_interval == 0: - logger.info( - f"Processed {n_samples_iter} samples, total chunks = {n_entries}" - ) - logger.info(f"Current max id = {raw_ids[-1]}") - logger.info( - "Logging buffer info:\n" - + "\n".join( - [ - f"Sanity check -- There are {len(buffer[key])} entries in {key}" - for key in buffer - ] - ) - ) - if set(raw_ids).issubset(set(parsed_ids)): - logger.debug(f"Skipping ids = {raw_ids} as they already exist") - continue - - buffer["documents"].extend(chunks) - buffer["embeddings"].extend( - get_embeddings(chunks, engine=embedding_engine) + logger.info("Creating the dataset batches...") + with ThreadPoolExecutor(max_workers=num_threads) as executor: + args_for_workers = ( + (batch, parsed_ids, logger, sample_log_interval) + for batch in batch_dataset(dataset, batches_per_split * batch_size) ) - buffer["metadatas"].extend( - [ - { - "package": entry["package"], - "path": entry["path"], - "filename": entry["filename"], - } - ] - * len(chunks) - ) - buffer["ids"].extend(raw_ids) - - if len(buffer["documents"]) >= batch_size: - logger.debug(f"Inserting ids = {buffer['ids']}") - collection.add( - embeddings=buffer["embeddings"], - documents=buffer["documents"], - metadatas=buffer["metadatas"], - ids=buffer["ids"], - ) - buffer = { - "documents": [], - "embeddings": [], - "metadatas": [], - "ids": [], - } + # The map method blocks until all results are returned + list(executor.map(worker, args_for_workers))