Skip to content

Commit

Permalink
langchain[patch]: Add async methods to EmbeddingRouterChain (langchai…
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet authored Mar 26, 2024
1 parent b3d7b5a commit 7c2578b
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion libs/langchain/langchain/chains/router/embedding_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from typing import Any, Dict, List, Optional, Sequence, Tuple, Type

from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Extra
Expand Down Expand Up @@ -40,6 +43,15 @@ def _call(
results = self.vectorstore.similarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys])
results = await self.vectorstore.asimilarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}

@classmethod
def from_names_and_descriptions(
cls,
Expand All @@ -57,3 +69,21 @@ def from_names_and_descriptions(
)
vectorstore = vectorstore_cls.from_documents(documents, embeddings)
return cls(vectorstore=vectorstore, **kwargs)

@classmethod
async def afrom_names_and_descriptions(
cls,
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
vectorstore_cls: Type[VectorStore],
embeddings: Embeddings,
**kwargs: Any,
) -> EmbeddingRouterChain:
"""Convenience constructor."""
documents = []
for name, descriptions in names_and_descriptions:
for description in descriptions:
documents.append(
Document(page_content=description, metadata={"name": name})
)
vectorstore = await vectorstore_cls.afrom_documents(documents, embeddings)
return cls(vectorstore=vectorstore, **kwargs)

0 comments on commit 7c2578b

Please sign in to comment.