From 61e309cd0c7433a6b0d8b785b429284e50cc1dfb Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 9 Apr 2024 17:23:56 -0400 Subject: [PATCH] x --- langchain_benchmarks/schema.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/langchain_benchmarks/schema.py b/langchain_benchmarks/schema.py index 6c8e111..d0b41bf 100644 --- a/langchain_benchmarks/schema.py +++ b/langchain_benchmarks/schema.py @@ -256,7 +256,13 @@ def add(self, task: BaseTask) -> None: Provider = Literal["fireworks", "openai", "anthropic", "anyscale"] ModelType = Literal["chat", "llm"] -AUTHORIZED_NAMESPACES = {"langchain", "langchain_google_genai"} +AUTHORIZED_NAMESPACES = { + "langchain", + "langchain_google_genai", + "langchain_openai", + "langchain_anthropic", + "langchain_fireworks", +} def _get_model_class_from_path( @@ -273,7 +279,15 @@ def _get_model_class_from_path( ) # Import the module dynamically - module = importlib.import_module(module_name) + try: + module = importlib.import_module(module_name) + except ImportError: + raise ImportError( + f"Could not import module {module_name}. " + f"Perhaps you need to run to pip install the package? " + f"`pip install {module_name}`." + ) + model_class = getattr(module, attribute_name) if not issubclass(model_class, (BaseLanguageModel, BaseChatModel)): raise ValueError( @@ -285,13 +299,13 @@ def _get_model_class_from_path( def _get_default_path(provider: str, type_: ModelType) -> str: """Get the default path for a model.""" paths = { - ("fireworks", "chat"): "langchain.chat_models.fireworks.ChatFireworks", - ("fireworks", "llm"): "langchain.llms.fireworks.Fireworks", + ("anthropic", "chat"): "langchain_anthropic.ChatAnthropic", ("anyscale", "chat"): "langchain.chat_models.anyscale.ChatAnyscale", ("anyscale", "llm"): "langchain.llms.anyscale.Anyscale", - ("openai", "chat"): "langchain.chat_models.openai.ChatOpenAI", - ("openai", "llm"): "langchain.llms.openai.OpenAI", - ("anthropic", "chat"): "langchain.chat_models.anthropic.ChatAnthropic", + ("fireworks", "chat"): "langchain_fireworks.ChatFireworks", + ("fireworks", "llm"): "langchain_fireworks.Fireworks", + ("openai", "chat"): "langchain_openai.ChatOpenAI", + ("openai", "llm"): "langchain_openai.OpenAI", ( "google-genai", "chat",