Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 29, 2024
1 parent e829a49 commit 12439de
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
9 changes: 3 additions & 6 deletions lilac/embeddings/vector_store_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,10 @@ def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray
else:
locs = self._key_to_label.loc[cast(list[str], keys)].values

def _get_nparrays(index, locs) -> list[np.ndarray]:
items = np.array(index.get_items(locs), dtype=np.float32)
return [np.squeeze(vector) for vector in np.split(items, items.shape[0])]

for loc_chunk in chunks(locs, HNSW_RETRIEVAL_BATCH_SIZE):
item_chunk = _get_nparrays(self._index, loc_chunk)
yield from item_chunk
chunk_items = np.array(self._index.get_items(loc_chunk), dtype=np.float32)
for vector in np.split(chunk_items, chunk_items.shape[0]):
yield np.squeeze(vector)

@override
def topk(
Expand Down
14 changes: 9 additions & 5 deletions lilac/embeddings/vector_store_numpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""NumpyVectorStore class for storing vectors in numpy arrays."""

import os
from typing import Iterable, Optional, cast
from typing import Iterable, Iterator, Optional, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -73,14 +73,18 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
self._key_to_index = new_key_to_label

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
assert (
self._embeddings is not None and self._key_to_index is not None
), 'The vector store has no embeddings. Call load() or add() first.'
if not keys:
return self._embeddings
locs = self._key_to_index.loc[cast(list[str], keys)]
return self._embeddings.take(locs, axis=0)
embeddings = self._embeddings
else:
locs = self._key_to_index.loc[cast(list[str], keys)]
embeddings = self._embeddings.take(locs, axis=0)

for vector in np.split(embeddings, embeddings.shape[0]):
yield np.squeeze(vector)

@override
def topk(
Expand Down

0 comments on commit 12439de

Please sign in to comment.