Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Amazon bedrock guardrails #17281

Merged
merged 10 commits into from
Dec 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ class BedrockConverse(FunctionCallingLLM):
default=60.0,
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
)
guardrail_identifier: Optional[str] = (
Field(
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
),
)
guardrail_version: Optional[str] = (
Field(
description="The version number for the guardrail. The value can also be DRAFT"
),
)
trace: Optional[str] = (
Field(
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
),
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Additional kwargs for the bedrock invokeModel request.",
Expand Down Expand Up @@ -145,6 +160,9 @@ def __init__(
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
guardrail_identifier: Optional[str] = None,
guardrail_version: Optional[str] = None,
trace: Optional[str] = None,
) -> None:
additional_kwargs = additional_kwargs or {}
callback_manager = callback_manager or CallbackManager([])
Expand Down Expand Up @@ -178,6 +196,9 @@ def __init__(
region_name=region_name,
botocore_session=botocore_session,
botocore_config=botocore_config,
guardrail_identifier=guardrail_identifier,
guardrail_version=guardrail_version,
trace=trace,
)

self._config = None
Expand Down Expand Up @@ -292,6 +313,9 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
system_prompt=self.system_prompt,
max_retries=self.max_retries,
stream=False,
guardrail_identifier=self.guardrail_identifier,
guardrail_version=self.guardrail_version,
trace=self.trace,
**all_kwargs,
)

Expand Down Expand Up @@ -336,6 +360,9 @@ def stream_chat(
system_prompt=self.system_prompt,
max_retries=self.max_retries,
stream=True,
guardrail_identifier=self.guardrail_identifier,
guardrail_version=self.guardrail_version,
trace=self.trace,
**all_kwargs,
)

Expand Down Expand Up @@ -416,6 +443,9 @@ async def achat(
system_prompt=self.system_prompt,
max_retries=self.max_retries,
stream=False,
guardrail_identifier=self.guardrail_identifier,
guardrail_version=self.guardrail_version,
trace=self.trace,
**all_kwargs,
)

