From 3479fbed010e01bacb090437bdb45d214c368623 Mon Sep 17 00:00:00 2001 From: Aaron Chen Date: Sat, 16 Mar 2024 13:45:29 -0500 Subject: [PATCH 1/3] Pass aws_region_name to get_aws_service_client() --- .../llama_index/llms/sagemaker_endpoint/base.py | 4 ++-- .../llms/llama-index-llms-sagemaker-endpoint/pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py index e4234175067d8..d064a1439902f 100644 --- a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py +++ b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py @@ -83,7 +83,7 @@ def __init__( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, - region_name: Optional[str] = None, + aws_region_name: Optional[str] = None, max_retries: Optional[int] = 3, timeout: Optional[float] = 60.0, temperature: Optional[float] = 0.5, @@ -112,7 +112,7 @@ def __init__( self._client = get_aws_service_client( service_name="sagemaker-runtime", profile_name=profile_name, - region_name=region_name, + region_name=aws_region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, diff --git a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml index e40a9a0969cdb..6fc51919bfd54 100644 --- a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-sagemaker-endpoint" readme = "README.md" -version = "0.1.3" +version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" From f096f5fe762989c74d20bed668f695c49a5ef591 Mon Sep 17 00:00:00 2001 From: Aaron Chen Date: Sun, 7 Apr 2024 01:00:11 -0500 Subject: [PATCH 2/3] Allow for old kwarg to be passed with deprecation warning --- .../llama_index/llms/sagemaker_endpoint/base.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py index d064a1439902f..34294119213cd 100644 --- a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py +++ b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py @@ -29,6 +29,8 @@ ) from llama_index.llms.sagemaker_endpoint.utils import BaseIOHandler, IOHandler +import warnings + DEFAULT_IO_HANDLER = IOHandler() LLAMA_MESSAGES_TO_PROMPT = messages_to_prompt LLAMA_COMPLETION_TO_PROMPT = completion_to_prompt @@ -109,6 +111,17 @@ def __init__( model_kwargs["temperature"] = temperature content_handler = content_handler self._completion_to_prompt = completion_to_prompt + + region_name = kwargs.pop('region_name', None) + if region_name is not None: + warnings.warn( + "Kwarg `region_name` is deprecated and will be removed in a future version. " + "Please use `aws_region_name` instead.", + DeprecationWarning + ) + if not aws_region_name: + aws_region_name = region_name + self._client = get_aws_service_client( service_name="sagemaker-runtime", profile_name=profile_name, From c0c5fd694cb778e4da7ee2753b475a02e6474587 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Mon, 30 Dec 2024 11:30:19 -0600 Subject: [PATCH 3/3] linting --- .../llama_index/llms/sagemaker_endpoint/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py index 68a3f22fd8a68..cf49109186ce8 100644 --- a/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py +++ b/llama-index-integrations/llms/llama-index-llms-sagemaker-endpoint/llama_index/llms/sagemaker_endpoint/base.py @@ -158,13 +158,13 @@ def __init__( model_kwargs["temperature"] = temperature content_handler = content_handler self._completion_to_prompt = completion_to_prompt - - region_name = kwargs.pop('region_name', None) + + region_name = kwargs.pop("region_name", None) if region_name is not None: warnings.warn( "Kwarg `region_name` is deprecated and will be removed in a future version. " "Please use `aws_region_name` instead.", - DeprecationWarning + DeprecationWarning, ) if not aws_region_name: aws_region_name = region_name