Skip to content

Commit

Permalink
update mistral AI agent and new examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Phicks-debug committed Nov 11, 2024
1 parent 485be61 commit 59900a0
Show file tree
Hide file tree
Showing 19 changed files with 591 additions and 190 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ and more to come, we are working on it :)

## Requirements

- Python 3.7+
- Python 3.9+
- pydantic>=2.0.0
- boto3>=1.18.0
- botocore>=1.21.0
Expand Down
87 changes: 55 additions & 32 deletions build/lib/bedrock_llm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .types.enums import ModelName, StopReason
from .config.base import RetryConfig
from .config.model import ModelConfig
from .schema.message import MessageBlock, ToolUseBlock, ToolResultBlock
from .schema.message import MessageBlock, ToolUseBlock, ToolResultBlock, ToolCallBlock
from .schema.tools import ToolMetadata

from typing import Dict, Any, AsyncGenerator, Tuple, Optional, List, Union
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(

async def __process_tools(
self,
tools_list: List[ToolUseBlock]
tools_list: Union[List[ToolUseBlock], List[ToolCallBlock]]
) -> MessageBlock:
"""
Process a list of tool use requests and return the results.
Expand All @@ -84,34 +84,56 @@ async def __process_tools(
If a tool is not found or an error occurs during execution, an error message
is included in the result.
"""
message = MessageBlock(role="user", content=[])

if isinstance(tools_list[0], ToolUseBlock):
message = MessageBlock(role="user", content=[])
state=1
else:
message = []
state=0

for tool in tools_list:
if not isinstance(tool, ToolUseBlock):
if not isinstance(tool, ToolUseBlock) and not isinstance(tool, ToolCallBlock):
continue

if state: # Process tool in Claude, Llama Way
tool_name = tool.name
tool_data = self.tool_functions.get(tool_name)

tool_name = tool.name
tool_data = self.tool_functions.get(tool_name)

if tool_data:
try:
result = await tool_data["function"](**tool.input) if tool_data["is_async"] else tool_data["function"](**tool.input)
is_error = False
except Exception as e:
result = str(e)
if tool_data:
try:
result = await tool_data["function"](**tool.input) if tool_data["is_async"] else tool_data["function"](**tool.input)
is_error = False
except Exception as e:
result = str(e)
is_error = True
else:
result = f"Tool {tool_name} not found"
is_error = True
else:
result = f"Tool {tool_name} not found"
is_error = True

message.content.append(
ToolResultBlock(
type="tool_result",
tool_use_id=tool.id,
is_error=is_error,
content=str(result)

message.content.append(
ToolResultBlock(
type="tool_result",
tool_use_id=tool.id,
is_error=is_error,
content=str(result)
)
)
else: # Process tool in Mistral AI, Jamaba Way
tool_name = tool.function
tool_params = eval(tool_name["arguments"])
tool_data = self.tool_functions.get(tool_name["name"])

if tool_data:
try:
result = await tool_data["function"](**tool_params) if tool_data["is_async"] else tool_data["function"](**tool_params)
except Exception as e:
result = str(e)
else:
result = f"Tool {tool_name} not found"

message.append(
MessageBlock(role="tool", name=tool_name["name"], content=result, tool_call_id=tool.id)
)
)

return message

Expand Down Expand Up @@ -168,9 +190,13 @@ async def generate_and_action_async(
yield token, None, None, None
elif stop_reason == StopReason.TOOL_USE:
yield None, stop_reason, response, None
result = await self.__process_tools(response.content)
yield None, None, None, result.content
self.memory.append(result.model_dump())
result = await self.__process_tools(response.content if not response.tool_calls else response.tool_calls)
if isinstance(result, list):
yield None, None, None, result
self.memory.extend(result)
else:
yield None, None, None, result.content
self.memory.append(result.model_dump())
break
else:
yield None, stop_reason, response, None
Expand All @@ -188,7 +214,4 @@ def _update_memory(self, prompt: Union[str, MessageBlock, List[MessageBlock]]) -
elif isinstance(prompt, list):
self.memory.extend(prompt)
else:
raise ValueError("Invalid prompt format")



raise ValueError("Invalid prompt format")
2 changes: 1 addition & 1 deletion build/lib/bedrock_llm/models/ai21.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, AsyncGenerator, Tuple, List, Dict, Optional, Union

from ..models.base import BaseModelImplementation, ModelConfig
from ..schema.message import MessageBlock, DocumentBlock, SystemBlock
from ..schema.message import MessageBlock, SystemBlock
from ..schema.tools import ToolMetadata
from ..types.enums import StopReason

Expand Down
147 changes: 116 additions & 31 deletions build/lib/bedrock_llm/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, AsyncGenerator, Optional, List, Dict, Tuple, Union

from ..models.base import BaseModelImplementation, ModelConfig
from ..schema.message import MessageBlock, SystemBlock
from ..schema.message import MessageBlock, SystemBlock, ToolCallBlock, TextBlock
from ..schema.tools import ToolMetadata
from ..types.enums import ToolChoiceEnum, StopReason

Expand All @@ -15,26 +15,79 @@ class MistralChatImplementation(BaseModelImplementation):
Read more: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-large-2407.html
"""

def _parse_tool_metadata(self, tool: Union[ToolMetadata, Dict[str, Any]]) -> Dict[str, Any]:
"""
Parse a ToolMetadata object or a dictionary into the format required by the Mistral model.
"""

if isinstance(tool, dict):
# Handle all dictionary inputs consistently
if "type" in tool and tool["type"] == "function":
function_data = tool.get("function", {})
else:
function_data = tool

return {
"type": "function",
"function": {
"name": function_data.get("name", "unnamed_function"),
"description": function_data.get("description", "No description provided"),
"parameters": function_data.get("input_schema", {
"type": "object",
"properties": {},
"required": []
})
}
}

if isinstance(tool, ToolMetadata):
mistral_tool = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}

if tool.input_schema:
for prop_name, prop_attr in tool.input_schema.properties.items():
mistral_tool["function"]["parameters"]["properties"][prop_name] = {
"type": prop_attr.type,
"description": prop_attr.description
}

if tool.input_schema.required:
mistral_tool["function"]["parameters"]["required"] = tool.input_schema.required

return mistral_tool

raise ValueError(f"Unsupported tool type: {type(tool)}. Expected Dict or ToolMetadata.")

def prepare_request(
self,
config: ModelConfig,
prompt: Union[str, MessageBlock, List[Dict]],
system: Optional[Union[str, SystemBlock]] = None,
tools: Optional[Union[List[ToolMetadata], List[Dict]]] = None,
tools: Optional[Union[List[ToolMetadata], List[Dict], ToolMetadata, Dict]] = None,
tool_choice: Optional[ToolChoiceEnum] = None,
**kwargs
) -> Dict[str, Any]:
if tools and not isinstance(tools, (list, dict)):
raise ValueError("Tools must be a list or a dictionary.")

if tools and not isinstance(tools, (list, dict, ToolMetadata)):
raise ValueError("Tools must be a list, dictionary, or ToolMetadata object.")

messages = []
if isinstance(prompt, str):
messages.append(MessageBlock(role="user", content=prompt).model_dump())
elif isinstance(prompt, MessageBlock):
messages.append(prompt.model_dump())
else:
messages.extend(prompt)
elif isinstance(prompt, list):
messages.extend([msg.model_dump() if isinstance(msg, MessageBlock) else msg for msg in prompt])

if system is not None:
system_content = system.text if isinstance(system, SystemBlock) else system
Expand All @@ -49,7 +102,18 @@ def prepare_request(
}

if tools is not None:
request_body["tools"] = tools if isinstance(tools, list) else [tools]
if isinstance(tools, (dict, ToolMetadata)):
parsed_tools = [self._parse_tool_metadata(tools)]
elif isinstance(tools, list):
parsed_tools = []
for tool in tools:
if isinstance(tool, (dict, ToolMetadata)):
parsed_tools.append(self._parse_tool_metadata(tool))
else:
raise ValueError(f"Unsupported tool type in list: {type(tool)}. Expected Dict or ToolMetadata.")
else:
raise ValueError(f"Unsupported tools type: {type(tools)}. Expected List, Dict, or ToolMetadata.")
request_body["tools"] = parsed_tools

if tool_choice is not None:
request_body["tool_choice"] = tool_choice
Expand All @@ -61,27 +125,21 @@ async def prepare_request_async(
config: ModelConfig,
prompt: Union[str, MessageBlock, List[Dict]],
system: Optional[Union[str, SystemBlock]] = None,
documents: Optional[Union[List[str], Dict, str]] = None,
tools: Optional[Union[List[ToolMetadata], List[Dict]]] = None,
tools: Optional[Union[List[ToolMetadata], List[Dict], ToolMetadata, Dict]] = None,
tool_choice: Optional[ToolChoiceEnum] = None,
**kwargs
) -> Dict[str, Any]:

if documents:
raise ValueError("Mistral Large 2 does not support documents RAG, please use Agent RAG features")
if tools and not isinstance(tools, (list, dict, ToolMetadata)):
raise ValueError("Tools must be a list, dictionary, or ToolMetadata object.")

messages = []
if isinstance(prompt, str):
messages.append(
MessageBlock(
role="user",
content=prompt
).model_dump()
)
messages.append(MessageBlock(role="user", content=prompt).model_dump())
elif isinstance(prompt, MessageBlock):
messages.append(prompt.model_dump())
else:
messages.extend(prompt)
elif isinstance(prompt, list):
messages.extend([msg.model_dump() if isinstance(msg, MessageBlock) else msg for msg in prompt])

if system is not None:
if isinstance(system, SystemBlock):
Expand All @@ -101,13 +159,18 @@ async def prepare_request_async(

# Conditionally add tools and tool_choice if they are not None
if tools is not None:
if isinstance(tools, dict):
tools = [tools]
request_body["tools"] = tools

if tool_choice is not None:
request_body["tool_choice"] = tool_choice

if isinstance(tools, (dict, ToolMetadata)):
parsed_tools = [self._parse_tool_metadata(tools)]
elif isinstance(tools, list):
parsed_tools = []
for tool in tools:
if isinstance(tool, (dict, ToolMetadata)):
parsed_tools.append(self._parse_tool_metadata(tool))
else:
raise ValueError(f"Unsupported tool type in list: {type(tool)}. Expected Dict or ToolMetadata.")
else:
raise ValueError(f"Unsupported tools type: {type(tools)}. Expected List, Dict, or ToolMetadata.")
request_body["tools"] = parsed_tools
return request_body

def parse_response(
Expand Down Expand Up @@ -139,23 +202,45 @@ async def parse_stream_response(
for event in stream:
chunk = json.loads(event["chunk"]["bytes"])
chunk = chunk["choices"][0]
if chunk["stop_reason"]:
if chunk["stop_reason"]:
content = "".join(full_response) if full_response else ""
message = MessageBlock(
role="assistant",
content="".join(full_response)
content=[TextBlock(type="text", text=content)] if content else None
)
if chunk["stop_reason"] == "stop":
yield None, StopReason.END_TURN, message
elif chunk["stop_reason"] == "tool_calls":
if "tool_calls" in chunk["message"]:
tool_calls = [
ToolCallBlock(
id=tool_call["id"],
type=tool_call["type"],
function=tool_call["function"]
) for tool_call in chunk["message"]["tool_calls"]
]
message.tool_calls = tool_calls
yield None, StopReason.TOOL_USE, message
elif chunk["stop_reason"] == "length":
yield None, StopReason.MAX_TOKENS, message
else:
yield None, StopReason.ERROR, message
return
else:
yield chunk["message"]["content"], None, None
full_response.append(chunk["message"]["content"])
if "content" in chunk["message"] and chunk["message"]["content"]:
yield chunk["message"]["content"], None, None
full_response.append(chunk["message"]["content"])
elif "tool_calls" in chunk["message"]:
# Handle streaming tool calls
tool_calls = [
ToolCallBlock(
id=tool_call["id"],
type=tool_call["type"],
function=tool_call["function"]
) for tool_call in chunk["message"]["tool_calls"]
]
message = MessageBlock(role="assistant", content=[TextBlock(type="text", text="")], tool_calls=tool_calls)
yield None, None, message


class MistralInstructImplementation(BaseModelImplementation):
Expand Down
Loading

0 comments on commit 59900a0

Please sign in to comment.