Skip to content

Commit

Permalink
Transcription msg + transcribe function for node
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Aug 1, 2024
1 parent d581098 commit 361e3ac
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 28 deletions.
6 changes: 5 additions & 1 deletion whisper_demos/whisper_demos/whisper_demo_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def listen(self) -> None:

rclpy.spin_until_future_complete(self, get_result_future)
result: STT.Result = get_result_future.result().result
self.get_logger().info(f"I hear: {result.text}")
self.get_logger().info(f"I hear: {result.transcription.text}")
self.get_logger().info(
f"Audio time: {result.transcription.audio_time}")
self.get_logger().info(
f"Transcription time: {result.transcription.transcription_time}")


def main():
Expand Down
3 changes: 2 additions & 1 deletion whisper_msgs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)

rosidl_generate_interfaces(${PROJECT_NAME}
"action/STT.action"
"msg/Transcription.msg"
"msg/GrammarConfig.msg"
"srv/SetGrammar.srv"
"srv/SetInitPrompt.srv"
"action/STT.action"
)

ament_package()
2 changes: 1 addition & 1 deletion whisper_msgs/action/STT.action
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
string prompt
GrammarConfig grammar_config
---
string text
Transcription transcription
---
3 changes: 3 additions & 0 deletions whisper_msgs/msg/Transcription.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
string text
float32 audio_time
float32 transcription_time
3 changes: 3 additions & 0 deletions whisper_ros/include/whisper_ros/whisper_base_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <memory>
#include <rclcpp/rclcpp.hpp>

#include "whisper_msgs/msg/transcription.hpp"
#include "whisper_ros/whisper.hpp"

namespace whisper_ros {
Expand All @@ -38,6 +39,8 @@ class WhisperBaseNode : public rclcpp::Node {
protected:
std::string language;
std::shared_ptr<Whisper> whisper;

whisper_msgs::msg::Transcription transcribe(const std::vector<float> &audio);
};

} // namespace whisper_ros
Expand Down
4 changes: 2 additions & 2 deletions whisper_ros/include/whisper_ros/whisper_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
#include <memory>

#include <std_msgs/msg/float32_multi_array.hpp>
#include <std_msgs/msg/string.hpp>
#include <std_srvs/srv/empty.hpp>

