diff --git a/buster/documents_manager/base.py b/buster/documents_manager/base.py index 9e504dd..d6f5d39 100644 --- a/buster/documents_manager/base.py +++ b/buster/documents_manager/base.py @@ -1,4 +1,5 @@ import logging +import time from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional @@ -66,12 +67,31 @@ def _check_required_columns(self, df: pd.DataFrame): if not all(col in df.columns for col in self.required_columns): raise ValueError(f"DataFrame is missing one or more of {self.required_columns=}") + def _checkpoint_csv(self, df, csv_filename: str, csv_overwrite: bool = True): + import os + + if csv_overwrite: + df.to_csv(csv_filename) + logger.info(f"Saved DataFrame with embeddings to {csv_filename}") + + else: + if os.path.exists(csv_filename): + # append to existing file + append_df = pd.read_csv(csv_filename) + append_df = pd.concat([append_df, df]) + else: + # will create the new file + append_df = df.copy() + append_df.to_csv(csv_filename) + logger.info(f"Appending DataFrame embeddings to {csv_filename}") + def add( self, df: pd.DataFrame, num_workers: int = 16, embedding_fn: callable = get_embedding_openai, - csv_checkpoint: Optional[str] = None, + csv_filename: Optional[str] = None, + csv_overwrite: bool = True, **add_kwargs, ): """Write documents from a DataFrame into the DocumentManager store. @@ -88,7 +108,8 @@ def add( embedding_fn (callable, optional): A function that computes embeddings for a given input string. Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model. - csv_checkpoint: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use. + csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use. + csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file. **add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method. @@ -101,12 +122,78 @@ def add( if "embedding" not in df.columns: df["embedding"] = compute_embeddings_parallelized(df, embedding_fn=embedding_fn, num_workers=num_workers) - if csv_checkpoint is not None: - df.to_csv(csv_checkpoint) - logger.info(f"Saving DataFrame with embeddings to {csv_checkpoint}") + if csv_filename is not None: + self._checkpoint_csv(df, csv_filename=csv_filename, csv_overwrite=csv_overwrite) self._add_documents(df, **add_kwargs) + def batch_add( + self, + df: pd.DataFrame, + batch_size: int = 3000, + min_time_interval: int = 60, + num_workers: int = 16, + embedding_fn: callable = get_embedding_openai, + csv_filename: Optional[str] = None, + csv_overwrite: bool = False, + **add_kwargs, + ): + """ + This function takes a DataFrame and adds its data to a DataManager instance in batches. + It ensures that a minimum time interval is maintained between successive batches + to prevent timeouts or excessive load. This is useful for APIs like openAI with rate limits. + + Args: + df (pandas.DataFrame): The input DataFrame containing data to be added. + batch_size (int, optional): The size of each batch. Defaults to 3000. + min_time_interval (int, optional): The minimum time interval (in seconds) between batches. + Defaults to 60. + num_workers (int, optional): The number of parallel workers to use when adding data. + Defaults to 32. + embedding_fn (callable, optional): A function that computes embeddings for a given input string. + Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model. + csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use. + csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file. + When using batches, set to False to keep all embeddings in the same file. You may want to manually remove the file if experimenting. + + **add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method. + + Returns: + None + """ + total_batches = (len(df) // batch_size) + 1 + + logger.info(f"Adding {len(df)} documents with {batch_size=} for {total_batches=}") + + for batch_idx in range(total_batches): + logger.info(f"Processing batch {batch_idx + 1}/{total_batches}") + start_time = time.time() + + # Calculate batch indices and extract batch DataFrame + start_idx = batch_idx * batch_size + end_idx = min((batch_idx + 1) * batch_size, len(df)) + batch_df = df.iloc[start_idx:end_idx] + + # Add the batch data to using specified parameters + self.add( + batch_df, + num_workers=num_workers, + csv_filename=csv_filename, + csv_overwrite=csv_overwrite, + embedding_fn=embedding_fn, + **add_kwargs, + ) + + elapsed_time = time.time() - start_time + sleep_time = max(0, min_time_interval - elapsed_time) + + # Sleep to ensure the minimum time interval is maintained + if sleep_time > 0: + logger.info(f"Sleeping for {round(sleep_time)} seconds...") + time.sleep(sleep_time) + + logger.info("All batches processed.") + @abstractmethod def _add_documents(self, df: pd.DataFrame, **add_kwargs): """Abstract method to be implemented by each inherited member. diff --git a/buster/documents_manager/deeplake.py b/buster/documents_manager/deeplake.py index cc4a3eb..ce3ea73 100644 --- a/buster/documents_manager/deeplake.py +++ b/buster/documents_manager/deeplake.py @@ -6,7 +6,7 @@ from buster.utils import zip_contents -from .base import DocumentsManager, get_embedding_openai +from .base import DocumentsManager logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -28,6 +28,9 @@ def __init__( **vector_store_kwargs, ) + def __len__(self): + return len(self.vector_store) + @classmethod def _extract_metadata(cls, df: pd.DataFrame) -> dict: """extract the metadata from the dataframe in deeplake dict format""" diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index 2b86e76..5a69507 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -152,7 +152,7 @@ def vector_store_path(tmp_path_factory): # Add the documents (will generate embeddings) dm = DeepLakeDocumentsManager(vector_store_path=dm_path) df = pd.read_csv(DOCUMENTS_CSV) - dm.add(df, num_workers=1) + dm.add(df, num_workers=NUM_WORKERS) return dm_path diff --git a/tests/test_documents.py b/tests/test_documents.py index 9dc14b7..f82c8eb 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pandas as pd import pytest @@ -154,3 +156,40 @@ def test_generate_embeddings_parallelized(): embeddings_arr = np.array(embeddings.to_list()) assert np.allclose(embeddings_parallel, embeddings_arr, atol=1e-3) + + +def test_add_batches(tmp_path): + dm_path = tmp_path / "deeplake_store" + num_samples = 20 + batch_size = 16 + csv_filename = os.path.join(tmp_path, "embedding_") + + dm = DeepLakeDocumentsManager(vector_store_path=dm_path) + + # Create fake data + df = pd.DataFrame.from_dict( + { + "title": ["test"] * num_samples, + "url": ["http://url.com"] * num_samples, + "content": ["cool text" + str(x) for x in range(num_samples)], + "source": ["my_source"] * num_samples, + } + ) + + dm.batch_add( + df, + embedding_fn=get_fake_embedding, + num_workers=NUM_WORKERS, + batch_size=batch_size, + min_time_interval=0, + csv_filename=csv_filename, + ) + + csv_files = [f for f in os.listdir(tmp_path) if f.endswith(".csv")] + + # check that we registered the good number of doucments and that files were generated + assert len(dm) == num_samples + + df_saved = pd.read_csv(csv_filename) + assert len(df_saved) == num_samples + assert "embedding" in df_saved.columns