From 63a5d18486789ce1b4a8f5ea661fc83779fceca2 Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Thu, 5 Sep 2024 10:19:47 +0200 Subject: [PATCH] fix(AbstractGraph): Bedrock init issues Closes #633 --- scrapegraphai/graphs/abstract_graph.py | 4 +++- tests/graphs/abstract_graph_test.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 27ba9f2c..1a4f1e6a 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -128,7 +128,7 @@ def _create_llm(self, llm_config: dict) -> object: return llm_params["model_instance"] known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai", - "ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai", + "ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks", "togetherai"} split_model_provider = llm_params["model"].split("/", 1) @@ -146,6 +146,8 @@ def _create_llm(self, llm_config: dict) -> object: try: if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek", "togetherai"}: + if llm_params["model_provider"] == "bedrock": + llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") } with warnings.catch_warnings(): warnings.simplefilter("ignore") return init_chat_model(**llm_params) diff --git a/tests/graphs/abstract_graph_test.py b/tests/graphs/abstract_graph_test.py index 606f1346..54349d22 100644 --- a/tests/graphs/abstract_graph_test.py +++ b/tests/graphs/abstract_graph_test.py @@ -12,6 +12,7 @@ from langchain_openai import ChatOpenAI, AzureChatOpenAI from langchain_ollama import ChatOllama from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_aws import ChatBedrock @@ -71,6 +72,7 @@ class TestAbstractGraph: ({"model": "ollama/llama2"}, ChatOllama), ({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi), ({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek), + ({"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "IDK"}, ChatBedrock), ]) def test_create_llm(self, llm_config, expected_model):