Expand Down Expand Up @@ -461,6 +491,9 @@ async def astream_chat(
system_prompt=self.system_prompt,
max_retries=self.max_retries,
stream=True,
guardrail_identifier=self.guardrail_identifier,
guardrail_version=self.guardrail_version,
trace=self.trace,
**all_kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def converse_with_retry(
max_tokens: int = 1000,
temperature: float = 0.1,
stream: bool = False,
guardrail_identifier: Optional[str] = None,
guardrail_version: Optional[str] = None,
trace: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
Expand All @@ -323,8 +326,24 @@ def converse_with_retry(
converse_kwargs["system"] = [{"text": system_prompt}]
if tool_config := kwargs.get("tools"):
converse_kwargs["toolConfig"] = tool_config
if guardrail_identifier and guardrail_version:
converse_kwargs["guardrailConfig"] = {}
converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier
converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version
if trace:
converse_kwargs["guardrailConfig"]["trace"] = trace
converse_kwargs = join_two_dicts(
converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"}
converse_kwargs,
{
k: v
for k, v in kwargs.items()
if (
k != "tools"
or k != "guardrail_identifier"
or k != "guardrail_version"
or k != "trace"
)
},
)

@retry_decorator
Expand All @@ -346,6 +365,9 @@ async def converse_with_retry_async(
max_tokens: int = 1000,
temperature: float = 0.1,
stream: bool = False,
guardrail_identifier: Optional[str] = None,
guardrail_version: Optional[str] = None,
trace: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
Expand All @@ -362,8 +384,24 @@ async def converse_with_retry_async(
converse_kwargs["system"] = [{"text": system_prompt}]
if tool_config := kwargs.get("tools"):
converse_kwargs["toolConfig"] = tool_config
if guardrail_identifier and guardrail_version:
converse_kwargs["guardrailConfig"] = {}
converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier
converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version
if trace:
converse_kwargs["guardrailConfig"]["trace"] = trace
converse_kwargs = join_two_dicts(
converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"}
converse_kwargs,
{
k: v
for k, v in kwargs.items()
if (
k != "tools"
or k != "guardrail_identifier"
or k != "guardrail_version"
or k != "trace"
)
},
)

## NOTE: Returning the generator directly from converse_stream doesn't work
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-bedrock-converse"
readme = "README.md"
version = "0.4.1"
version = "0.4.2"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
EXP_MAX_TOKENS = 100
EXP_TEMPERATURE = 0.7
EXP_MODEL = "anthropic.claude-v2"
EXP_GUARDRAIL_ID = "IDENTIFIER"
EXP_GUARDRAIL_VERSION = "DRAFT"
EXP_GUARDRAIL_TRACE = "ENABLED"

# Reused chat message and prompt
messages = [ChatMessage(role=MessageRole.USER, content="Test")]
Expand Down Expand Up @@ -88,6 +91,9 @@ def bedrock_converse(mock_boto3_session, mock_aioboto3_session):
model=EXP_MODEL,
max_tokens=EXP_MAX_TOKENS,
temperature=EXP_TEMPERATURE,
guardrail_identifier=EXP_GUARDRAIL_ID,
guardrail_version=EXP_GUARDRAIL_VERSION,
trace=EXP_GUARDRAIL_TRACE,
callback_manager=CallbackManager(),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ class Bedrock(LLM):
default=60.0,
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
)
guardrail_identifier: Optional[str] = (
Field(
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
),
)
guardrail_version: Optional[str] = (
Field(
description="The version number for the guardrail. The value can also be DRAFT"
),
)
trace: Optional[str] = (
Field(
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
),
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Additional kwargs for the bedrock invokeModel request.",
Expand Down Expand Up @@ -125,6 +140,9 @@ def __init__(
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
guardrail_identifier: Optional[str] = None,
guardrail_version: Optional[str] = None,
trace: Optional[str] = None,
**kwargs: Any,
) -> None:
if context_size is None and model not in BEDROCK_FOUNDATION_LLMS:
Expand Down Expand Up @@ -187,6 +205,9 @@ def __init__(
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
guardrail_identifier=guardrail_identifier,
guardrail_version=guardrail_version,
trace=trace,
)
self._provider = get_provider(model)
self.messages_to_prompt = (
Expand Down Expand Up @@ -257,6 +278,9 @@ def complete(
model=self.model,
request_body=request_body_str,
max_retries=self.max_retries,
guardrail_identifier=self.guardrail_identifier,
guardrail_version=self.guardrail_version,
trace=self.trace,
**all_kwargs,
)
response_body = response["body"].read()
Expand Down Expand Up @@ -287,6 +311,9 @@ def stream_complete(
request_body=request_body_str,
max_retries=self.max_retries,
stream=True,
guardrail_identifier=self.guardrail_identifier,
guardrail_version=self.guardrail_version,
trace=self.trace,
**all_kwargs,
)
response_body = response["body"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def completion_with_retry(
request_body: str,
max_retries: int,
stream: bool = False,
guardrail_identifier: Optional[str] = None,
guardrail_version: Optional[str] = None,
trace: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
Expand All @@ -307,9 +310,29 @@ def completion_with_retry(
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if stream:
return client.invoke_model_with_response_stream(
modelId=model, body=request_body
)
return client.invoke_model(modelId=model, body=request_body)
if guardrail_identifier is None or guardrail_version is None:
return client.invoke_model_with_response_stream(
modelId=model,
body=request_body,
)
else:
return client.invoke_model_with_response_stream(
modelId=model,
body=request_body,
guardrailIdentifier=guardrail_identifier,
guardrailVersion=guardrail_version,
trace=trace,
)
else:
if guardrail_identifier is None or guardrail_version is None:
return client.invoke_model(modelId=model, body=request_body)
else:
return client.invoke_model(
modelId=model,
body=request_body,
guardrailIdentifier=guardrail_identifier,
guardrailVersion=guardrail_version,
trace=trace,
)

return _completion_with_retry(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-bedrock"
readme = "README.md"
version = "0.3.2"
version = "0.3.3"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def test_model_basic(
profile_name=None,
region_name="us-east-1",
aws_access_key_id="test",
guardrail_identifier="test",
guardrail_version="test",
trace="ENABLED",
)

bedrock_stubber = Stubber(llm._client)
Expand All @@ -155,13 +158,25 @@ def test_model_basic(
bedrock_stubber.add_response(
"invoke_model",
get_invoke_model_response(response_body),
{"body": complete_request, "modelId": model},
{
"body": complete_request,
"modelId": model,
"guardrailIdentifier": "test",
"guardrailVersion": "test",
"trace": "ENABLED",
},
)
# response for llm.chat()
bedrock_stubber.add_response(
"invoke_model",
get_invoke_model_response(response_body),
{"body": chat_request, "modelId": model},
{
"body": chat_request,
"modelId": model,
"guardrailIdentifier": "test",
"guardrailVersion": "test",
"trace": "ENABLED",
},
)

bedrock_stubber.activate()
Expand Down
Loading