diff --git a/.gitignore b/.gitignore index 400100c..e3f6455 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ /.vscode /.embuild -/target -/Cargo.lock +target/ /config.yml /ota.bin diff --git a/Cargo.lock b/Cargo.lock index 827e413..9790de5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -873,6 +873,14 @@ version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +[[package]] +name = "hass-types" +version = "0.1.0" +dependencies = [ + "heapless 0.8.0", + "serde", +] + [[package]] name = "heapless" version = "0.7.17" @@ -1228,6 +1236,7 @@ dependencies = [ [[package]] name = "rust-mqtt" version = "0.3.0" +source = "git+https://github.com/akosnad/rust-mqtt.git#a2751e543ec9f669bbd828138df7384dabe4650b" dependencies = [ "embedded-io", "embedded-io-async", @@ -1434,6 +1443,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "toml" version = "0.8.19" @@ -1504,6 +1533,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e87a2ed6b42ec5e28cc3b94c09982969e9227600b2e3dcbc1db927a84c06bd69" +[[package]] +name = "uneval" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63cc5d2fd8648d7e2be86098f60c9ece7045cc710b3c1e226910e2f37d11dc73" +dependencies = [ + "serde", + "thiserror", +] + [[package]] name = "unicode-ident" version = "1.0.14" @@ -1555,6 +1594,7 @@ dependencies = [ "esp-hal-embassy", "esp-println", "esp-wifi", + "hass-types", "heapless 0.8.0", "log", "rust-mqtt", @@ -1565,6 +1605,7 @@ dependencies = [ "smoltcp", "static_cell", "ublox", + "uneval", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 62e5ecf..19473e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,11 @@ version = "0.1.0" authors = ["akosnad"] edition = "2021" +[workspace] +members = [ + "hass-types" +] + [dependencies] esp-backtrace = { version = "0.14.0", features = [ "esp32", @@ -36,7 +41,8 @@ embassy-futures = "0.1.1" atat = { version = "0.23.0", default-features = false, features = ["atat_derive", "derive", "log", "serde_at"] } serde_at = { version = "0.23.0", features = ["alloc"] } embassy-net-ppp = { version = "0.1.0", features = ["log"] } -rust-mqtt = { version = "0.3.0", default-features = false, features = ["no_std"] } +rust-mqtt = { git = "https://github.com/akosnad/rust-mqtt.git", default-features = false, features = ["no_std"] } +hass-types = { path = "hass-types" } [profile.dev] # Rust debug is too slow. @@ -55,9 +61,8 @@ overflow-checks = false [build-dependencies] serde = { version = "1.0.210", features = ["derive"] } serde_yaml = "0.9.34" +hass-types = { path = "hass-types" } +uneval = "0.2.4" [package.metadata.espflash] partition_table = "partitions.csv" - -[patch.crates-io] -rust-mqtt = { path = "../rust-mqtt" } diff --git a/build.rs b/build.rs index f0924c9..8dcbd3e 100644 --- a/build.rs +++ b/build.rs @@ -1,8 +1,11 @@ +use hass_types::DeviceTracker; + #[derive(serde::Deserialize)] struct Config { wifi_ssid: String, wifi_password: String, apn: String, + device_tracker_config: DeviceTracker, } impl Config { @@ -23,4 +26,7 @@ fn main() { serde_yaml::from_str::(&config_string).expect("config.yml is not valid") }; config.export_vars(); + + uneval::to_out_dir(config.device_tracker_config, "device_tracker_config.rs") + .expect("Failed to write device_tracker_config.rs"); } diff --git a/hass-types/Cargo.toml b/hass-types/Cargo.toml new file mode 100644 index 0000000..f191b42 --- /dev/null +++ b/hass-types/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "hass-types" +version = "0.1.0" +authors = ["akosnad"] +edition = "2021" + +[dependencies] +heapless = { version = "0.8.0", features = ["serde"] } +serde = { version = "1.0.215", default-features = false, features = ["derive", "alloc"] } diff --git a/hass-types/src/lib.rs b/hass-types/src/lib.rs new file mode 100644 index 0000000..1d3705b --- /dev/null +++ b/hass-types/src/lib.rs @@ -0,0 +1,71 @@ +#![no_std] + +use alloc::{format, string::String}; +use heapless::Vec; +use serde::{Deserialize, Serialize}; + +extern crate alloc; + +pub const DISCOVERY_PREFIX: &str = "homeassistant"; + +pub trait Discoverable { + fn discovery_topic(&self) -> Topic; +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, Default)] +pub enum AvailabilityMode { + #[serde(rename = "all")] + All, + #[serde(rename = "any")] + Any, + #[default] + #[serde(rename = "latest")] + Latest, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Availability { + pub payload_available: Option, + pub payload_not_available: Option, + pub topic: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Device { + pub hw_version: Option, + pub sw_version: Option, + pub name: Option, + pub identifiers: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UniqueId(pub String); + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Topic(pub String); + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct DeviceTracker { + /// Only first availability is used by us, the rest are ignored. + pub availability: Vec, + pub availability_mode: Option, + pub device: Option, + pub unique_id: UniqueId, + pub name: String, + pub json_attributes_topic: Topic, +} +impl Discoverable for DeviceTracker { + fn discovery_topic(&self) -> Topic { + Topic(format!( + "{}/device_tracker/{}/config", + DISCOVERY_PREFIX, self.unique_id.0 + )) + } +} + +#[derive(Debug, Clone, Copy, Serialize)] +pub struct DeviceTrackerAttributes { + pub longitude: f64, + pub latitude: f64, + pub gps_accuracy: Option, +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..6d56a5c --- /dev/null +++ b/src/config.rs @@ -0,0 +1,17 @@ +use alloc::vec; +use embassy_sync::once_lock::OnceLock; +use hass_types::*; + +static CONFIG: OnceLock = OnceLock::new(); + +#[derive(Debug, Clone)] +pub struct SystemConfig { + pub device_tracker: DeviceTracker, +} +impl SystemConfig { + pub fn get() -> &'static Self { + CONFIG.get_or_init(|| Self { + device_tracker: include!(concat!(env!("OUT_DIR"), "/device_tracker_config.rs")), + }) + } +} diff --git a/src/gps.rs b/src/gps.rs index 89895ec..ff95463 100644 --- a/src/gps.rs +++ b/src/gps.rs @@ -4,6 +4,7 @@ use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex}; use embassy_time::{Duration, Timer}; use embedded_io_async::{Read as _, ReadReady as _, Write as _}; use esp_hal::{peripherals::UART1, uart::Uart, Async}; +use hass_types::DeviceTrackerAttributes; use log::{info, trace}; use ublox::{GpsFix, PacketRef}; @@ -28,6 +29,16 @@ impl From> for GpsCoords { } } +impl From for DeviceTrackerAttributes { + fn from(coords: GpsCoords) -> Self { + Self { + longitude: coords.lon, + latitude: coords.lat, + gps_accuracy: Some(coords.horiz_accuracy as f64), + } + } +} + #[derive(Debug, Default)] enum GpsState { #[default] diff --git a/src/main.rs b/src/main.rs index a630ef1..e3a0e2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ use alloc::boxed::Box; use core::str::FromStr; use embassy_executor::{task, Spawner}; -use embassy_net::{tcp::TcpSocket, ConfigV4, Ipv4Address, StackResources, StaticConfigV4}; +use embassy_net::{ConfigV4, Ipv4Address, StackResources, StaticConfigV4}; use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex; use embassy_time::{Duration, Timer}; use esp_backtrace as _; @@ -21,14 +21,14 @@ use esp_wifi::wifi::{ utils::create_network_interface, ClientConfiguration, Configuration, WifiController, WifiDevice, WifiEvent, WifiStaDevice, WifiState, }; -use log::info; -use rust_mqtt::{client::client::MqttClient, utils::rng_generator::CountingRng}; use static_cell::make_static; extern crate alloc; +mod config; mod gps; mod modem; +mod mqtt; use gps::Gps; @@ -221,73 +221,18 @@ async fn main_task(spawner: Spawner) { .expect("Failed to spawn modem stack config setter"); // MQTT - let mut mqtt_rx = [0u8; 128]; - let mut mqtt_tx = [0u8; 128]; - let mut mqtt_sock = TcpSocket::new(wifi_stack, &mut mqtt_rx, &mut mqtt_tx); - mqtt_sock.set_timeout(Some(Duration::from_secs(10))); - let endpoint = (Ipv4Address::new(10, 20, 0, 1), 1883); - loop { - Timer::after(Duration::from_secs(5)).await; - if let Err(e) = mqtt_sock.connect(endpoint).await { - log::error!("Failed to connect to MQTT broker: {:?}", e); - continue; - } - info!("MQTT socket connected"); - - let mut config = rust_mqtt::client::client_config::ClientConfig::new( - rust_mqtt::client::client_config::MqttVersion::MQTTv5, - CountingRng(20000), - ); - config.add_client_id("voyagesp"); - const MQTT_BUF_SIZE: usize = 128; - config.max_packet_size = (MQTT_BUF_SIZE as u32) - 1; - let mut client_tx = [0u8; MQTT_BUF_SIZE]; - let mut client_rx = [0u8; MQTT_BUF_SIZE]; - let mut client = MqttClient::<_, 5, _>::new( - &mut mqtt_sock, - &mut client_tx, - MQTT_BUF_SIZE, - &mut client_rx, - MQTT_BUF_SIZE, - config, - ); - - if let Err(e) = client.connect_to_broker().await { - log::error!("Failed to connect to MQTT broker: {:?}", e); - continue; - } + let mqtt = make_static!(mqtt::Mqtt::new(wifi_stack, modem_stack, rng)); + spawner + .spawn(mqtt_task(mqtt)) + .expect("Failed to spawn MQTT task"); - loop { - let gps_data = { - let raw = match gps.get_coords().await { - Some(data) => data, - None => { - Timer::after(Duration::from_secs(2)).await; - continue; - } - }; - match serde_json::to_string(&raw) { - Ok(data) => data, - Err(e) => { - log::error!("Failed to serialize GPS data: {:?}", e); - continue; - } - } - }; - let result = client - .send_message( - "voyagesp", - gps_data.as_bytes(), - rust_mqtt::packet::v5::publish_packet::QualityOfService::QoS0, - false, - ) + loop { + if let Some(gps_data) = gps.get_coords().await { + log::info!("GPS data: {:?}", gps_data); + mqtt.send_event(mqtt::Event::DeviceTrackerStateChange(gps_data.into())) .await; - if let Err(e) = result { - log::error!("Failed to send message: {:?}", e); - break; - } - Timer::after(Duration::from_secs(2)).await; } + Timer::after(Duration::from_secs(5)).await; } } @@ -388,3 +333,8 @@ async fn gps_task(gps: &'static Gps) { async fn modem_task(modem: &'static modem::Modem) { modem.run().await; } + +#[task] +async fn mqtt_task(mqtt: &'static mqtt::Mqtt) { + mqtt.run().await; +} diff --git a/src/mqtt.rs b/src/mqtt.rs new file mode 100644 index 0000000..26b08f4 --- /dev/null +++ b/src/mqtt.rs @@ -0,0 +1,247 @@ +use anyhow::anyhow; +use embassy_futures::select::{select3, Either3}; +use embassy_net::{tcp::TcpSocket, Ipv4Address}; +use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, channel::Channel}; +use embassy_time::{Duration, Timer}; +use esp_hal::rng::Rng; +use hass_types::{DeviceTrackerAttributes, Discoverable}; +use rust_mqtt::{ + client::{client::MqttClient, client_config::ClientConfig}, + packet::v5::publish_packet::QualityOfService, +}; +use serde::Serialize; + +const BUF_SIZE: usize = 2048; +const MAX_PROPERTIES: usize = 10; +const EVENT_QUEUE_SIZE: usize = 5; + +type WifiStack = + &'static embassy_net::Stack>; +type ModemStack = &'static embassy_net::Stack<&'static mut embassy_net_ppp::Device<'static>>; + +#[derive(Debug)] +pub enum Event { + DeviceTrackerStateChange(DeviceTrackerAttributes), +} + +pub struct Mqtt { + wifi_stack: WifiStack, + modem_stack: ModemStack, + config: ClientConfig<'static, MAX_PROPERTIES, esp_hal::rng::Rng>, + event_queue: Channel, +} + +impl Mqtt { + pub fn new(wifi_stack: WifiStack, modem_stack: ModemStack, rng: esp_hal::rng::Rng) -> Self { + let system_config = crate::config::SystemConfig::get(); + + let mut config = + ClientConfig::new(rust_mqtt::client::client_config::MqttVersion::MQTTv5, rng); + config.add_client_id(&system_config.device_tracker.unique_id.0); + if let Some(availability) = system_config.device_tracker.availability.first() { + let payload_offline = availability + .payload_not_available + .as_deref() + .unwrap_or("offline"); + config.add_will(&availability.topic, payload_offline.as_bytes(), true); + } + config.keep_alive = 60; + config.max_packet_size = (BUF_SIZE as u32) - 1; + + Self { + wifi_stack, + modem_stack, + config, + event_queue: Channel::new(), + } + } + + /// MQTT stack task + /// + /// Should be only called once + pub async fn run(&self) -> ! { + let endpoint = (Ipv4Address::new(10, 20, 0, 1), 1883); + 'socket_retry: loop { + Timer::after(Duration::from_secs(5)).await; + + let mut mqtt_rx = [0u8; 128]; + let mut mqtt_tx = [0u8; 128]; + let mut mqtt_sock = TcpSocket::new(self.wifi_stack, &mut mqtt_rx, &mut mqtt_tx); + mqtt_sock.set_timeout(Some(Duration::from_secs(10))); + + if let Err(e) = mqtt_sock.connect(endpoint).await { + log::error!("Failed to connect to MQTT broker: {:?}", e); + continue 'socket_retry; + } + log::info!("MQTT socket connected"); + + let mut client_tx = [0u8; BUF_SIZE]; + let mut client_rx = [0u8; BUF_SIZE]; + let mut client = MqttClient::<_, MAX_PROPERTIES, _>::new( + &mut mqtt_sock, + &mut client_tx, + BUF_SIZE, + &mut client_rx, + BUF_SIZE, + self.config.clone(), + ); + + if let Err(e) = client.connect_to_broker().await { + log::error!("Failed to connect to MQTT broker: {:?}", e); + continue 'socket_retry; + } + log::info!("MQTT broker connected"); + + if let Err(e) = self.init_entities(&mut client).await { + log::error!("Failed to init MQTT entities: {:?}", e); + continue 'socket_retry; + } + log::info!("MQTT entities initialized"); + + 'event_loop: loop { + let event = self.event_queue.ready_to_receive(); + let receive = client.receive_message(); + let timeout = Timer::after(Duration::from_secs(5)); + match select3(event, receive, timeout).await { + Either3::First(_) => { + let event = self.event_queue.receive().await; + match self.handle_event(&mut client, event).await { + Ok(_) => {} + Err(e) => { + log::error!("Failed to handle event: {:?}", e); + break 'event_loop; + } + } + } + Either3::Second(Ok((topic, _payload))) => { + log::info!("Received message on topic: {}", topic); + } + Either3::Second(Err(e)) => { + log::error!("Failed to receive MQTT message: {:?}", e); + break 'event_loop; + } + Either3::Third(_) => match client.send_ping().await { + Ok(_) => {} + Err(e) => { + log::error!("Failed to send MQTT ping: {:?}", e); + break 'event_loop; + } + }, + } + } + } + } + + async fn init_entities( + &self, + client: &mut MqttClient<'_, &mut TcpSocket<'_>, MAX_PROPERTIES, Rng>, + ) -> anyhow::Result<()> { + let device_tracker = &crate::config::SystemConfig::get().device_tracker; + + if let Some(device_tracker_availability) = device_tracker.availability.first() { + let online = device_tracker_availability + .payload_available + .as_deref() + .unwrap_or("online"); + log::debug!( + "Sending online ({}) message for {}", + online, + device_tracker_availability.topic + ); + client + .send_message( + device_tracker_availability.topic.as_str(), + online.as_bytes(), + QualityOfService::QoS1, + true, + ) + .await + .map_err(|e| anyhow!("Failed to send MQTT message: {e:?}"))?; + } + + let device_tracker_value = serde_json::to_value(device_tracker) + .map_err(|e| anyhow!("Failed to serialize device tracker: {e:?}"))?; + let device_tracker_str = serde_json::to_string(&SkipNulls(device_tracker_value)) + .map_err(|e| anyhow!("Failed to serialize device tracker: {e:?}"))?; + log::debug!( + "Sending device tracker discovery config: {}", + device_tracker_str + ); + client + .send_message( + device_tracker.discovery_topic().0.as_str(), + device_tracker_str.as_bytes(), + QualityOfService::QoS1, + true, + ) + .await + .map_err(|e| anyhow!("Failed to send MQTT message: {e:?}"))?; + + Ok(()) + } + + async fn handle_event( + &self, + client: &mut MqttClient<'_, &mut TcpSocket<'_>, MAX_PROPERTIES, Rng>, + event: Event, + ) -> anyhow::Result<()> { + log::debug!("Handling event: {:?}", event); + match event { + Event::DeviceTrackerStateChange(data) => { + let data_value = serde_json::to_value(data) + .map_err(|e| anyhow!("Failed to serialize event data: {e:?}"))?; + match serde_json::to_string(&SkipNulls(data_value)) { + Ok(data_str) => { + let device_tracker = &crate::config::SystemConfig::get().device_tracker; + client + .send_message( + device_tracker.json_attributes_topic.0.as_str(), + data_str.as_bytes(), + QualityOfService::QoS2, + true, + ) + .await + .map_err(|e| anyhow!("Failed to send MQTT message: {e:?}"))?; + } + Err(e) => anyhow::bail!("Failed to serialize event data: {e:?}"), + } + } + } + Ok(()) + } + + pub async fn send_event(&self, event: Event) { + self.event_queue.send(event).await; + } +} + +#[derive(Debug)] +struct SkipNulls(serde_json::Value); +impl Serialize for SkipNulls { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer + ?Sized, + { + use serde::ser::{SerializeMap, SerializeSeq}; + + match &self.0 { + serde_json::Value::Object(map) => { + let map = map.iter().filter(|(_, v)| !v.is_null()); + let mut ser = serializer.serialize_map(None)?; + for (k, v) in map { + ser.serialize_entry(k, &SkipNulls(v.clone()))?; + } + ser.end() + } + serde_json::Value::Array(arr) => { + let arr = arr.iter().filter(|v| !v.is_null()); + let mut ser = serializer.serialize_seq(None)?; + for v in arr { + ser.serialize_element(&SkipNulls(v.clone()))?; + } + ser.end() + } + _ => self.0.serialize(serializer), + } + } +}