Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich committed Dec 17, 2024
1 parent fb54cc7 commit e83280a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,21 @@ class BedrockConverse(FunctionCallingLLM):
default=60.0,
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
)
guardrail_identifier: Optional[str] = Field(
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
),
guardrail_version: Optional[str] = Field(
description="The version number for the guardrail. The value can also be DRAFT"
),
trace: Optional[str] = Field(
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
),
guardrail_identifier: Optional[str] = (
Field(
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
),
)
guardrail_version: Optional[str] = (
Field(
description="The version number for the guardrail. The value can also be DRAFT"
),
)
trace: Optional[str] = (
Field(
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
),
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Additional kwargs for the bedrock invokeModel request.",
Expand Down Expand Up @@ -192,7 +198,7 @@ def __init__(
botocore_config=botocore_config,
guardrail_identifier=guardrail_identifier,
guardrail_version=guardrail_version,
trace=trace
trace=trace,
)

self._config = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,23 @@ def converse_with_retry(
if tool_config := kwargs.get("tools"):
converse_kwargs["toolConfig"] = tool_config
if guardrail_identifier and guardrail_version:
converse_kwargs['guardrailConfig'] = {}
converse_kwargs["guardrailConfig"] = {}
converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier
converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version
if trace:
converse_kwargs["guardrailConfig"]["trace"] = trace
converse_kwargs = join_two_dicts(
converse_kwargs, {k: v for k, v in kwargs.items() if (k != "tools" or k != "guardrail_identifier" or k != "guardrail_version" or k != "trace")}
converse_kwargs,
{
k: v
for k, v in kwargs.items()
if (
k != "tools"
or k != "guardrail_identifier"
or k != "guardrail_version"
or k != "trace"
)
},
)

@retry_decorator
Expand Down Expand Up @@ -375,13 +385,23 @@ async def converse_with_retry_async(
if tool_config := kwargs.get("tools"):
converse_kwargs["toolConfig"] = tool_config
if guardrail_identifier and guardrail_version:
converse_kwargs['guardrailConfig'] = {}
converse_kwargs["guardrailConfig"] = {}
converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier
converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version
if trace:
converse_kwargs["guardrailConfig"]["trace"] = trace
converse_kwargs = join_two_dicts(
converse_kwargs, {k: v for k, v in kwargs.items() if (k != "tools" or k != "guardrail_identifier" or k != "guardrail_version" or k != "trace")}
converse_kwargs,
{
k: v
for k, v in kwargs.items()
if (
k != "tools"
or k != "guardrail_identifier"
or k != "guardrail_version"
or k != "trace"
)
},
)

## NOTE: Returning the generator directly from converse_stream doesn't work
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,21 @@ class Bedrock(LLM):
default=60.0,
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
)
guardrail_identifier: Optional[str] = Field(
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
),
guardrail_version: Optional[str] = Field(
description="The version number for the guardrail. The value can also be DRAFT"
),
trace: Optional[str] = Field(
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
),
guardrail_identifier: Optional[str] = (
Field(
description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation."
),
)
guardrail_version: Optional[str] = (
Field(
description="The version number for the guardrail. The value can also be DRAFT"
),
)
trace: Optional[str] = (
Field(
description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace."
),
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Additional kwargs for the bedrock invokeModel request.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def completion_with_retry(
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if stream:
if (guardrail_identifier == None or guardrail_version == None):
if guardrail_identifier is None or guardrail_version is None:
return client.invoke_model_with_response_stream(
modelId=model,
body=request_body,
Expand All @@ -321,21 +321,18 @@ def _completion_with_retry(**kwargs: Any) -> Any:
body=request_body,
guardrailIdentifier=guardrail_identifier,
guardrailVersion=guardrail_version,
trace=trace
trace=trace,
)
else:
if (guardrail_identifier == None or guardrail_version == None):
return client.invoke_model(
modelId=model,
body=request_body
)
if guardrail_identifier is None or guardrail_version is None:
return client.invoke_model(modelId=model, body=request_body)
else:
return client.invoke_model(
modelId=model,
body=request_body,
guardrailIdentifier=guardrail_identifier,
guardrailVersion=guardrail_version,
trace=trace
trace=trace,
)

return _completion_with_retry(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,25 @@ def test_model_basic(
bedrock_stubber.add_response(
"invoke_model",
get_invoke_model_response(response_body),
{"body": complete_request, "modelId": model, "guardrailIdentifier": "test", "guardrailVersion": "test", "trace": "ENABLED"},
{
"body": complete_request,
"modelId": model,
"guardrailIdentifier": "test",
"guardrailVersion": "test",
"trace": "ENABLED",
},
)
# response for llm.chat()
bedrock_stubber.add_response(
"invoke_model",
get_invoke_model_response(response_body),
{"body": chat_request, "modelId": model, "guardrailIdentifier": "test", "guardrailVersion": "test", "trace": "ENABLED"},
{
"body": chat_request,
"modelId": model,
"guardrailIdentifier": "test",
"guardrailVersion": "test",
"trace": "ENABLED",
},
)

bedrock_stubber.activate()
Expand Down

0 comments on commit e83280a

Please sign in to comment.