Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiple indexes in a vector store #16972

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llama-index-core/llama_index/core/indices/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def set_index_id(self, index_id: str) -> None:
# add the new index struct
self._index_struct.index_id = index_id
self._storage_context.index_store.add_index_struct(self._index_struct)
if hasattr(self._vector_store, "move_nodes"):
self._vector_store.move_nodes(from_index_id=old_id, to_index_id=index_id)

@property
def docstore(self) -> BaseDocumentStore:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def as_retriever(
):
sub_retrievers.append(
VectorContextRetriever(
index_id=self.index_id,
graph_store=self.property_graph_store,
vector_store=self.vector_store,
include_text=include_text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def _async_add_nodes_to_index(
nodes_batch = await self._aget_node_with_embedding(
nodes_batch, show_progress
)
new_ids = await self._vector_store.async_add(nodes_batch, **insert_kwargs)
new_ids = await self._vector_store.async_add(nodes_batch, index_id=index_struct.index_id, **insert_kwargs)

# if the vector store doesn't store text, we need to add the nodes to the
# index struct and document store
Expand Down Expand Up @@ -230,7 +230,7 @@ def _add_nodes_to_index(

for nodes_batch in iter_batch(nodes, self._insert_batch_size):
nodes_batch = self._get_node_with_embedding(nodes_batch, show_progress)
new_ids = self._vector_store.add(nodes_batch, **insert_kwargs)
new_ids = self._vector_store.add(nodes_batch, index_id=index_struct.index_id, **insert_kwargs)

if not self._vector_store.stores_text or self._store_nodes_override:
# NOTE: if the vector store doesn't store text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _get_nodes_with_embeddings(
self, query_bundle_with_embeddings: QueryBundle
) -> List[NodeWithScore]:
query = self._build_vector_store_query(query_bundle_with_embeddings)
query_result = self._vector_store.query(query, **self._kwargs)
query_result = self._vector_store.query(query, index_id=self._index.index_id, **self._kwargs)
return self._build_node_list_from_query_result(query_result)

async def _aget_nodes_with_embeddings(
Expand Down
8 changes: 6 additions & 2 deletions llama-index-core/llama_index/core/vector_stores/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,16 @@ def get_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
) -> List[BaseNode]:
"""Get nodes."""
raise NotImplementedError("SimpleVectorStore does not store nodes directly.")

def add(
self,
nodes: Sequence[BaseNode],
**add_kwargs: Any,
index_id: Optional[str] = None,
**kwargs: Any,
) -> List[str]:
"""Add nodes to index."""
for node in nodes:
Expand Down Expand Up @@ -287,6 +289,7 @@ def delete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
**delete_kwargs: Any,
) -> None:
filter_fn = _build_metadata_filter_fn(
Expand All @@ -310,13 +313,14 @@ def node_filter_fn(node_id: str) -> bool:
del self.data.text_id_to_ref_doc_id[node_id]
self.data.metadata_dict.pop(node_id, None)

def clear(self) -> None:
def clear(self, index_id: Optional[str] = None) -> None:
"""Clear the store."""
self.data = SimpleVectorStoreData()

def query(
self,
query: VectorStoreQuery,
index_id: Optional[str] = None,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Get nodes for response."""
Expand Down
27 changes: 18 additions & 9 deletions llama-index-core/llama_index/core/vector_stores/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,14 @@ class BasePydanticVectorStore(BaseComponent, ABC):
def client(self) -> Any:
"""Get client."""

def move_nodes(self, from_index_id: str, to_index_id: str):
pass

def get_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
) -> List[BaseNode]:
"""Get nodes from vector store."""
raise NotImplementedError("get_nodes not implemented")
Expand All @@ -345,29 +349,32 @@ async def aget_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
) -> List[BaseNode]:
"""Asynchronously get nodes from vector store."""
return self.get_nodes(node_ids, filters)
return self.get_nodes(node_ids, filters, index_id)

@abstractmethod
def add(
self,
nodes: Sequence[BaseNode],
index_id: Optional[str] = None,
**kwargs: Any,
) -> List[str]:
"""Add nodes to vector store."""

async def async_add(
self,
nodes: Sequence[BaseNode],
index_id: Optional[str] = None,
**kwargs: Any,
) -> List[str]:
"""
Asynchronously add nodes to vector store.
NOTE: this is not implemented for all vector stores. If not implemented,
it will just call add synchronously.
"""
return self.add(nodes, **kwargs)
return self.add(nodes, index_id, **kwargs)

@abstractmethod
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
Expand All @@ -386,6 +393,7 @@ def delete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
**delete_kwargs: Any,
) -> None:
"""Delete nodes from vector store."""
Expand All @@ -395,32 +403,33 @@ async def adelete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
**delete_kwargs: Any,
) -> None:
"""Asynchronously delete nodes from vector store."""
self.delete_nodes(node_ids, filters)
self.delete_nodes(node_ids, filters, index_id)

def clear(self) -> None:
def clear(self, index_id: Optional[str] = None) -> None:
"""Clear all nodes from configured vector store."""
raise NotImplementedError("clear not implemented")

async def aclear(self) -> None:
async def aclear(self, index_id: Optional[str] = None) -> None:
"""Asynchronously clear all nodes from configured vector store."""
self.clear()
self.clear(index_id=index_id)

@abstractmethod
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
def query(self, query: VectorStoreQuery, index_id: Optional[str] = None, **kwargs: Any) -> VectorStoreQueryResult:
"""Query vector store."""

async def aquery(
self, query: VectorStoreQuery, **kwargs: Any
self, query: VectorStoreQuery, index_id: Optional[str] = None, **kwargs: Any
) -> VectorStoreQueryResult:
"""
Asynchronously query vector store.
NOTE: this is not implemented for all vector stores. If not implemented,
it will just call query synchronously.
"""
return self.query(query, **kwargs)
return self.query(query, index_id, **kwargs)

def persist(
self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def _create_index(self, index_name: Optional[str]) -> None:
filterable=True,
hidden=self._field_mapping["id"] in self._hidden_field_keys,
),
SimpleField(
name=self._field_mapping["index_id"],
type="Edm.String",
filterable=True,
hidden=self._field_mapping["index_id"] in self._hidden_field_keys,
),
SearchableField(
name=self._field_mapping["chunk"],
type="Edm.String",
Expand Down Expand Up @@ -741,6 +747,7 @@ def __init__(
# Default field mapping
field_mapping = {
"id": id_field_key,
"index_id": "index_id",
"chunk": chunk_field_key,
"embedding": embedding_field_key,
"metadata": metadata_string_field_key,
Expand Down Expand Up @@ -796,9 +803,17 @@ def _default_index_mapping(

return index_doc

def move_nodes(self, from_index_id: str, to_index_id: str):
nodes = self.get_nodes(index_id=from_index_id)
updates = [
{"id": n.id_, "index_id": to_index_id} for n in nodes
]
self._search_client.merge_documents(updates)

def add(
self,
nodes: List[BaseNode],
index_id: Optional[str] = None,
**add_kwargs: Any,
) -> List[str]:
"""
Expand All @@ -825,7 +840,7 @@ def add(
logger.debug(f"Processing embedding: {node.node_id}")
ids.append(node.node_id)

index_document = self._create_index_document(node)
index_document = self._create_index_document(node, index_id)
document_size = len(json.dumps(index_document).encode("utf-8"))
documents.append(index_document)
accumulated_size += document_size
Expand Down Expand Up @@ -857,6 +872,7 @@ def add(
async def async_add(
self,
nodes: List[BaseNode],
index_id: Optional[str] = None,
**add_kwargs: Any,
) -> List[str]:
"""
Expand Down Expand Up @@ -891,7 +907,7 @@ async def async_add(
logger.debug(f"Processing embedding: {node.node_id}")
ids.append(node.node_id)

index_document = self._create_index_document(node)
index_document = self._create_index_document(node, index_id)
document_size = len(json.dumps(index_document).encode("utf-8"))
documents.append(index_document)
accumulated_size += document_size
Expand Down Expand Up @@ -920,10 +936,11 @@ async def async_add(

return ids

def _create_index_document(self, node: BaseNode) -> Dict[str, Any]:
def _create_index_document(self, node: BaseNode, index_id: str) -> Dict[str, Any]:
"""Create AI Search index document from embedding result."""
doc: Dict[str, Any] = {}
doc["id"] = node.node_id
doc["index_id"] = index_id
doc["chunk"] = node.get_content(metadata_mode=MetadataMode.NONE) or ""
doc["embedding"] = node.get_embedding()
doc["doc_id"] = node.ref_doc_id
Expand Down Expand Up @@ -1004,6 +1021,7 @@ def delete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
**delete_kwargs: Any,
) -> None:
"""
Expand All @@ -1012,7 +1030,10 @@ def delete_nodes(
if node_ids is None and filters is None:
raise ValueError("Either node_ids or filters must be provided")

filter = self._build_filter_delete_query(node_ids, filters)
user_filter = self._build_filter_delete_query(node_ids, filters)
filter = f'({self._field_mapping["index_id"]} eq \'{index_id}\')'
if user_filter:
filter += f' and ({user_filter})'

batch_size = 1000

Expand All @@ -1038,6 +1059,7 @@ def delete_nodes(
async def adelete_nodes(
self,
node_ids: Optional[List[str]] = None,
index_id: Optional[str] = None,
filters: Optional[MetadataFilters] = None,
**delete_kwargs: Any,
) -> None:
Expand All @@ -1047,7 +1069,10 @@ async def adelete_nodes(
if node_ids is None and filters is None:
raise ValueError("Either node_ids or filters must be provided")

filter = self._build_filter_delete_query(node_ids, filters)
user_filter = self._build_filter_delete_query(node_ids, filters)
filter = f'({self._field_mapping["index_id"]} eq \'{index_id}\')'
if user_filter:
filter += f' and ({user_filter})'

batch_size = 1000

Expand Down Expand Up @@ -1155,10 +1180,11 @@ def _create_odata_filter(self, metadata_filters: MetadataFilters) -> str:

return odata_expr

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
odata_filter = None
def query(self, query: VectorStoreQuery, index_id: Optional[str] = None, **kwargs: Any) -> VectorStoreQueryResult:
odata_filter = f'{self._field_mapping["index_id"]} eq \'{index_id}\''
if query.filters is not None:
odata_filter = self._create_odata_filter(query.filters)
odata_filter = f'({odata_filter}) and ({self._create_odata_filter(query.filters)})'

azure_query_result_search: AzureQueryResultSearchBase = (
AzureQueryResultSearchDefault(
query, self._field_mapping, odata_filter, self._search_client
Expand All @@ -1179,17 +1205,17 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
return azure_query_result_search.search()

async def aquery(
self, query: VectorStoreQuery, **kwargs: Any
self, query: VectorStoreQuery, index_id: Optional[str] = None, **kwargs: Any
) -> VectorStoreQueryResult:
odata_filter = None
odata_filter = f'{self._field_mapping["index_id"]} eq {index_id}'

# NOTE: users can provide odata_filters directly to the query
odata_filters = kwargs.get("odata_filters")
if odata_filters is not None:
odata_filter = odata_filter
odata_filter = f'({odata_filter}) and ({odata_filters})'
else:
if query.filters is not None:
odata_filter = self._create_odata_filter(query.filters)
odata_filter = f'({odata_filter}) and ({self._create_odata_filter(query.filters)})'

azure_query_result_search: AzureQueryResultSearchBase = (
AzureQueryResultSearchDefault(
Expand Down Expand Up @@ -1245,6 +1271,7 @@ def get_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
limit: Optional[int] = None,
) -> List[BaseNode]:
"""Get nodes from the Azure AI Search index.
Expand All @@ -1260,7 +1287,11 @@ def get_nodes(
if not self._search_client:
raise ValueError("Search client not initialized")

filter_str = self._build_filter_str(self._field_mapping, node_ids, filters)
user_filter_str = self._build_filter_str(self._field_mapping, node_ids, filters)
if user_filter_str:
filter_str = f'({self._field_mapping["index_id"]} eq \'{index_id}\') and ({user_filter_str})'
else:
filter_str = f'{self._field_mapping["index_id"]} eq \'{index_id}\''
nodes = []
batch_size = 1000 # Azure Search batch size limit

Expand Down Expand Up @@ -1291,6 +1322,7 @@ async def aget_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
index_id: Optional[str] = None,
limit: Optional[int] = None,
) -> List[BaseNode]:
"""Get nodes asynchronously from the Azure AI Search index.
Expand All @@ -1306,7 +1338,8 @@ async def aget_nodes(
if not self._async_search_client:
raise ValueError("Async Search client not initialized")

filter_str = self._build_filter_str(self._field_mapping, node_ids, filters)
user_filter_str = self._build_filter_str(self._field_mapping, node_ids, filters)
filter_str = f'({self._field_mapping["index_id"]} eq \'{index_id}\') and ({user_filter_str})'
nodes = []
batch_size = 1000 # Azure Search batch size limit

Expand Down
Loading