Skip to content

Commit

Permalink
add batch_add method (#123)
Browse files Browse the repository at this point in the history
* add batch_add method

* save all files to single .csv with batch add
  • Loading branch information
jerpint authored Aug 11, 2023
1 parent c662476 commit e0fbbd6
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 7 deletions.
97 changes: 92 additions & 5 deletions buster/documents_manager/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
Expand Down Expand Up @@ -66,12 +67,31 @@ def _check_required_columns(self, df: pd.DataFrame):
if not all(col in df.columns for col in self.required_columns):
raise ValueError(f"DataFrame is missing one or more of {self.required_columns=}")

def _checkpoint_csv(self, df, csv_filename: str, csv_overwrite: bool = True):
import os

if csv_overwrite:
df.to_csv(csv_filename)
logger.info(f"Saved DataFrame with embeddings to {csv_filename}")

else:
if os.path.exists(csv_filename):
# append to existing file
append_df = pd.read_csv(csv_filename)
append_df = pd.concat([append_df, df])
else:
# will create the new file
append_df = df.copy()
append_df.to_csv(csv_filename)
logger.info(f"Appending DataFrame embeddings to {csv_filename}")

def add(
self,
df: pd.DataFrame,
num_workers: int = 16,
embedding_fn: callable = get_embedding_openai,
csv_checkpoint: Optional[str] = None,
csv_filename: Optional[str] = None,
csv_overwrite: bool = True,
**add_kwargs,
):
"""Write documents from a DataFrame into the DocumentManager store.
Expand All @@ -88,7 +108,8 @@ def add(
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
csv_checkpoint: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use.
csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file.
**add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method.
Expand All @@ -101,12 +122,78 @@ def add(
if "embedding" not in df.columns:
df["embedding"] = compute_embeddings_parallelized(df, embedding_fn=embedding_fn, num_workers=num_workers)

if csv_checkpoint is not None:
df.to_csv(csv_checkpoint)
logger.info(f"Saving DataFrame with embeddings to {csv_checkpoint}")
if csv_filename is not None:
self._checkpoint_csv(df, csv_filename=csv_filename, csv_overwrite=csv_overwrite)

self._add_documents(df, **add_kwargs)

def batch_add(
self,
df: pd.DataFrame,
batch_size: int = 3000,
min_time_interval: int = 60,
num_workers: int = 16,
embedding_fn: callable = get_embedding_openai,
csv_filename: Optional[str] = None,
csv_overwrite: bool = False,
**add_kwargs,
):
"""
This function takes a DataFrame and adds its data to a DataManager instance in batches.
It ensures that a minimum time interval is maintained between successive batches
to prevent timeouts or excessive load. This is useful for APIs like openAI with rate limits.
Args:
df (pandas.DataFrame): The input DataFrame containing data to be added.
batch_size (int, optional): The size of each batch. Defaults to 3000.
min_time_interval (int, optional): The minimum time interval (in seconds) between batches.
Defaults to 60.
num_workers (int, optional): The number of parallel workers to use when adding data.
Defaults to 32.
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file.
When using batches, set to False to keep all embeddings in the same file. You may want to manually remove the file if experimenting.
**add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method.
Returns:
None
"""
total_batches = (len(df) // batch_size) + 1

logger.info(f"Adding {len(df)} documents with {batch_size=} for {total_batches=}")

for batch_idx in range(total_batches):
logger.info(f"Processing batch {batch_idx + 1}/{total_batches}")
start_time = time.time()

# Calculate batch indices and extract batch DataFrame
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(df))
batch_df = df.iloc[start_idx:end_idx]

# Add the batch data to using specified parameters
self.add(
batch_df,
num_workers=num_workers,
csv_filename=csv_filename,
csv_overwrite=csv_overwrite,
embedding_fn=embedding_fn,
**add_kwargs,
)

elapsed_time = time.time() - start_time
sleep_time = max(0, min_time_interval - elapsed_time)

# Sleep to ensure the minimum time interval is maintained
if sleep_time > 0:
logger.info(f"Sleeping for {round(sleep_time)} seconds...")
time.sleep(sleep_time)

logger.info("All batches processed.")

@abstractmethod
def _add_documents(self, df: pd.DataFrame, **add_kwargs):
"""Abstract method to be implemented by each inherited member.
Expand Down
5 changes: 4 additions & 1 deletion buster/documents_manager/deeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from buster.utils import zip_contents

from .base import DocumentsManager, get_embedding_openai
from .base import DocumentsManager

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand All @@ -28,6 +28,9 @@ def __init__(
**vector_store_kwargs,
)

def __len__(self):
return len(self.vector_store)

@classmethod
def _extract_metadata(cls, df: pd.DataFrame) -> dict:
"""extract the metadata from the dataframe in deeplake dict format"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def vector_store_path(tmp_path_factory):
# Add the documents (will generate embeddings)
dm = DeepLakeDocumentsManager(vector_store_path=dm_path)
df = pd.read_csv(DOCUMENTS_CSV)
dm.add(df, num_workers=1)
dm.add(df, num_workers=NUM_WORKERS)
return dm_path


Expand Down
39 changes: 39 additions & 0 deletions tests/test_documents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -154,3 +156,40 @@ def test_generate_embeddings_parallelized():
embeddings_arr = np.array(embeddings.to_list())

assert np.allclose(embeddings_parallel, embeddings_arr, atol=1e-3)


def test_add_batches(tmp_path):
dm_path = tmp_path / "deeplake_store"
num_samples = 20
batch_size = 16
csv_filename = os.path.join(tmp_path, "embedding_")

dm = DeepLakeDocumentsManager(vector_store_path=dm_path)

# Create fake data
df = pd.DataFrame.from_dict(
{
"title": ["test"] * num_samples,
"url": ["http://url.com"] * num_samples,
"content": ["cool text" + str(x) for x in range(num_samples)],
"source": ["my_source"] * num_samples,
}
)

dm.batch_add(
df,
embedding_fn=get_fake_embedding,
num_workers=NUM_WORKERS,
batch_size=batch_size,
min_time_interval=0,
csv_filename=csv_filename,
)

csv_files = [f for f in os.listdir(tmp_path) if f.endswith(".csv")]

# check that we registered the good number of doucments and that files were generated
assert len(dm) == num_samples

df_saved = pd.read_csv(csv_filename)
assert len(df_saved) == num_samples
assert "embedding" in df_saved.columns

0 comments on commit e0fbbd6

Please sign in to comment.