#include "whisper_msgs/msg/transcription.hpp"
#include "whisper_msgs/srv/set_grammar.hpp"
#include "whisper_msgs/srv/set_init_prompt.hpp"
#include "whisper_ros/whisper_base_node.hpp"
Expand All @@ -41,7 +41,7 @@ class WhisperNode : public WhisperBaseNode {
WhisperNode();

private:
rclcpp::Publisher<std_msgs::msg::String>::SharedPtr publisher_;
rclcpp::Publisher<whisper_msgs::msg::Transcription>::SharedPtr publisher_;
rclcpp::Subscription<std_msgs::msg::Float32MultiArray>::SharedPtr
subscription_;

Expand Down
7 changes: 4 additions & 3 deletions whisper_ros/include/whisper_ros/whisper_server_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <std_srvs/srv/set_bool.hpp>

#include "whisper_msgs/action/stt.hpp"
#include "whisper_msgs/msg/transcription.hpp"
#include "whisper_ros/whisper_base_node.hpp"

namespace whisper_ros {
Expand All @@ -48,9 +49,9 @@ class WhisperServerNode : public WhisperBaseNode {
void enable_silero(bool enable);

private:
std::string text;
std::mutex text_mutex;
std::condition_variable text_cond;
whisper_msgs::msg::Transcription transcription_msg;
std::mutex transcription_mutex;
std::condition_variable transcription_cond;

std::shared_ptr<GoalHandleSTT> goal_handle_;
rclcpp::Subscription<std_msgs::msg::Float32MultiArray>::SharedPtr
Expand Down
18 changes: 18 additions & 0 deletions whisper_ros/src/whisper_ros/whisper_base_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,21 @@ WhisperBaseNode::WhisperBaseNode() : rclcpp::Node("whisper_node") {
this->whisper = std::make_shared<Whisper>(model, openvino_encode_device,
n_processors, cparams, wparams);
}

whisper_msgs::msg::Transcription
WhisperBaseNode::transcribe(const std::vector<float> &audio) {

auto start_time = this->get_clock()->now();
RCLCPP_INFO(this->get_logger(), "Transcribing");
transcription_output result = this->whisper->transcribe(audio);
std::string text = this->whisper->trim(result.text);
auto end_time = this->get_clock()->now();

RCLCPP_INFO(this->get_logger(), "Text heard: %s", text.c_str());
whisper_msgs::msg::Transcription msg;
msg.text = text;
msg.audio_time = audio.size() / WHISPER_SAMPLE_RATE;
msg.transcription_time = (end_time - start_time).seconds();

return msg;
}
13 changes: 4 additions & 9 deletions whisper_ros/src/whisper_ros/whisper_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ WhisperNode::WhisperNode() : WhisperBaseNode() {
_2));

// pubs, subs
this->publisher_ = this->create_publisher<std_msgs::msg::String>("text", 10);
this->publisher_ = this->create_publisher<whisper_msgs::msg::Transcription>(
"transcription", 10);
this->subscription_ =
this->create_subscription<std_msgs::msg::Float32MultiArray>(
"vad", 10, std::bind(&WhisperNode::vad_callback, this, _1));
Expand All @@ -61,14 +62,8 @@ WhisperNode::WhisperNode() : WhisperBaseNode() {
void WhisperNode::vad_callback(
const std_msgs::msg::Float32MultiArray::SharedPtr msg) {

RCLCPP_INFO(this->get_logger(), "Transcribing");
transcription_output result = this->whisper->transcribe(msg->data);
std::string text = this->whisper->trim(result.text);
RCLCPP_INFO(this->get_logger(), "Text heard: %s", text.c_str());

std_msgs::msg::String result_msg;
result_msg.data = text;
this->publisher_->publish(result_msg);
auto transcription_msg = this->transcribe(msg->data);
this->publisher_->publish(transcription_msg);
}

void WhisperNode::set_grammar_service_callback(
Expand Down
19 changes: 8 additions & 11 deletions whisper_ros/src/whisper_ros/whisper_server_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,9 @@ void WhisperServerNode::vad_callback(

this->enable_silero(false);

RCLCPP_INFO(this->get_logger(), "Transcribing");
transcription_output result = this->whisper->transcribe(msg->data);
this->text = this->whisper->trim(result.text);
RCLCPP_INFO(this->get_logger(), "Text heard: %s", this->text.c_str());
this->transcription_msg = this->transcribe(msg->data);

this->text_cond.notify_all();
this->transcription_cond.notify_all();
}

rclcpp_action::GoalResponse
Expand All @@ -89,7 +86,7 @@ rclcpp_action::CancelResponse WhisperServerNode::handle_cancel(

RCLCPP_INFO(this->get_logger(), "Received request to cancel Whisper node");
this->enable_silero(false);
this->text_cond.notify_all();
this->transcription_cond.notify_all();

return rclcpp_action::CancelResponse::ACCEPT;
}
Expand All @@ -114,14 +111,14 @@ void WhisperServerNode::execute(
this->whisper->set_init_prompt(goal->prompt);

auto result = std::make_shared<STT::Result>();
this->text.clear();
this->transcription_msg.text.clear();

this->enable_silero(true);

// wait for text
while (this->text.empty() && !goal_handle->is_canceling()) {
std::unique_lock<std::mutex> lock(this->text_mutex);
this->text_cond.wait(lock);
while (this->transcription_msg.text.empty() && !goal_handle->is_canceling()) {
std::unique_lock<std::mutex> lock(this->transcription_mutex);
this->transcription_cond.wait(lock);
}

// reset
Expand All @@ -133,7 +130,7 @@ void WhisperServerNode::execute(
goal_handle->canceled(result);

} else {
result->text = this->text;
result->transcription = this->transcription_msg;
goal_handle->succeed(result);
}
}

0 comments on commit 361e3ac

Please sign in to comment.