diff --git a/.env.example b/.env.example index 2360116..b99260c 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,7 @@ OPENAI_API_KEY=your_openai_key ANTHROPIC_API_KEY=your_anthropic_key -HF_TOKEN=your_huggingface_token \ No newline at end of file +HF_TOKEN=your_huggingface_token +CHROMA_REMOTE_ADDR=your_chroma_db_addr +CHROMA_REMOTE_PORT="8000" # default +CHROMA_TOKEN=your_chroma_db_token +CHROMA_AUTH_PROVIDER="chromadb.auth.token.TokenAuthClientProvider" \ No newline at end of file diff --git a/README.md b/README.md index aa2f252..e0c9d1c 100644 --- a/README.md +++ b/README.md @@ -23,9 +23,7 @@ git clone https://github.com/emrgnt-cmplxty/sciphi.git cd sciphi # Install dependencies # pip3 install poetry (if you don't have it) -poetry install -E openai_support -# Add other optional dependencies -# poetry install -E openai_support -E anthropic_support -E hf_support ... +poetry install -E all # Setup your environment cp .env.example .env && vim .env ``` diff --git a/pyproject.toml b/pyproject.toml index 488de54..b69fe24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,21 +17,36 @@ python = ">=3.10,<3.12" python-dotenv = "^1.0.0" pandas = "^2.1.0" # Begin optional dependencies -accelerate = { version = "^0.23.0", optional = true } +# anthropic anthropic = { version = "^0.3.10", optional = true } +# hf +accelerate = { version = "^0.23.0", optional = true } datasets = { version = "^2.14.5", optional = true } -llama-index = { version = "^0.8.29.post1", optional = true } -openai = { version = "0.27.8", optional = true } torch = { version = "^2.0.1", optional = true } transformers = { version = "^4.33.1", optional = true } +# openai +openai = { version = "0.27.8", optional = true } +matplotlib = {version = "^3.8.0", optional = true} +plotly = {version = "^5.17.0", optional = true} +scipy = {version = "^1.11.2", optional = true} +scikit-learn = {version = "^1.3.1", optional = true} +# vllm vllm = { version = "0.1.7", optional = true } +# llama-index +llama-index = { version = "^0.8.29.post1", optional = true } +# chroma +chromadb = { version = "^0.4.12", optional = true } [tool.poetry.extras] anthropic_support = ["anthropic"] hf_support = ["accelerate", "datasets", "torch", "transformers"] openai_support = ["openai"] +openai_utils_support = ["openai"] vllm_support = ["accelerate", "torch", "vllm"] llama_index_support = ["llama-index"] +chroma_support = ["chromadb"] +# omit "vllm", since it can cause failures w/out cuda +all = ["anthropic", "accelerate", "datasets", "torch", "transformers", "openai", "matplotlib", "plotly", "scipy", "scikit-learn", "llama-index", "chromadb"] [tool.poetry.group.dev.dependencies] black = "^23.3.0" diff --git a/sciphi/examples/helpers.py b/sciphi/examples/helpers.py index 52ef0d5..a58eadd 100644 --- a/sciphi/examples/helpers.py +++ b/sciphi/examples/helpers.py @@ -95,7 +95,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( "--max_tokens_to_sample", type=int, - default=None, + default=1_024, help="Max tokens to sample for each completion from the provided model.", ) parser.add_argument( diff --git a/sciphi/examples/populate_chroma/__init__.py b/sciphi/examples/populate_chroma/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sciphi/examples/populate_chroma/runner.py b/sciphi/examples/populate_chroma/runner.py new file mode 100644 index 0000000..21d7bcb --- /dev/null +++ b/sciphi/examples/populate_chroma/runner.py @@ -0,0 +1,133 @@ +import os + +import chromadb +import dotenv +import openai +from chromadb.config import Settings +from datasets import load_dataset +from openai.embeddings_utils import get_embeddings + +from sciphi.core.utils import get_configured_logger + +dotenv.load_dotenv() + + +def chunk_text(text: str, chunk_size: int) -> list[str]: + return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] + + +if __name__ == "__main__": + chroma_addr = os.environ["CHROMA_REMOTE_ADDR"] + chroma_port = os.environ["CHROMA_REMOTE_PORT"] + chroma_token = os.environ["CHROMA_TOKEN"] + chroma_auth_provider = os.environ["CHROMA_AUTH_PROVIDER"] + openai_api_key = os.environ["OPENAI_API_KEY"] + + openai.api_key = openai_api_key + dataset_name = "vikp/pypi_clean" + chunk_size = 2048 + batch_size = 64 + sample_log_interval = 10 + collection_name = f"{dataset_name.replace('/', '_')}_chunk_size_eq_2048" + log_level = "INFO" + + logger = get_configured_logger("populate_chroma_db", log_level) + logger.info("Starting to populate ChromaDB") + + if not chroma_token or not chroma_addr or not chroma_port: + raise ValueError( + f"ChromaDB environment variables not set correctly, found: chroma_token={chroma_token}, chroma_addr={chroma_addr}, chroma_port={chroma_port}" + ) + + if not openai_api_key: + raise ValueError( + "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable." + ) + + client = chromadb.HttpClient( + host=chroma_addr, + port=chroma_port, + settings=Settings( + chroma_client_auth_provider=chroma_auth_provider, + chroma_client_auth_credentials=chroma_token, + ), + ) + + try: + collection = client.create_collection(name=collection_name) + except Exception as e: + logger.info( + f"Collection {collection_name} likely already exists, skipping creation. For completeness, here is the exception: {e}" + ) + collection = client.get_collection(name=collection_name) + + parsed_ids = collection.get(include=[])["ids"] + + dataset = load_dataset(dataset_name, streaming=True) + n_samples_iter = 0 + + # Count the number of chunks we have already parsed + n_entries = len(parsed_ids) + if n_entries > 0: + logger.info(f"Loaded {n_entries} entries from ChromaDB") + + buffer: dict[str, list] = { + "documents": [], + "embeddings": [], + "metadatas": [], + "ids": [], + } + + for entry in dataset["train"]: + chunks = chunk_text(entry["code"], chunk_size) + raw_ids = [ + f"id_{i}" for i in range(n_entries, n_entries + len(chunks)) + ] + n_entries += len(chunks) + n_samples_iter += 1 + if n_samples_iter % sample_log_interval == 0: + logger.info( + f"Processed {n_samples_iter} samples, total chunks = {n_entries}" + ) + logger.info(f"Current max id = {raw_ids[-1]}") + logger.info( + "Logging buffer info:\n" + + "\n".join( + [ + f"Sanity check -- There are {len(buffer[key])} entries in {key}" + for key in buffer + ] + ) + ) + if set(raw_ids).issubset(set(parsed_ids)): + logger.debug(f"Skipping ids = {raw_ids} as they already exist") + continue + + buffer["documents"].extend(chunks) + buffer["embeddings"].extend(get_embeddings(chunks)) + buffer["metadatas"].extend( + [ + { + "package": entry["package"], + "path": entry["path"], + "filename": entry["filename"], + } + ] + * len(chunks) + ) + buffer["ids"].extend(raw_ids) + + if len(buffer["documents"]) >= batch_size: + logger.debug(f"Inserting ids = {buffer['ids']}") + collection.add( + embeddings=buffer["embeddings"], + documents=buffer["documents"], + metadatas=buffer["metadatas"], + ids=buffer["ids"], + ) + buffer = { + "documents": [], + "embeddings": [], + "metadatas": [], + "ids": [], + }