From 7c2578bd5599278383beb1a2299ea1f40015f884 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 26 Mar 2024 22:33:36 +0100 Subject: [PATCH] langchain[patch]: Add async methods to EmbeddingRouterChain (#19603) --- .../chains/router/embedding_router.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chains/router/embedding_router.py b/libs/langchain/langchain/chains/router/embedding_router.py index b8f9d975bd953..b7fb59f8522f2 100644 --- a/libs/langchain/langchain/chains/router/embedding_router.py +++ b/libs/langchain/langchain/chains/router/embedding_router.py @@ -2,7 +2,10 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type -from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.callbacks import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import Extra @@ -40,6 +43,15 @@ def _call( results = self.vectorstore.similarity_search(_input, k=1) return {"next_inputs": inputs, "destination": results[0].metadata["name"]} + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _input = ", ".join([inputs[k] for k in self.routing_keys]) + results = await self.vectorstore.asimilarity_search(_input, k=1) + return {"next_inputs": inputs, "destination": results[0].metadata["name"]} + @classmethod def from_names_and_descriptions( cls, @@ -57,3 +69,21 @@ def from_names_and_descriptions( ) vectorstore = vectorstore_cls.from_documents(documents, embeddings) return cls(vectorstore=vectorstore, **kwargs) + + @classmethod + async def afrom_names_and_descriptions( + cls, + names_and_descriptions: Sequence[Tuple[str, Sequence[str]]], + vectorstore_cls: Type[VectorStore], + embeddings: Embeddings, + **kwargs: Any, + ) -> EmbeddingRouterChain: + """Convenience constructor.""" + documents = [] + for name, descriptions in names_and_descriptions: + for description in descriptions: + documents.append( + Document(page_content=description, metadata={"name": name}) + ) + vectorstore = await vectorstore_cls.afrom_documents(documents, embeddings) + return cls(vectorstore=vectorstore, **kwargs)