Skip to content

Commit

Permalink
lifecycle nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Aug 20, 2024
1 parent 8a22285 commit 00276e1
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 64 deletions.
3 changes: 3 additions & 0 deletions whisper_ros/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ endif()
find_package(ament_cmake REQUIRED)
find_package(rclcpp REQUIRED)
find_package(rclcpp_action REQUIRED)
find_package(rclcpp_lifecycle REQUIRED)
find_package(std_msgs REQUIRED)
find_package(std_srvs REQUIRED)
find_package(whisper_msgs REQUIRED)
Expand All @@ -28,6 +29,7 @@ target_link_libraries(whisper_node
)
ament_target_dependencies(whisper_node
rclcpp
rclcpp_lifecycle
std_msgs
std_srvs
whisper_msgs
Expand All @@ -47,6 +49,7 @@ target_link_libraries(whisper_server_node
ament_target_dependencies(whisper_server_node
rclcpp
rclcpp_action
rclcpp_lifecycle
std_msgs
std_srvs
whisper_msgs
Expand Down
27 changes: 26 additions & 1 deletion whisper_ros/include/whisper_ros/whisper_base_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,47 @@

#include <memory>
#include <rclcpp/rclcpp.hpp>
#include <rclcpp_lifecycle/lifecycle_node.hpp>

#include "whisper_msgs/msg/transcription.hpp"
#include "whisper_ros/whisper.hpp"

namespace whisper_ros {

class WhisperBaseNode : public rclcpp::Node {
using CallbackReturn =
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn;

class WhisperBaseNode : public rclcpp_lifecycle::LifecycleNode {

public:
WhisperBaseNode();

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_configure(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_activate(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_deactivate(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_cleanup(const rclcpp_lifecycle::State &);
rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
on_shutdown(const rclcpp_lifecycle::State &);

virtual void activate_ros_interfaces(){};
virtual void deactivate_ros_interfaces(){};

protected:
std::string language;
std::shared_ptr<Whisper> whisper;

whisper_msgs::msg::Transcription transcribe(const std::vector<float> &audio);

private:
std::string model;
std::string openvino_encode_device;
int n_processors;
struct whisper_context_params cparams = whisper_context_default_params();
struct whisper_full_params wparams;
};

} // namespace whisper_ros
Expand Down
3 changes: 3 additions & 0 deletions whisper_ros/include/whisper_ros/whisper_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class WhisperNode : public WhisperBaseNode {
public:
WhisperNode();

void activate_ros_interfaces();
void deactivate_ros_interfaces();

private:
rclcpp::Publisher<whisper_msgs::msg::Transcription>::SharedPtr publisher_;
rclcpp::Subscription<std_msgs::msg::Float32MultiArray>::SharedPtr
Expand Down
3 changes: 3 additions & 0 deletions whisper_ros/include/whisper_ros/whisper_server_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class WhisperServerNode : public WhisperBaseNode {
public:
WhisperServerNode();

void activate_ros_interfaces();
void deactivate_ros_interfaces();

protected:
void enable_silero(bool enable);

Expand Down
2 changes: 2 additions & 0 deletions whisper_ros/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
<test_depend>ament_lint_common</test_depend>

<depend>rclcpp</depend>
<depend>rclcpp_action</depend>
<depend>rclcpp_lifecycle</depend>
<depend>std_msgs</depend>
<depend>std_srvs</depend>
<depend>audio_common</depend>
Expand Down
9 changes: 8 additions & 1 deletion whisper_ros/src/whisper_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ using namespace whisper_ros;

int main(int argc, char *argv[]) {
rclcpp::init(argc, argv);

auto node = std::make_shared<WhisperNode>();
rclcpp::spin(node);
node->configure();
node->activate();

rclcpp::executors::SingleThreadedExecutor executor;
executor.add_node(node->get_node_base_interface());
executor.spin();

rclcpp::shutdown();
return 0;
}
174 changes: 115 additions & 59 deletions whisper_ros/src/whisper_ros/whisper_base_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,8 @@ using namespace whisper_ros;
using std::placeholders::_1;
using std::placeholders::_2;

WhisperBaseNode::WhisperBaseNode() : rclcpp::Node("whisper_node") {

std::string model;
std::string openvino_encode_device;
int n_processors;

std::string sampling_strategy;
struct whisper_context_params cparams = whisper_context_default_params();
WhisperBaseNode::WhisperBaseNode()
: rclcpp_lifecycle::LifecycleNode("whisper_node") {

this->declare_parameters<int32_t>("", {
{"n_threads", 8},
Expand Down Expand Up @@ -88,79 +82,141 @@ WhisperBaseNode::WhisperBaseNode() : rclcpp::Node("whisper_node") {
{"suppress_non_speech_tokens", false},
{"use_gpu", true},
});
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
WhisperBaseNode::on_configure(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Configuring...", this->get_name());

// get sampling method and create default params
std::string sampling_strategy;
this->get_parameter("sampling_strategy", sampling_strategy);

struct whisper_full_params wparams;

if (sampling_strategy == "greedy") {
wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
this->wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
} else if (sampling_strategy == "beam_search") {
wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
this->wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
}

// get params
this->get_parameter("model", model);
this->get_parameter("openvino_encode_device", openvino_encode_device);

this->get_parameter("n_threads", wparams.n_threads);
this->get_parameter("n_max_text_ctx", wparams.n_max_text_ctx);
this->get_parameter("offset_ms", wparams.offset_ms);
this->get_parameter("duration_ms", wparams.duration_ms);

this->get_parameter("translate", wparams.translate);
this->get_parameter("no_context", wparams.no_context);
this->get_parameter("no_timestamps", wparams.no_timestamps);
this->get_parameter("single_segment", wparams.single_segment);
this->get_parameter("print_special", wparams.print_special);
this->get_parameter("print_progress", wparams.print_progress);
this->get_parameter("print_realtime", wparams.print_realtime);
this->get_parameter("print_timestamps", wparams.print_timestamps);

this->get_parameter("token_timestamps", wparams.token_timestamps);
this->get_parameter("thold_pt", wparams.thold_pt);
this->get_parameter("thold_ptsum", wparams.thold_ptsum);
this->get_parameter("max_len", wparams.max_len);
this->get_parameter("split_on_word", wparams.split_on_word);
this->get_parameter("max_tokens", wparams.max_tokens);

this->get_parameter("audio_ctx", wparams.audio_ctx);
this->get_parameter("tinydiarize", wparams.tdrz_enable);
this->get_parameter("model", this->model);
this->get_parameter("openvino_encode_device", this->openvino_encode_device);

this->get_parameter("n_threads", this->wparams.n_threads);
this->get_parameter("n_max_text_ctx", this->wparams.n_max_text_ctx);
this->get_parameter("offset_ms", this->wparams.offset_ms);
this->get_parameter("duration_ms", this->wparams.duration_ms);

this->get_parameter("translate", this->wparams.translate);
this->get_parameter("no_context", this->wparams.no_context);
this->get_parameter("no_timestamps", this->wparams.no_timestamps);
this->get_parameter("single_segment", this->wparams.single_segment);
this->get_parameter("print_special", this->wparams.print_special);
this->get_parameter("print_progress", this->wparams.print_progress);
this->get_parameter("print_realtime", this->wparams.print_realtime);
this->get_parameter("print_timestamps", this->wparams.print_timestamps);

this->get_parameter("token_timestamps", this->wparams.token_timestamps);
this->get_parameter("thold_pt", this->wparams.thold_pt);
this->get_parameter("thold_ptsum", this->wparams.thold_ptsum);
this->get_parameter("max_len", this->wparams.max_len);
this->get_parameter("split_on_word", this->wparams.split_on_word);
this->get_parameter("max_tokens", this->wparams.max_tokens);

this->get_parameter("audio_ctx", this->wparams.audio_ctx);
this->get_parameter("tinydiarize", this->wparams.tdrz_enable);

this->get_parameter("language", this->language);
wparams.language = this->language.c_str();
this->get_parameter("detect_language", wparams.detect_language);
this->wparams.language = this->language.c_str();
this->get_parameter("detect_language", this->wparams.detect_language);

this->get_parameter("suppress_blank", wparams.suppress_blank);
this->get_parameter("suppress_blank", this->wparams.suppress_blank);
this->get_parameter("suppress_non_speech_tokens",
wparams.suppress_non_speech_tokens);
this->wparams.suppress_non_speech_tokens);

this->get_parameter("temperature", wparams.temperature);
this->get_parameter("max_initial_ts", wparams.max_initial_ts);
this->get_parameter("length_penalty", wparams.length_penalty);
this->get_parameter("temperature", this->wparams.temperature);
this->get_parameter("max_initial_ts", this->wparams.max_initial_ts);
this->get_parameter("length_penalty", this->wparams.length_penalty);

this->get_parameter("temperature_inc", wparams.temperature_inc);
this->get_parameter("entropy_thold", wparams.entropy_thold);
this->get_parameter("logprob_thold", wparams.logprob_thold);
this->get_parameter("no_speech_thold", wparams.no_speech_thold);
this->get_parameter("temperature_inc", this->wparams.temperature_inc);
this->get_parameter("entropy_thold", this->wparams.entropy_thold);
this->get_parameter("logprob_thold", this->wparams.logprob_thold);
this->get_parameter("no_speech_thold", this->wparams.no_speech_thold);

this->get_parameter("greedy_best_of", wparams.greedy.best_of);
this->get_parameter("beam_search_beam_size", wparams.beam_search.beam_size);
this->get_parameter("beam_search_patience", wparams.beam_search.patience);
this->get_parameter("greedy_best_of", this->wparams.greedy.best_of);
this->get_parameter("beam_search_beam_size",
this->wparams.beam_search.beam_size);
this->get_parameter("beam_search_patience",
this->wparams.beam_search.patience);

this->get_parameter("n_processors", n_processors);
this->get_parameter("use_gpu", cparams.use_gpu);
this->get_parameter("gpu_device", cparams.gpu_device);
this->get_parameter("n_processors", this->n_processors);
this->get_parameter("use_gpu", this->cparams.use_gpu);
this->get_parameter("gpu_device", this->cparams.gpu_device);

// check threads number
if (wparams.n_threads < 0) {
wparams.n_threads = std::thread::hardware_concurrency();
if (this->wparams.n_threads < 0) {
this->wparams.n_threads = std::thread::hardware_concurrency();
}

RCLCPP_INFO(get_logger(), "[%s] Configured", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
WhisperBaseNode::on_activate(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Activating...", this->get_name());

// create whisper
this->whisper = std::make_shared<Whisper>(model, openvino_encode_device,
n_processors, cparams, wparams);
this->whisper = std::make_shared<Whisper>(
this->model, this->openvino_encode_device, this->n_processors,
this->cparams, this->wparams);

this->activate_ros_interfaces();

RCLCPP_INFO(get_logger(), "[%s] Activated", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
WhisperBaseNode::on_deactivate(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Deactivating...", this->get_name());

this->whisper.reset();
this->whisper = nullptr;

this->deactivate_ros_interfaces();

RCLCPP_INFO(get_logger(), "[%s] Deactivated", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
WhisperBaseNode::on_cleanup(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Cleaning up...", this->get_name());
RCLCPP_INFO(get_logger(), "[%s] Cleaned up", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
WhisperBaseNode::on_shutdown(const rclcpp_lifecycle::State &) {

RCLCPP_INFO(get_logger(), "[%s] Shutting down...", this->get_name());
RCLCPP_INFO(get_logger(), "[%s] Shutted down", this->get_name());

return rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::
CallbackReturn::SUCCESS;
}

whisper_msgs::msg::Transcription
Expand Down
24 changes: 23 additions & 1 deletion whisper_ros/src/whisper_ros/whisper_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ using std::placeholders::_1;
using std::placeholders::_2;

WhisperNode::WhisperNode() : WhisperBaseNode() {
RCLCPP_INFO(this->get_logger(), "Whisper node started");
}

void WhisperNode::activate_ros_interfaces() {
// services
this->set_grammar_service_ =
this->create_service<whisper_msgs::srv::SetGrammar>(
Expand All @@ -55,8 +58,27 @@ WhisperNode::WhisperNode() : WhisperBaseNode() {
this->subscription_ =
this->create_subscription<std_msgs::msg::Float32MultiArray>(
"vad", 10, std::bind(&WhisperNode::vad_callback, this, _1));
}

RCLCPP_INFO(this->get_logger(), "Whisper node started");
void WhisperNode::deactivate_ros_interfaces() {

this->set_grammar_service_.reset();
this->set_grammar_service_ = nullptr;

this->reset_grammar_service_.reset();
this->reset_grammar_service_ = nullptr;

this->set_init_prompt_service_.reset();
this->set_init_prompt_service_ = nullptr;

this->reset_init_prompt_service_.reset();
this->reset_init_prompt_service_ = nullptr;

this->publisher_.reset();
this->publisher_ = nullptr;

this->subscription_.reset();
this->subscription_ = nullptr;
}

void WhisperNode::vad_callback(
Expand Down
Loading

0 comments on commit 00276e1

Please sign in to comment.