diff --git a/Cargo.lock b/Cargo.lock index 58f1edf..0ae417b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -333,7 +333,7 @@ dependencies = [ "base64 0.21.2", "derive_builder", "futures", - "rand", + "rand 0.8.5", "reqwest", "reqwest-eventsource", "secrecy", @@ -465,7 +465,7 @@ dependencies = [ "getrandom", "instant", "pin-project-lite", - "rand", + "rand 0.8.5", "tokio", ] @@ -1519,6 +1519,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fuchsia-cprng" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" + [[package]] name = "fuchsia-zircon" version = "0.3.3" @@ -1951,6 +1957,7 @@ dependencies = [ "async-openai", "async-trait", "azure_tts", + "base64 0.21.2", "bitflags 2.3.3", "bytes", "chrono", @@ -1972,7 +1979,7 @@ dependencies = [ "prost-reflect", "prost-reflect-build", "prost-types", - "rand", + "rand 0.8.5", "reqwest", "rodio", "rplidar_driver", @@ -1982,6 +1989,7 @@ dependencies = [ "serde_yaml", "serialport", "sha2", + "tempdir", "thiserror", "tokio", "toml 0.7.6", @@ -3002,7 +3010,7 @@ dependencies = [ "num-integer", "num-iter", "num-traits", - "rand", + "rand 0.8.5", "smallvec", "zeroize", ] @@ -3810,7 +3818,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c8bb234e70c863204303507d841e7fa2295e95c822b2bb4ca8ebf57f17b1cb" dependencies = [ "bytes", - "rand", + "rand 0.8.5", "ring", "rustc-hash", "rustls", @@ -3843,6 +3851,19 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "552840b97013b1a26992c11eac34bdd778e464601a4c2054b5f0bff7c6761293" +dependencies = [ + "fuchsia-cprng", + "libc", + "rand_core 0.3.1", + "rdrand", + "winapi 0.3.9", +] + [[package]] name = "rand" version = "0.8.5" @@ -3851,7 +3872,7 @@ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", "rand_chacha", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -3861,9 +3882,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", ] +[[package]] +name = "rand_core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b" +dependencies = [ + "rand_core 0.4.2", +] + +[[package]] +name = "rand_core" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" + [[package]] name = "rand_core" version = "0.6.4" @@ -3926,6 +3962,15 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "rdrand" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2" +dependencies = [ + "rand_core 0.3.1", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4008,6 +4053,15 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +[[package]] +name = "remove_dir_all" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" +dependencies = [ + "winapi 0.3.9", +] + [[package]] name = "reqwest" version = "0.11.18" @@ -4143,7 +4197,7 @@ dependencies = [ "num-traits", "pkcs1", "pkcs8", - "rand_core", + "rand_core 0.6.4", "signature", "subtle", "zeroize", @@ -4595,7 +4649,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" dependencies = [ "digest", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -4816,6 +4870,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempdir" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f2b5fb00ccdf689e0149d1b1b3c03fead81c2b37735d812fa8bddbbf41b6d8" +dependencies = [ + "rand 0.4.6", + "remove_dir_all", +] + [[package]] name = "tempfile" version = "3.7.0" @@ -5169,7 +5233,7 @@ dependencies = [ "http", "httparse", "log", - "rand", + "rand 0.8.5", "sha1", "thiserror", "url", @@ -5208,7 +5272,7 @@ dependencies = [ "humantime", "lazy_static", "log", - "rand", + "rand 0.8.5", "serde", "spin 0.9.8", ] @@ -5932,7 +5996,7 @@ dependencies = [ "log", "ordered-float 3.7.0", "petgraph", - "rand", + "rand 0.8.5", "regex", "rustc_version", "serde", @@ -6033,7 +6097,7 @@ checksum = "41e2b1c3850d7f052daa24823c4b496610d0b150d3624e08d9383d7b9c97714b" dependencies = [ "aes", "hmac", - "rand", + "rand 0.8.5", "rand_chacha", "sha3", "zenoh-result", @@ -6047,7 +6111,7 @@ checksum = "ecf3b237a6d2ac9df0b4204fe5a178ddb586a4a7e7dd73fe680e0cb0517dab86" dependencies = [ "hashbrown 0.13.2", "keyed-set", - "rand", + "rand 0.8.5", "serde", "token-cell", "zenoh-result", @@ -6253,7 +6317,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e52290acad7decbae601bce3b6c394d584c41ca93d948aac27be964eeb38ee48" dependencies = [ "hex", - "rand", + "rand 0.8.5", "serde", "uhlc", "uuid", @@ -6301,7 +6365,7 @@ dependencies = [ "log", "lz4_flex", "paste", - "rand", + "rand 0.8.5", "ringbuffer-spsc", "rsa", "serde", diff --git a/Cargo.toml b/Cargo.toml index e10671e..48e8d48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ thiserror = "1.0" walkdir = "2.3.3" bytes = "1.4" reqwest = {version = "0.11", features = ["json"]} +tempdir = "0.3.7" # async futures = "0.3" @@ -88,7 +89,7 @@ serde_yaml = "0.9.25" toml = "0.7.6" schemars = "0.8.12" cobs-rs = "1.1.1" - +base64 = "0.21.0" # other projects azure_tts = {git = "https://github.com/dmweis/azure_tts", branch = "main", optional = true} diff --git a/src/audio_transcribe.rs b/src/audio_transcribe.rs new file mode 100644 index 0000000..e207bc2 --- /dev/null +++ b/src/audio_transcribe.rs @@ -0,0 +1,133 @@ +use anyhow::Context; +use async_openai::{config::OpenAIConfig, types::CreateTranscriptionRequestArgs, Client}; +use base64::{engine::general_purpose, Engine}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tempdir::TempDir; +use tokio::select; +use zenoh::prelude::r#async::*; + +use crate::{ + error::HopperError, + zenoh_consts::HOPPER_OPENAI_VOICE_COMMAND_SUBSCRIBER, openai::OpenAiService, ioc_container::IocContainer, +}; + +const VOICE_TO_TEXT_TRANSCRIBE_MODEL: &str = "whisper-1"; +const VOICE_TO_TEXT_TRANSCRIBE_MODEL_ENGLISH_LANGUAGE: &str = "en"; + +pub async fn start_audio_transcribe_service( + openai_api_key: &str, + zenoh_session: Arc, +) -> anyhow::Result<()> { + let config = OpenAIConfig::new().with_api_key(openai_api_key); + let client = Client::with_config(config); + + let audio_command_subscriber = zenoh_session + .declare_subscriber(HOPPER_OPENAI_VOICE_COMMAND_SUBSCRIBER) + .res() + .await + .map_err(HopperError::ZenohError)?; + + + + // this is because of the transcribe future not being send + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + loop { + let res: anyhow::Result<()> = async { + select! { + audio_command_msg = audio_command_subscriber.recv_async() => { + let audio_command_msg = audio_command_msg?; + let audio_command_msg_json: String = audio_command_msg.value.try_into()?; + let encoded_audio_command: Base64AudioMessage = serde_json::from_str(&audio_command_msg_json)?; + let audio_command: DecodedAudioMessage = encoded_audio_command.try_into()?; + let text = transcribe(audio_command, "Audio command for a hexapod robot called Hopper", &client).await?; + + IocContainer::global_instance() + .service::()? + .send_command(&text) + .await?; + + } + } + Ok(()) + } + .await; + if let Err(e) = res { + tracing::error!("Error in speech controller: {}", e); + } + } + }); + }); + + Ok(()) +} + +pub async fn transcribe( + audio: DecodedAudioMessage, + prompt: &str, + open_ai_client: &Client, +) -> anyhow::Result { + let temp_dir = TempDir::new("audio_message_temp_dir")?; + let temp_auido_file = temp_dir + .path() + .join(format!("recorded.{}", audio.format_extension)); + + tokio::fs::write(&temp_auido_file, &audio.data).await?; + + let request = CreateTranscriptionRequestArgs::default() + .file(temp_auido_file) + .model(VOICE_TO_TEXT_TRANSCRIBE_MODEL) + .language(VOICE_TO_TEXT_TRANSCRIBE_MODEL_ENGLISH_LANGUAGE) + .prompt(prompt) + .build()?; + + let response = open_ai_client.audio().transcribe(request).await?; + Ok(response.text) +} + +#[derive(Deserialize, Serialize, Debug, Clone, Default)] +pub struct Base64AudioMessage { + pub data: String, + pub format: String, +} + +pub struct DecodedAudioMessage { + pub data: Vec, + /// .wav, .mp3, etc + pub format_extension: String, +} + +impl DecodedAudioMessage { + pub fn new(data: Vec, format_extension: &str) -> Self { + Self { + data, + format_extension: format_extension.to_string(), + } + } +} + +impl TryFrom for DecodedAudioMessage { + type Error = anyhow::Error; + + fn try_from(audio_message: Base64AudioMessage) -> anyhow::Result { + let data = base64_to_binary(&audio_message.data)?; + let format_extension = audio_message.format; + Ok(Self { + data, + format_extension, + }) + } +} + +pub fn base64_to_binary(base64: &str) -> anyhow::Result> { + let decoded_file = general_purpose::STANDARD + .decode(base64) + .context("Failed to parse base64")?; + Ok(decoded_file) +} diff --git a/src/bin/hopper.rs b/src/bin/hopper.rs index c2b8522..2023929 100644 --- a/src/bin/hopper.rs +++ b/src/bin/hopper.rs @@ -1,6 +1,7 @@ use anyhow::Result; use clap::Parser; use hopper_rust::{ + audio_transcribe::start_audio_transcribe_service, body_controller, body_controller::BodyController, camera::start_camera, @@ -158,7 +159,12 @@ async fn main() -> Result<()> { start_camera(zenoh_session.clone(), &app_config.camera).await?; - start_openai_controller(&app_config.openai.api_key, zenoh_session.clone()).await?; + let open_ai_service = + start_openai_controller(&app_config.openai.api_key, zenoh_session.clone()).await?; + + ioc_container.register(open_ai_service); + + start_audio_transcribe_service(&app_config.openai.api_key, zenoh_session.clone()).await?; // hopper_rust::udp_remote::udp_controller_handler(&mut motion_controller) // .await diff --git a/src/lib.rs b/src/lib.rs index bfbd14b..73bcfe1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![doc = include_str!("../README.md")] +pub mod audio_transcribe; pub mod body_controller; pub mod camera; pub mod configuration; diff --git a/src/openai.rs b/src/openai.rs index f504210..0c3b6a5 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -51,10 +51,22 @@ Give short and concise answers. // Add to make robot speak slovak // You can also speak slovak. +#[derive(Clone)] +pub struct OpenAiService { + sender: tokio::sync::mpsc::Sender, +} + +impl OpenAiService { + pub async fn send_command(&self, command: &str) -> anyhow::Result<()> { + self.sender.send(command.to_string()).await?; + Ok(()) + } +} + pub async fn start_openai_controller( openai_api_key: &str, zenoh_session: Arc, -) -> anyhow::Result<()> { +) -> anyhow::Result { let config = OpenAIConfig::new().with_api_key(openai_api_key); let client = Client::with_config(config); @@ -80,15 +92,24 @@ pub async fn start_openai_controller( .await .map_err(HopperError::ZenohError)?; + let (sender, mut receiver) = tokio::sync::mpsc::channel::(10); + tokio::spawn(async move { loop { let res: anyhow::Result<()> = async { select! { text_command_msg = simple_text_command_subscriber.recv_async() => { + info!("Received new zenoh text command"); let text_command_msg = text_command_msg?; let text_command: String = text_command_msg.value.try_into()?; process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone()).await?; } + text_command = receiver.recv() => { + if let Some(text_command) = text_command { + info!("Received new text command"); + process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone()).await?; + } + } } Ok(()) } @@ -98,7 +119,10 @@ pub async fn start_openai_controller( } } }); - Ok(()) + + let open_ai_service = OpenAiService { sender }; + + Ok(open_ai_service) } async fn process_simple_text_command( diff --git a/src/zenoh_consts.rs b/src/zenoh_consts.rs index 261516b..78d1cdd 100644 --- a/src/zenoh_consts.rs +++ b/src/zenoh_consts.rs @@ -28,3 +28,4 @@ pub const HOPPER_CONTROL_LOOP_RATE: &str = "hopper/metrics/control_loop/rate"; // openai pub const HOPPER_OPENAI_COMMAND_SUBSCRIBER: &str = "hopper/openai/simple/text/command"; +pub const HOPPER_OPENAI_VOICE_COMMAND_SUBSCRIBER: &str = "z/audio_to_mqtt/windows/simple";