Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/indexing faissless #173

Merged
merged 19 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ragatouille/RAGPretrainedModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def index(
document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
bsize: int = 32,
use_faiss: bool = False,
):
"""Build an index from a list of documents.

Expand Down Expand Up @@ -215,6 +216,7 @@ def index(
max_document_length=max_document_length,
overwrite=overwrite_index,
bsize=bsize,
use_faiss=use_faiss,
)

def add_to_index(
Expand All @@ -227,6 +229,7 @@ def add_to_index(
document_splitter_fn: Optional[Callable] = llama_index_sentence_splitter,
preprocessing_fn: Optional[Union[Callable, list[Callable]]] = None,
bsize: int = 32,
use_faiss: bool = False,
):
"""Add documents to an existing index.

Expand Down Expand Up @@ -258,6 +261,7 @@ def add_to_index(
new_docid_metadata_map=new_docid_metadata_map,
index_name=index_name,
bsize=bsize,
use_faiss=use_faiss,
)

def delete_from_index(
Expand Down
8 changes: 6 additions & 2 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def add_to_index(
new_docid_metadata_map: Optional[List[dict]] = None,
index_name: Optional[str] = None,
bsize: int = 32,
use_faiss: bool = False,
):
self.index_name = index_name if index_name is not None else self.index_name
if self.index_name is None:
Expand Down Expand Up @@ -181,6 +182,7 @@ def add_to_index(
new_collection,
verbose=self.verbose != 0,
bsize=bsize,
use_faiss=use_faiss,
)
self.config = self.model_index.config

Expand Down Expand Up @@ -294,6 +296,7 @@ def index(
max_document_length: int = 256,
overwrite: Union[bool, str] = "reuse",
bsize: int = 32,
use_faiss: bool = False,
):
self.collection = collection
self.config.doc_maxlen = max_document_length
Expand Down Expand Up @@ -341,6 +344,7 @@ def index(
overwrite,
verbose=self.verbose != 0,
bsize=bsize,
use_faiss=use_faiss,
)
self.config = self.model_index.config
self._save_index_metadata()
Expand All @@ -364,9 +368,9 @@ def search(
for doc_id in doc_ids:
pids.extend(self.docid_pid_map[doc_id])

force_reload = self.index_name is not None and index_name != self.index_name
force_reload = index_name is not None and index_name != self.index_name
if index_name is not None:
if self.index_name is not None:
if self.index_name is not None and self.index_name != index_name:
print(
f"New index_name received!",
f"Updating current index_name ({self.index_name}) to {index_name}",
Expand Down
115 changes: 75 additions & 40 deletions ragatouille/models/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import srsly
import torch
from colbert import Indexer, IndexUpdater, Searcher
from colbert.indexing.collection_indexer import CollectionIndexer
from colbert.infra import ColBERTConfig

from ragatouille.models import torch_kmeans

IndexType = Literal["FLAT", "HNSW", "PLAID"]


Expand All @@ -30,8 +33,7 @@ def construct(
overwrite: Union[bool, str] = "reuse",
verbose: bool = True,
**kwargs,
) -> "ModelIndex":
...
) -> "ModelIndex": ...

@staticmethod
@abstractmethod
Expand All @@ -41,8 +43,7 @@ def load_from_file(
index_config: dict[str, Any],
config: ColBERTConfig,
verbose: bool = True,
) -> "ModelIndex":
...
) -> "ModelIndex": ...

@abstractmethod
def build(
Expand All @@ -52,8 +53,7 @@ def build(
index_name: Optional["str"] = None,
overwrite: Union[bool, str] = "reuse",
verbose: bool = True,
) -> None:
...
) -> None: ...

@abstractmethod
def search(
Expand All @@ -68,16 +68,13 @@ def search(
pids: Optional[List[int]] = None,
force_reload: bool = False,
**kwargs,
) -> list[tuple[list, list, list]]:
...
) -> list[tuple[list, list, list]]: ...

@abstractmethod
def _search(self, query: str, k: int, pids: Optional[List[int]] = None):
...
def _search(self, query: str, k: int, pids: Optional[List[int]] = None): ...

@abstractmethod
def _batch_search(self, query: list[str], k: int):
...
def _batch_search(self, query: list[str], k: int): ...

@abstractmethod
def add(
Expand All @@ -90,8 +87,7 @@ def add(
new_collection: List[str],
verbose: bool = True,
**kwargs,
) -> None:
...
) -> None: ...

@abstractmethod
def delete(
Expand All @@ -102,12 +98,10 @@ def delete(
index_name: str,
pids_to_remove: Union[TypeVar("T"), List[TypeVar("T")]],
verbose: bool = True,
) -> None:
...
) -> None: ...

@abstractmethod
def _export_config(self) -> dict[str, Any]:
...
def _export_config(self) -> dict[str, Any]: ...

def export_metadata(self) -> dict[str, Any]:
config = self._export_config()
Expand Down Expand Up @@ -168,21 +162,6 @@ def build(
bsize = kwargs.get("bsize", PLAIDModelIndex._DEFAULT_INDEX_BSIZE)
assert isinstance(bsize, int)

if torch.cuda.is_available():
import faiss

if not hasattr(faiss, "StandardGpuResources"):
print(
"________________________________________________________________________________\n"
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
"This means that indexing will be slow. To make use of your GPU.\n"
"Please install `faiss-gpu` by running:\n"
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
"________________________________________________________________________________",
)
print("Will continue with CPU indexing in 5 seconds...")
time.sleep(5)

nbits = 2
if len(collection) < 5000:
nbits = 8
Expand All @@ -201,13 +180,69 @@ def build(

# Instruct colbert-ai to disable forking if nranks == 1
self.config.avoid_fork_if_possible = True
indexer = Indexer(
checkpoint=checkpoint,
config=self.config,
verbose=verbose,
)
indexer.configure(avoid_fork_if_possible=True)
indexer.index(name=index_name, collection=collection, overwrite=overwrite)

# Monkey-patch colbert-ai to avoid using FAISS
monkey_patching = False
bclavie marked this conversation as resolved.
Show resolved Hide resolved
if len(collection) < 500000 and kwargs.get("use_faiss", False) is False:
print(
"---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----"
)
print("This is a behaviour change from RAGatouille 0.8.0 onwards.")
Anmol6 marked this conversation as resolved.
Show resolved Hide resolved
print(
"This works fine for most users, but is slower than FAISS and slightly more approximate."
)
print(
"If you're confident with FAISS working issue-free on your machine, pass use_faiss=True to revert to the FAISS-using behaviour."
)
print("--------------------")
CollectionIndexer._original_train_kmeans = CollectionIndexer._train_kmeans
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use another variable to track this? Would avoid directly setting object attributes!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could yeah! This was mostly for the convenience of checking on hasattr later on, but it might be better practice to set it to a new object instead. I'll change it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned on Discord but I'm actually thinking this is a relatively sane way of doing it, because we need to keep it alive for the entirety of the session -- we're monkey-patching the colbert-ai indexer itself and we want to be able to revert anytime someone needs to use faiss, so local variables wouldn't cut it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right yea, makes sense!

In this case, can we assign the faiss and non-faiss k-means functions as class attributes of PLAIDModelIndex. Then, at build time, we can just toggle between them (to set CollectionIndexer._train_k_means) based on the monkey_patch flag. Wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to get your thoughts here too @jlscheerer !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having it be a class attribute on PLAIDModelIndex and toggling it on use (+ persisting the flag). This would perhaps provide more consistent behaviour when rebuilding/adding to an already persisted index (e.g., if we decide to rebuild as part of add_to_index).

Copy link
Collaborator Author

@bclavie bclavie Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! Implemented now.

CollectionIndexer._train_kmeans = torch_kmeans._train_kmeans
monkey_patching = True
try:
indexer = Indexer(
checkpoint=checkpoint,
config=self.config,
verbose=verbose,
)
indexer.configure(avoid_fork_if_possible=True)
indexer.index(
name=index_name, collection=collection, overwrite=overwrite
)
except Exception:
print(
bclavie marked this conversation as resolved.
Show resolved Hide resolved
"PyTorch-based indexing did not succeed! Reverting to using FAISS and attempting again..."
)
CollectionIndexer._train_kmeans = (
CollectionIndexer._original_train_kmeans
)
monkey_patching = False
if monkey_patching is False:
if hasattr(CollectionIndexer, "_original_train_kmeans"):
CollectionIndexer._train_kmeans = (
CollectionIndexer._original_train_kmeans
)
if torch.cuda.is_available():
import faiss

if not hasattr(faiss, "StandardGpuResources"):
print(
"________________________________________________________________________________\n"
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
"This means that indexing will be slow. To make use of your GPU.\n"
"Please install `faiss-gpu` by running:\n"
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
"________________________________________________________________________________",
)
print("Will continue with CPU indexing in 5 seconds...")
time.sleep(5)
indexer = Indexer(
checkpoint=checkpoint,
config=self.config,
verbose=verbose,
)
indexer.configure(avoid_fork_if_possible=True)
indexer.index(name=index_name, collection=collection, overwrite=overwrite)

return self

def _load_searcher(
Expand Down
119 changes: 119 additions & 0 deletions ragatouille/models/torch_kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import time

import torch


def _train_kmeans(self, sample, shared_lists): # noqa: ARG001
bclavie marked this conversation as resolved.
Show resolved Hide resolved
if self.use_gpu:
torch.cuda.empty_cache()
centroids = compute_pytorch_kmeans(
sample,
self.config.dim,
self.num_partitions,
self.config.kmeans_niters,
self.use_gpu,
)
centroids = torch.nn.functional.normalize(centroids, dim=-1)
if self.use_gpu:
centroids = centroids.half()
else:
centroids = centroids.float()
return centroids


def compute_pytorch_kmeans(
sample,
dim,
num_partitions,
kmeans_niters,
use_gpu,
batch_size=512000,
tol=1e-4,
verbose=1,
):
device = torch.device("cuda" if use_gpu else "cpu")
sample = sample.to(device)
total_size = sample.shape[0]

# Initialize centroids randomly
centroids = torch.randn(num_partitions, dim, dtype=sample.dtype, device=device)

# Convert to half-precision if GPU is available
if use_gpu:
sample = sample.half()
centroids = centroids.half()
else:
sample = sample.float()
centroids = centroids.float()

# Precompute the squared norms of data points
sample_norms = torch.sum(sample.pow(2), dim=1, keepdim=True)

start_time = time.time()
for i in range(kmeans_niters):
iter_time = time.time()

# Shuffle the data points
permutation = torch.randperm(total_size, device=device)
sample = sample[permutation]
sample_norms = sample_norms[permutation]

if total_size <= batch_size:
bclavie marked this conversation as resolved.
Show resolved Hide resolved
# Compute distances and assignments for the entire dataset
distances = (
sample_norms
- 2 * torch.mm(sample, centroids.t())
+ torch.sum(centroids.pow(2), dim=1).unsqueeze(0)
)
assignments = torch.min(distances, dim=1)[1]

# Update centroids by taking the mean of assigned data points
for j in range(num_partitions):
assigned_points = sample[assignments == j]
if len(assigned_points) > 0:
centroids[j] = assigned_points.mean(dim=0)

# Compute the error (sum of squared distances)
error = torch.sum((sample - centroids[assignments]).pow(2))
else:
# Process the data points in batches
error = 0.0
for batch_start in range(0, total_size, batch_size):
batch_end = min(batch_start + batch_size, total_size)
batch = sample[batch_start:batch_end]
batch_norms = sample_norms[batch_start:batch_end]

# Compute distances and assignments for the batch
distances = (
batch_norms
- 2 * torch.mm(batch, centroids.t())
+ torch.sum(centroids.pow(2), dim=1).unsqueeze(0)
)
assignments = torch.min(distances, dim=1)[1]

# Update centroids by taking the mean of assigned data points
for j in range(num_partitions):
assigned_points = batch[assignments == j]
if len(assigned_points) > 0:
centroids[j] = (
centroids[j] * 0.9 + assigned_points.mean(dim=0) * 0.1
)

# Accumulate the error for the batch
error += torch.sum((batch - centroids[assignments]).pow(2))

if verbose >= 2:
print(
f"Iteration: {i+1}, Error: {error.item():.4f}, Time: {time.time() - iter_time:.4f}s"
)

# Check for convergence (unlikely to early stop, but still useful!)
if error <= tol:
break

if verbose >= 1:
print(
f"Used {i+1} iterations ({time.time() - start_time:.4f}s) to cluster {total_size} items into {num_partitions} clusters"
)

return centroids
Loading