Skip to content

Commit

Permalink
Merge pull request #26 from zhudotexe/tool-calls
Browse files Browse the repository at this point in the history
ToolCall refactor
  • Loading branch information
zhudotexe authored Nov 8, 2023
2 parents a2ef623 + 6afe648 commit 42ff983
Show file tree
Hide file tree
Showing 20 changed files with 454 additions and 111 deletions.
5 changes: 5 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ Common Models
:exclude-members: model_config, model_fields
:class-doc-from: class

.. autoclass:: kani.ToolCall
:members:
:exclude-members: model_config, model_fields
:class-doc-from: class

.. autoclass:: kani.MessagePart
:members:
:exclude-members: model_config, model_fields
Expand Down
4 changes: 2 additions & 2 deletions docs/customization/chat_history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ For example, here's how you might extend :meth:`.Kani.add_to_history` to log eve
super().__init__(*args, **kwargs)
self.log_file = open("kani-log.jsonl", "w")
async def add_to_history(self, message):
await super().add_to_history(message)
async def add_to_history(self, message, *args, **kwargs):
await super().add_to_history(message, *args, **kwargs)
self.log_file.write(message.model_dump_json())
self.log_file.write("\n")
Expand Down
4 changes: 2 additions & 2 deletions docs/customization/function_call.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ during a conversation, and how often it was successful:
self.successful_calls = collections.Counter()
self.failed_calls = collections.Counter()
async def do_function_call(self, call):
async def do_function_call(self, call, *args, **kwargs):
try:
result = await super().do_function_call(call)
result = await super().do_function_call(call, *args, **kwargs)
self.successful_calls[call.name] += 1
return result
except FunctionCallException:
Expand Down
4 changes: 2 additions & 2 deletions docs/customization/function_exception.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ Here's an example of providing custom prompts on an exception:
:emphasize-lines: 2-10
class CustomExceptionPromptKani(Kani):
async def handle_function_call_exception(self, call, err, attempt):
async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs):
# get the standard retry logic...
result = await super().handle_function_call_exception(call, err, attempt)
result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs)
# but override the returned message with our own
result.message = ChatMessage.system(
"The call encountered an error. "
Expand Down
9 changes: 8 additions & 1 deletion docs/engines/implementing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ You'll need to implement two methods: :meth:`.BaseEngine.predict` and :meth:`.Ba
build such a prompt. :meth:`.BaseEngine.function_token_reserve` tells kani how many tokens that prompt takes, so the
context window management can ensure it never sends too many tokens.

You'll also need to add previous function calls into the prompt (e.g. in the few-shot function calling example).
When you're building the prompt, you'll need to iterate over :attr:`.ChatMessage.tool_calls` if it exists, and add
your model's appropriate function calling prompt.

To parse the model's requests to call a function, you also do this in :meth:`.BaseEngine.predict`. After generating the
model's completion (usually a string, or a list of token IDs that decodes into a string), separate the model's
conversational content from the structured function call:
Expand All @@ -56,4 +60,7 @@ conversational content from the structured function call:
:align: center

Finally, return a :class:`.Completion` with the ``.message`` attribute set to a :class:`.ChatMessage` with the
appropriate :attr:`.ChatMessage.content` and :attr:`.ChatMessage.function_call`.
appropriate :attr:`.ChatMessage.content` and :attr:`.ChatMessage.tool_calls`.

.. note::
See :ref:`functioncall_v_toolcall` for more information about ToolCalls vs FunctionCalls.
129 changes: 97 additions & 32 deletions docs/function_calling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ Next Actor
After a function call returns, kani will hand control back to the LM to generate a response by default. If instead
control should be given to the human (i.e. return from the chat round), set ``after=ChatRole.USER``.

.. note::
If the model calls multiple tools in parallel, the model will be allowed to generate a response if *any* function
has ``after=ChatRole.ASSISTANT`` (the default) once all function calls are complete.

Complete Example
----------------
Here's the full example of how you might implement a function to get weather that we built in the last few steps:
Expand Down Expand Up @@ -182,38 +186,77 @@ prompt a model, we can mock these returns in the chat history using :meth:`.Chat
For example, here's how you might prompt the model to give the temperature in both Fahrenheit and Celsius without
the user having to ask:

.. code-block:: python
from kani import ChatMessage, FunctionCall
fewshot = [
ChatMessage.user("What's the weather in Philadelphia?"),
# first, the model should ask for the weather in fahrenheit
ChatMessage.assistant(
content=None,
function_call=FunctionCall.with_args(
"get_weather", location="Philadelphia, PA", unit="fahrenheit"
)
),
# and we mock the function's response to the model
ChatMessage.function(
"get_weather",
"Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.",
),
# repeat in celsius
ChatMessage.assistant(
content=None,
function_call=FunctionCall.with_args(
"get_weather", location="Philadelphia, PA", unit="celsius"
)
),
ChatMessage.function(
"get_weather",
"Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.",
),
# finally, give the result to the user
ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."),
]
ai = MyKani(engine, chat_history=fewshot)
.. tab:: ToolCall API

.. code-block:: python
# build the chat history with examples
fewshot = [
ChatMessage.user("What's the weather in Philadelphia?"),
ChatMessage.assistant(
content=None,
# use a walrus operator to save a reference to the tool call here...
tool_calls=[
tc := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="fahrenheit")
],
),
ChatMessage.function(
"get_weather",
"Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.",
# ...so this function result knows which call it's responding to
tc.id
),
# and repeat for the other unit
ChatMessage.assistant(
content=None,
tool_calls=[
tc2 := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="celsius")
],
),
ChatMessage.function(
"get_weather",
"Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.",
tc2.id
),
ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."),
]
# and give it to the kani when you initialize it
ai = MyKani(engine, chat_history=fewshot)
.. tab:: FunctionCall API (deprecated)

.. code-block:: python
from kani import ChatMessage, FunctionCall
fewshot = [
ChatMessage.user("What's the weather in Philadelphia?"),
# first, the model should ask for the weather in fahrenheit
ChatMessage.assistant(
content=None,
function_call=FunctionCall.with_args(
"get_weather", location="Philadelphia, PA", unit="fahrenheit"
)
),
# and we mock the function's response to the model
ChatMessage.function(
"get_weather",
"Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.",
),
# repeat in celsius
ChatMessage.assistant(
content=None,
function_call=FunctionCall.with_args(
"get_weather", location="Philadelphia, PA", unit="celsius"
)
),
ChatMessage.function(
"get_weather",
"Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.",
),
# finally, give the result to the user
ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."),
]
ai = MyKani(engine, chat_history=fewshot)
.. code-block:: pycon
Expand Down Expand Up @@ -254,4 +297,26 @@ passing params with invalid, non-coercible types) or the function raises an exce
error in a message to the model by default, allowing it up to *retry_attempts* to correct itself and retry the
call.

.. note::
If the model calls multiple tools in parallel, the model will be allowed a retry if *any* exception handler
allows it. This will only count as 1 retry attempt regardless of the number of functions that raised an exception.

In the next section, we'll discuss how to customize this behaviour, along with other parts of the kani interface.

.. _functioncall_v_toolcall:

Internal Representation
-----------------------

.. versionchanged:: v0.6.0

As of Nov 6, 2023, OpenAI added the ability for a single assistant message to request calling multiple functions in
parallel, and wrapped all function calls in a :class:`.ToolCall` wrapper. In order to add support for this in kani while
maintaining backwards compatibility with OSS function calling models, a :class:`.ChatMessage` actually maintains the
following internal representation:

:attr:`.ChatMessage.function_call` is actually an alias for ``ChatMessage.tool_calls[0].function``. If there is more
than one tool call in the message, kani will raise an exception.

A ToolCall is effectively a named wrapper around a :class:`.FunctionCall`, associating the request with a generated
ID so that its response can be linked to the request in future rounds of prompting.
12 changes: 7 additions & 5 deletions examples/2_function_calling_fewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import Annotated

from kani import AIParam, ChatMessage, FunctionCall, Kani, ai_function, chat_in_terminal
from kani import AIParam, ChatMessage, Kani, ToolCall, ai_function, chat_in_terminal
from kani.engines.openai import OpenAIEngine

api_key = os.getenv("OPENAI_API_KEY")
Expand Down Expand Up @@ -35,14 +35,16 @@ def get_weather(
ChatMessage.user("What's the weather in Philadelphia?"),
ChatMessage.assistant(
content=None,
function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="fahrenheit"),
# use a walrus operator to save a reference to the tool call here...
tool_calls=[tc := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="fahrenheit")],
),
ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit."),
# so this function result knows which call it's responding to
ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", tc.id),
ChatMessage.assistant(
content=None,
function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="celsius"),
tool_calls=[tc2 := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="celsius")],
),
ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius."),
ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", tc2.id),
ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."),
]
# and give it to the kani when you initialize it
Expand Down
4 changes: 2 additions & 2 deletions examples/3_customization_exception_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@


class CustomExceptionPromptKani(Kani):
async def handle_function_call_exception(self, call, err, attempt):
async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs):
# get the standard retry logic...
result = await super().handle_function_call_exception(call, err, attempt)
result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs)
# but override the returned message with our own
result.message = ChatMessage.system(
f"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}"
Expand Down
4 changes: 2 additions & 2 deletions examples/3_customization_track_function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(self, *args, **kwargs):
self.successful_calls = collections.Counter()
self.failed_calls = collections.Counter()

async def do_function_call(self, call):
async def do_function_call(self, call, *args, **kwargs):
try:
result = await super().do_function_call(call)
result = await super().do_function_call(call, *args, **kwargs)
self.successful_calls[call.name] += 1
return result
except FunctionCallException:
Expand Down
12 changes: 6 additions & 6 deletions examples/colab_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@
"\n",
"\n",
"class CustomExceptionKani(Kani):\n",
" async def handle_function_call_exception(self, call, err, attempt):\n",
" async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs):\n",
" # get the standard retry logic...\n",
" result = await super().handle_function_call_exception(call, err, attempt)\n",
" result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs)\n",
" # but override the returned message with our own\n",
" result.message = ChatMessage.system(\n",
" f\"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}\"\n",
Expand Down Expand Up @@ -677,13 +677,13 @@
" self.successful_calls = collections.Counter()\n",
" self.failed_calls = collections.Counter()\n",
"\n",
" async def handle_function_call_exception(self, call, err, attempt):\n",
" msg = ChatMessage.system(str(err))\n",
" async def handle_function_call_exception(self, call, err, attempt, tool_call_id=None):\n",
" msg = ChatMessage.function(name=call.name, content=str(err), tool_call_id=tool_call_id)\n",
" return ExceptionHandleResult(should_retry=attempt < self.retry_attempts, message=msg)\n",
"\n",
" async def do_function_call(self, call):\n",
" async def do_function_call(self, call, *args, **kwargs):\n",
" try:\n",
" res = await super().do_function_call(call)\n",
" res = await super().do_function_call(call, *args, **kwargs)\n",
" self.successful_calls[call.name] += 1\n",
" return res\n",
" except FunctionCallException:\n",
Expand Down
2 changes: 1 addition & 1 deletion kani/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .ai_function import AIFunction, AIParam, ai_function
from .internal import ExceptionHandleResult, FunctionCallResult
from .kani import Kani
from .models import ChatMessage, ChatRole, FunctionCall, MessagePart
from .models import ChatMessage, ChatRole, FunctionCall, MessagePart, ToolCall
from .utils.cli import chat_in_terminal, chat_in_terminal_async

# declare that kani is also a namespace package
Expand Down
38 changes: 32 additions & 6 deletions kani/engines/openai/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import asyncio
import warnings
from typing import Literal, overload

import aiohttp
import pydantic

from .models import ChatCompletion, Completion, FunctionSpec, OpenAIChatMessage, SpecificFunctionCall
from .models import (
ChatCompletion,
Completion,
FunctionSpec,
OpenAIChatMessage,
ResponseFormat,
SpecificFunctionCall,
ToolChoice,
ToolSpec,
)
from ..httpclient import BaseClient, HTTPException, HTTPStatusException, HTTPTimeout


Expand Down Expand Up @@ -77,6 +87,7 @@ async def request(self, method: str, route: str, headers=None, retry=None, **kwa
async def create_completion(
self,
model: str,
*,
prompt: str = "<|endoftext|>",
suffix: str = None,
max_tokens: int = 16,
Expand All @@ -85,6 +96,7 @@ async def create_completion(
n: int = 1,
logprobs: int = None,
echo: bool = False,
seed: int | None = None,
stop: str | list[str] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
Expand All @@ -107,34 +119,48 @@ async def create_chat_completion(
self,
model: str,
messages: list[OpenAIChatMessage],
functions: list[FunctionSpec] | None = None,
function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None,
*,
tools: list[ToolSpec] | None = None,
tool_choice: ToolChoice | Literal["auto"] | Literal["none"] | None = None,
temperature: float = 1.0,
top_p: float = 1.0,
n: int = 1,
response_format: ResponseFormat | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
max_tokens: int | None = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
logit_bias: dict | None = None,
user: str | None = None,
# deprecated
functions: list[FunctionSpec] | None = None,
function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None,
) -> ChatCompletion: ...

async def create_chat_completion(
self,
model: str,
messages: list[OpenAIChatMessage],
functions: list[FunctionSpec] | None = None,
*,
tools: list[ToolSpec] | None = None,
**kwargs,
) -> ChatCompletion:
"""Create a chat completion.
See https://platform.openai.com/docs/api-reference/chat/create.
"""
# transform pydantic models
if functions:
kwargs["functions"] = [f.model_dump(exclude_unset=True) for f in functions]
if tools:
kwargs["tools"] = [t.model_dump(exclude_unset=True) for t in tools]
if "tool_choice" in kwargs and isinstance(kwargs["tool_choice"], SpecificFunctionCall):
kwargs["tool_choice"] = kwargs["tool_choice"].model_dump(exclude_unset=True)
# deprecated function calling
if "functions" in kwargs:
warnings.warn("The functions parameter is deprecated. Use tools instead.", DeprecationWarning)
kwargs["functions"] = [f.model_dump(exclude_unset=True) for f in kwargs["functions"]]
if "function_call" in kwargs and isinstance(kwargs["function_call"], SpecificFunctionCall):
warnings.warn("The function_call parameter is deprecated. Use tool_choice instead.", DeprecationWarning)
kwargs["function_call"] = kwargs["function_call"].model_dump(exclude_unset=True)
# call API
data = await self.post(
Expand Down
Loading

0 comments on commit 42ff983

Please sign in to comment.