Skip to content

Commit

Permalink
Add assistant.tool_execution_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
andreibondarev committed Nov 25, 2024
1 parent 5bae916 commit 2bf029a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,13 @@ Note that streaming is not currently supported for all LLMs.
* `tool_choice`: Specifies how tools should be selected. Default: "auto". A specific tool function name can be passed. This will force the Assistant to **always** use this function.
* `parallel_tool_calls`: Whether to make multiple parallel tool calls. Default: true
* `add_message_callback`: A callback function (proc, lambda) that is called when any message is added to the conversation (optional)
```ruby
assistant.add_message_callback = -> (message) { puts "New message: #{message}" }
```
* `tool_execution_callback`: A callback function (proc, lambda) that is called right before a tool is executed (optional)
```ruby
assistant.tool_execution_callback = -> (tool_call_id, tool_name, method_name, tool_arguments) { puts "Executing tool_call_id: #{tool_call_id}, tool_name: #{tool_name}, method_name: #{method_name}, tool_arguments: #{tool_arguments}" }
```

### Key Methods
* `add_message`: Adds a user message to the messages array
Expand Down
21 changes: 16 additions & 5 deletions lib/langchain/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Assistant

attr_accessor :tools,
:add_message_callback,
:tool_execution_callback,
:parallel_tool_calls

# Create a new assistant
Expand All @@ -35,14 +36,17 @@ class Assistant
# @param parallel_tool_calls [Boolean] Whether or not to run tools in parallel
# @param messages [Array<Langchain::Assistant::Messages::Base>] The messages
# @param add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
# @param tool_execution_callback [Proc] A callback function (Proc or lambda) that is called right before a tool function is executed
def initialize(
llm:,
tools: [],
instructions: nil,
tool_choice: "auto",
parallel_tool_calls: true,
messages: [],
# Callbacks
add_message_callback: nil,
tool_execution_callback: nil,
&block
)
unless tools.is_a?(Array) && tools.all? { |tool| tool.class.singleton_class.included_modules.include?(Langchain::ToolDefinition) }
Expand All @@ -52,11 +56,8 @@ def initialize(
@llm = llm
@llm_adapter = LLM::Adapter.build(llm)

# TODO: Validate that it is, indeed, a Proc or lambda
if !add_message_callback.nil? && !add_message_callback.respond_to?(:call)
raise ArgumentError, "add_message_callback must be a callable object, like Proc or lambda"
end
@add_message_callback = add_message_callback
@add_message_callback = add_message_callback if validate_callback!("add_message_callback", add_message_callback)
@tool_execution_callback = tool_execution_callback if validate_callback!("tool_execution_callback", tool_execution_callback)

self.messages = messages
@tools = tools
Expand Down Expand Up @@ -359,6 +360,8 @@ def run_tools(tool_calls)
t.class.tool_name == tool_name
end or raise ArgumentError, "Tool: #{tool_name} not found in assistant.tools"

# Call the callback if set
tool_execution_callback.call(tool_call_id, tool_name, method_name, tool_arguments) if tool_execution_callback # rubocop:disable Style/SafeNavigation
output = tool_instance.send(method_name, **tool_arguments)

submit_tool_output(tool_call_id: tool_call_id, output: output)
Expand Down Expand Up @@ -392,5 +395,13 @@ def record_used_tokens(prompt_tokens, completion_tokens, total_tokens_from_opera
def available_tool_names
llm_adapter.available_tool_names(tools)
end

def validate_callback!(attr_name, callback)
if !callback.nil? && !callback.respond_to?(:call)
raise ArgumentError, "#{attr_name} must be a callable object, like Proc or lambda"
end

true
end
end
end
10 changes: 10 additions & 0 deletions spec/langchain/assistant/assistant_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
end
end

describe "#tool_execution_callback" do
it "raises an error if the callback is not a Proc" do
expect { described_class.new(llm: llm, tool_execution_callback: "foo") }.to raise_error(ArgumentError)
end

it "does not raise an error if the callback is a Proc" do
expect { described_class.new(llm: llm, tool_execution_callback: -> {}) }.not_to raise_error
end
end

it "raises an error if LLM class does not implement `chat()` method" do
llm = Langchain::LLM::Replicate.new(api_key: "123")
expect { described_class.new(llm: llm) }.to raise_error(ArgumentError)
Expand Down

0 comments on commit 2bf029a

Please sign in to comment.