Skip to content

Commit

Permalink
fix(AbstractGraph): Bedrock init issues
Browse files Browse the repository at this point in the history
Closes #633
  • Loading branch information
f-aguzzi committed Sep 5, 2024
1 parent 50c9c6b commit 63a5d18
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _create_llm(self, llm_config: dict) -> object:
return llm_params["model_instance"]

known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}

split_model_provider = llm_params["model"].split("/", 1)
Expand All @@ -146,6 +146,8 @@ def _create_llm(self, llm_config: dict) -> object:

try:
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek", "togetherai"}:
if llm_params["model_provider"] == "bedrock":
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return init_chat_model(**llm_params)
Expand Down
2 changes: 2 additions & 0 deletions tests/graphs/abstract_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_ollama import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_aws import ChatBedrock



Expand Down Expand Up @@ -71,6 +72,7 @@ class TestAbstractGraph:
({"model": "ollama/llama2"}, ChatOllama),
({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi),
({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek),
({"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "IDK"}, ChatBedrock),
])

def test_create_llm(self, llm_config, expected_model):
Expand Down

0 comments on commit 63a5d18

Please sign in to comment.