From f924e1e92b67e9edae102b48cc38b27d1cc86700 Mon Sep 17 00:00:00 2001 From: Mathias Koch Date: Wed, 25 Sep 2024 15:52:34 +0200 Subject: [PATCH] Feature/async shadows (#57) * Wip on rewriting shadows to async * Further work on async shadows. Still working on compile errors * Fix: Async shadow (#60) * fix asyunc shadow * renaming of handle message and some linting * shadows error fix and handle delta should wait for connected * fmt * Add const generic SUBS to shadows * Fix/async shadow (#61) * fix asyunc shadow * renaming of handle message and some linting * shadows error fix and handle delta should wait for connected * fmt * subscribe to get shadow and do not overwrite desired state * Get shadow should deserialize patchState * wait for accepted and rejected for delete and update as well * Make sure OTA job documents can be deserialized with no codesigning properties in the document (#62) * Dont blindly copy serde attrs in ShadowPatch derive, but rather introduce patch attr that specifies attrs to copy * Add skip_serializing_if none to all patchstate fields * Shadows: Check client token on all request/response pairs * Create initial shadow state, if dao read fails during getShadow operation * remove some client token checks * Fix not holding delta message across report call * handle delta on get shadow * Bump embedded-mqtt * Fix all tests * Allow reporting non-persisted shadows directly, through a report fn * Bump embedded-mqtt * Enhancement(async): Mutex shadow to borrow as immutable (#63) * Use mutex to borrow shadow as immutable * remove .git in embedded-mqtt dependency --------- Co-authored-by: Kenneth Knudsen Co-authored-by: Kenneth Knudsen <98805797+KennethKnudsen97@users.noreply.github.com> --- Cargo.toml | 17 +- documentation/stack.drawio | 8 +- rust-toolchain.toml | 2 +- shadow_derive/src/lib.rs | 26 +- src/jobs/data_types.rs | 15 +- src/lib.rs | 4 +- src/ota/control_interface/mqtt.rs | 116 +++- src/ota/data_interface/mqtt.rs | 51 +- src/ota/encoding/json.rs | 45 +- src/ota/encoding/mod.rs | 8 +- src/ota/error.rs | 5 +- src/provisioning/mod.rs | 228 ++++---- src/shadows/README.md | 2 +- src/shadows/dao.rs | 60 +-- src/shadows/data_types.rs | 12 +- src/shadows/error.rs | 84 +-- src/shadows/mod.rs | 845 +++++++++++++++--------------- src/shadows/topics.rs | 49 +- tests/common/file_handler.rs | 100 ++-- tests/common/network.rs | 4 +- tests/ota.rs | 62 +-- tests/provisioning.rs | 18 +- tests/shadows.rs | 135 +++-- 23 files changed, 979 insertions(+), 917 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c27aacc..42f4d3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,13 @@ members = ["shadow_derive"] [package] name = "rustot" version = "0.5.0" -authors = ["Mathias Koch "] +authors = ["Factbird team "] description = "AWS IoT" readme = "README.md" keywords = ["iot", "no-std"] categories = ["embedded", "no-std"] license = "MIT OR Apache-2.0" -repository = "https://github.com/BlackbirdHQ/rustot" +repository = "https://github.com/FactbirdHQ/rustot" edition = "2021" documentation = "https://docs.rs/rustot" exclude = ["/documentation"] @@ -29,7 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true } serde-json-core = { version = "0.5" } shadow-derive = { path = "shadow_derive", version = "0.2.1" } embedded-storage-async = "0.4" -embedded-mqtt = { git = "ssh://git@github.com/BlackbirdHQ/embedded-mqtt/", rev = "d766137" } +embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "d2b7c02" } futures = { version = "0.3.28", default-features = false } @@ -46,6 +46,7 @@ embedded-nal-async = "0.7" env_logger = "0.11" sha2 = "0.10.1" static_cell = { version = "2", features = ["nightly"] } + tokio = { version = "1.33", default-features = false, features = [ "macros", "rt", @@ -73,5 +74,13 @@ ota_http_data = [] std = ["serde/std", "serde_cbor?/std"] -defmt = ["dep:defmt", "heapless/defmt-03", "embedded-mqtt/defmt"] +defmt = [ + "dep:defmt", + "heapless/defmt-03", + "embedded-mqtt/defmt", + "embassy-time/defmt", +] log = ["dep:log", "embedded-mqtt/log"] + +# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +# embedded-mqtt = { path = "../embedded-mqtt" } diff --git a/documentation/stack.drawio b/documentation/stack.drawio index cdf864a..fd6bb68 100644 --- a/documentation/stack.drawio +++ b/documentation/stack.drawio @@ -1,13 +1,13 @@ - + - + - + @@ -22,7 +22,7 @@ - + diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 1368141..b6369b9 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.79" +channel = "nightly-2024-07-17" components = ["rust-src", "rustfmt", "llvm-tools"] targets = [ "x86_64-unknown-linux-gnu", diff --git a/shadow_derive/src/lib.rs b/shadow_derive/src/lib.rs index 7838e87..09cdef0 100644 --- a/shadow_derive/src/lib.rs +++ b/shadow_derive/src/lib.rs @@ -11,9 +11,9 @@ use syn::DeriveInput; use syn::Generics; use syn::Ident; use syn::Result; -use syn::{parenthesized, Attribute, Error, Field, LitStr}; +use syn::{parenthesized, Error, Field, LitStr}; -#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field))] +#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field, patch))] pub fn shadow_state(input: TokenStream) -> TokenStream { match parse_macro_input!(input as ParseInput) { ParseInput::Struct(input) => { @@ -32,7 +32,7 @@ pub fn shadow_state(input: TokenStream) -> TokenStream { } } -#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, serde))] +#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, patch))] pub fn shadow_patch(input: TokenStream) -> TokenStream { TokenStream::from(match parse_macro_input!(input as ParseInput) { ParseInput::Struct(input) => generate_shadow_patch_struct(&input), @@ -56,7 +56,7 @@ struct StructParseInput { pub ident: Ident, pub generics: Generics, pub shadow_fields: Vec, - pub copy_attrs: Vec, + pub copy_attrs: Vec, pub shadow_name: Option, } @@ -67,8 +67,6 @@ impl Parse for ParseInput { let mut shadow_name = None; let mut copy_attrs = vec![]; - let attrs_to_copy = ["serde"]; - // Parse valid container attributes for attr in derive_input.attrs { if attr.path.is_ident("shadow") { @@ -78,12 +76,14 @@ impl Parse for ParseInput { content.parse() } shadow_name = Some(shadow_arg.parse2(attr.tokens)?); - } else if attrs_to_copy - .iter() - .find(|a| attr.path.is_ident(a)) - .is_some() - { - copy_attrs.push(attr); + } else if attr.path.is_ident("patch") { + fn patch_arg(input: ParseStream) -> Result { + let content; + parenthesized!(content in input); + content.parse() + } + let args = patch_arg.parse2(attr.tokens)?; + copy_attrs.push(quote! { #[ #args ]}) } } @@ -161,7 +161,7 @@ fn create_optional_fields(fields: &Vec) -> Vec Some(if type_name_string.starts_with("Option<") { quote! { #(#attrs)* pub #field_name: Option::PatchState>> } } else { - quote! { #(#attrs)* pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> } + quote! { #(#attrs)* #[serde(skip_serializing_if = "Option::is_none")] pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> } }) } }) diff --git a/src/jobs/data_types.rs b/src/jobs/data_types.rs index 5910469..36449dd 100644 --- a/src/jobs/data_types.rs +++ b/src/jobs/data_types.rs @@ -22,7 +22,8 @@ pub enum JobStatus { Removed, } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum ErrorCode { /// The request was sent to a topic in the AWS IoT Jobs namespace that does /// not map to any API. @@ -89,7 +90,7 @@ pub struct GetPendingJobExecutionsResponse<'a> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Contains data about a job execution. @@ -211,7 +212,7 @@ pub struct StartNextPendingJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Topic (accepted): $aws/things/{thingName}/jobs/{jobId}/update/accepted \ @@ -232,7 +233,7 @@ pub struct UpdateJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Sent whenever a job execution is added to or removed from the list of @@ -289,7 +290,7 @@ pub struct Jobs { /// service operation. #[derive(Debug, PartialEq, Deserialize)] pub struct ErrorResponse<'a> { - code: ErrorCode, + pub code: ErrorCode, /// An error message string. message: &'a str, /// A client token used to correlate requests and responses. Enter an @@ -394,7 +395,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: None, timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); @@ -433,7 +434,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: Some(queued_jobs), timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); } diff --git a/src/lib.rs b/src/lib.rs index b121586..ca160a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ #![cfg_attr(not(any(test, feature = "std")), no_std)] #![allow(async_fn_in_trait)] +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; @@ -8,6 +10,6 @@ pub mod jobs; #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] pub mod ota; pub mod provisioning; -// pub mod shadows; +pub mod shadows; pub use serde_cbor; diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index 094ce16..4456428 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,18 +1,21 @@ use core::fmt::Write; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; -use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS}; +use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic}; +use futures::StreamExt as _; use super::ControlInterface; -use crate::jobs::data_types::JobStatus; -use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; +use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse}; +use crate::jobs::{JobError, JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; use crate::ota::config::Config; use crate::ota::encoding::json::JobStatusReason; -use crate::ota::encoding::FileContext; +use crate::ota::encoding::{self, FileContext}; use crate::ota::error::OtaError; -impl<'a, M: RawMutex, const SUBS: usize> ControlInterface - for embedded_mqtt::MqttClient<'a, M, SUBS> +impl<'a, M: RawMutex, const SUBS: usize> ControlInterface for embedded_mqtt::MqttClient<'a, M, SUBS> +where + BitsImpl<{ SUBS }>: Bits, { /// Check for next available OTA job from the job service by publishing a /// "get next job" message to the job service. @@ -21,15 +24,12 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface let mut buf = [0u8; 512]; let (topic, payload_len) = Jobs::describe().topic_payload(self.client_id(), &mut buf)?; - self.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: &topic, - payload: &buf[..payload_len], - properties: embedded_mqtt::Properties::Slice(&[]), - }) + self.publish( + Publish::builder() + .topic_name(&topic) + .payload(&buf[..payload_len]) + .build(), + ) .await?; Ok(()) @@ -69,7 +69,7 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface } // Don't override the progress on succeeded, nor on self-test - // active. (Cases where progess counter is lost due to device + // active. (Cases where progress counter is lost due to device // restarts) if status != JobStatus::Succeeded && reason != JobStatusReason::SelfTestActive { let mut progress = heapless::String::new(); @@ -93,11 +93,39 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface } } + let mut sub = self + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + JobTopic::UpdateAccepted(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + self.client_id(), + )? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + JobTopic::UpdateRejected(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + self.client_id(), + )? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await?; + let topic = JobTopic::Update(file_ctx.job_name.as_str()) .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?; let payload = DeferredPayload::new( |buf| { Jobs::update(status) + .client_token(self.client_id()) .status_details(&file_ctx.status_details) .payload(buf) .map_err(|_| EncodingError::BufferSize) @@ -105,17 +133,53 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface 512, ); - self.publish(Publish { - dup: false, - qos, - retain: false, - pid: None, - topic_name: &topic, - payload, - properties: embedded_mqtt::Properties::Slice(&[]), - }) + self.publish( + Publish::builder() + .qos(qos) + .topic_name(&topic) + .payload(payload) + .build(), + ) .await?; - Ok(()) + loop { + let message = sub.next().await.ok_or(JobError::Encoding)?; + + // Check if topic is GetAccepted + match crate::jobs::Topic::from_str(message.topic_name()) { + Some(crate::jobs::Topic::UpdateAccepted(_)) => { + // Check client token + let (response, _) = serde_json_core::from_slice::< + UpdateJobExecutionResponse>, + >(message.payload()) + .map_err(|_| JobError::Encoding)?; + + if response.client_token != Some(self.client_id()) { + error!( + "Unexpected client token received: {}, expected: {}", + response.client_token.unwrap_or("None"), + self.client_id() + ); + continue; + } + + return Ok(()); + } + Some(crate::jobs::Topic::UpdateRejected(_)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| JobError::Encoding)?; + + if error_response.client_token != Some(self.client_id()) { + continue; + } + + return Err(OtaError::UpdateRejected(error_response.code)); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + } + } + } } } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index f40f31a..17bbf48 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -2,10 +2,10 @@ use core::fmt::{Display, Write}; use core::ops::DerefMut; use core::str::FromStr; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - DeferredPayload, EncodingError, MqttClient, Properties, Publish, RetainHandling, Subscribe, - SubscribeTopic, Subscription, + DeferredPayload, EncodingError, MqttClient, Publish, Subscribe, SubscribeTopic, Subscription, }; use futures::StreamExt; @@ -124,13 +124,19 @@ impl<'a> OtaTopic<'a> { } } -impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> { +impl<'a, 'b, M: RawMutex, const SUBS: usize> BlockTransfer for Subscription<'a, 'b, M, SUBS, 1> +where + BitsImpl<{ SUBS }>: Bits, +{ async fn next_block(&mut self) -> Result>, OtaError> { Ok(self.next().await) } } -impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> { +impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUBS> +where + BitsImpl<{ SUBS }>: Bits, +{ const PROTOCOL: Protocol = Protocol::Mqtt; type ActiveTransfer<'t> = Subscription<'a, 't, M, SUBS, 1> where Self: 't; @@ -143,17 +149,15 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB let topic_path = OtaTopic::Data(Encoding::Cbor, file_ctx.stream_name.as_str()) .format::<256>(self.client_id())?; - let topic = SubscribeTopic { - topic_path: topic_path.as_str(), - maximum_qos: embedded_mqtt::QoS::AtMostOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }; + let topics = [SubscribeTopic::builder() + .topic_path(topic_path.as_str()) + .build()]; debug!("Subscribing to: [{:?}]", &topic_path); - Ok(self.subscribe::<1>(Subscribe::new(&[topic])).await?) + Ok(self + .subscribe::<1>(Subscribe::builder().topics(&topics).build()) + .await?) } /// Request file block by publishing to the get stream topic @@ -189,17 +193,18 @@ impl<'a, M: RawMutex, const SUBS: usize> DataInterface for MqttClient<'a, M, SUB file_ctx.request_block_remaining ); - self.publish(Publish { - dup: false, - qos: embedded_mqtt::QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) - .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>(self.client_id())? - .as_str(), - payload, - properties: Properties::Slice(&[]), - }) + self.publish( + Publish::builder() + .topic_name( + OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) + .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>( + self.client_id(), + )? + .as_str(), + ) + .payload(payload) + .build(), + ) .await?; Ok(()) diff --git a/src/ota/encoding/json.rs b/src/ota/encoding/json.rs index 45ea3f2..c258942 100644 --- a/src/ota/encoding/json.rs +++ b/src/ota/encoding/json.rs @@ -32,7 +32,8 @@ pub struct FileDescription<'a> { #[serde(rename = "fileid")] pub fileid: u8, #[serde(rename = "certfile")] - pub certfile: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub certfile: Option<&'a str>, #[serde(rename = "update_data_url")] #[serde(skip_serializing_if = "Option::is_none")] pub update_data_url: Option<&'a str>, @@ -59,20 +60,26 @@ pub struct FileDescription<'a> { } impl<'a> FileDescription<'a> { - pub fn signature(&self) -> Signature { + pub fn signature(&self) -> Option { if let Some(sig) = self.sha1_rsa { - return Signature::Sha1Rsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha1Rsa(heapless::String::try_from(sig).unwrap())); } if let Some(sig) = self.sha256_rsa { - return Signature::Sha256Rsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha256Rsa( + heapless::String::try_from(sig).unwrap(), + )); } if let Some(sig) = self.sha1_ecdsa { - return Signature::Sha1Ecdsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha1Ecdsa( + heapless::String::try_from(sig).unwrap(), + )); } if let Some(sig) = self.sha256_ecdsa { - return Signature::Sha256Ecdsa(heapless::String::try_from(sig).unwrap()); + return Some(Signature::Sha256Ecdsa( + heapless::String::try_from(sig).unwrap(), + )); } - unreachable!() + None } } @@ -147,4 +154,28 @@ mod tests { ); } } + + #[test] + fn deserializ() { + let data = r#"{ + "protocols": [ + "MQTT" + ], + "streamname": "AFR_OTA-d11032e9-38d5-4dca-8c7c-1e6f24533ede", + "files": [ + { + "filepath": "3.8.4", + "filesize": 537600, + "fileid": 0, + "certfile": null, + "fileType": 0, + "update_data_url": null, + "auth_scheme": null, + "sig--": null + } + ] + }"#; + + serde_json_core::from_str::(&data).unwrap(); + } } diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index bc68c67..7c88700 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -60,10 +60,10 @@ pub struct FileContext { pub filepath: heapless::String<64>, pub filesize: usize, pub fileid: u8, - pub certfile: heapless::String<64>, + pub certfile: Option>, pub update_data_url: Option>, pub auth_scheme: Option>, - pub signature: Signature, + pub signature: Option, pub file_type: Option, pub protocols: heapless::Vec, @@ -110,7 +110,9 @@ impl FileContext { filesize: file_desc.filesize, protocols: job_data.ota_document.protocols, fileid: file_desc.fileid, - certfile: heapless::String::try_from(file_desc.certfile).unwrap(), + certfile: file_desc + .certfile + .map(|cert| heapless::String::try_from(cert).unwrap()), update_data_url: file_desc .update_data_url .map(|s| heapless::String::try_from(s).unwrap()), diff --git a/src/ota/error.rs b/src/ota/error.rs index 8c5744b..119cc20 100644 --- a/src/ota/error.rs +++ b/src/ota/error.rs @@ -1,4 +1,4 @@ -use crate::jobs::JobError; +use crate::jobs::{data_types::ErrorCode, JobError}; use super::pal::OtaPalError; @@ -6,7 +6,6 @@ use super::pal::OtaPalError; #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum OtaError { NoActiveJob, - SignalEventFailed, Momentum, MomentumAbort, InvalidInterface, @@ -14,7 +13,9 @@ pub enum OtaError { BlockOutOfRange, ZeroFileSize, Overflow, + UnexpectedTopic, InvalidFile, + UpdateRejected(ErrorCode), Write( #[cfg_attr(feature = "defmt", defmt(Debug2Format))] embedded_storage_async::nor_flash::NorFlashErrorKind, diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index dbeefe5..d86e2a7 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -4,10 +4,11 @@ pub mod topics; use core::future::Future; +use bitmaps::{Bits, BitsImpl}; use embassy_sync::blocking_mutex::raw::RawMutex; use embedded_mqtt::{ - DeferredPayload, EncodingError, Message, Publish, QoS, RetainHandling, Subscribe, - SubscribeTopic, Subscription, + BufferProvider, DeferredPayload, EncodingError, Message, Publish, Subscribe, SubscribeTopic, + Subscription, }; use futures::StreamExt; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -41,12 +42,13 @@ pub struct FleetProvisioner; impl FleetProvisioner { pub async fn provision<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -68,6 +70,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -83,12 +86,13 @@ impl FleetProvisioner { #[cfg(feature = "provision_cbor")] pub async fn provision_cbor<'a, C, M: RawMutex, const SUBS: usize>( - mqtt: &'a embedded_mqtt::MqttClient<'a, M, SUBS>, + mqtt: &embedded_mqtt::MqttClient<'a, M, SUBS>, template_name: &str, parameters: Option, credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -111,6 +115,7 @@ impl FleetProvisioner { credential_handler: &mut impl CredentialHandler, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { Self::provision_inner( @@ -134,10 +139,12 @@ impl FleetProvisioner { payload_format: PayloadFormat, ) -> Result, Error> where + BitsImpl<{ SUBS }>: Bits, C: DeserializeOwned, { - let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; + use embedded_mqtt::SliceBufferProvider; + let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; let mut message = create_subscription .next() .await @@ -145,10 +152,11 @@ impl FleetProvisioner { let ownership_token = match Topic::from_str(message.topic_name()) { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { - let response = Self::deserialize::( - format, - &mut message, - )?; + let response = Self::deserialize::< + CreateKeysAndCertificateResponse, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; credential_handler .store_credentials(Credentials { @@ -162,10 +170,11 @@ impl FleetProvisioner { } Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - let response = Self::deserialize::( - format, - &mut message, - )?; + let response = Self::deserialize::< + CreateCertificateFromCsrResponse, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; credential_handler .store_credentials(Credentials { @@ -220,28 +229,29 @@ impl FleetProvisioner { debug!("Starting RegisterThing"); let mut register_subscription = mqtt - .subscribe::<1>(Subscribe::new(&[SubscribeTopic { - topic_path: Topic::RegisterThingAccepted(template_name, payload_format) - .format::<150>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }])) + .subscribe::<1>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + Topic::RegisterThingAccepted(template_name, payload_format) + .format::<150>()? + .as_str(), + ) + .build()]) + .build(), + ) .await?; - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::RegisterThing(template_name, payload_format) - .format::<69>()? - .as_str(), - payload, - properties: embedded_mqtt::Properties::Slice(&[]), - }) + mqtt.publish( + Publish::builder() + .topic_name( + Topic::RegisterThing(template_name, payload_format) + .format::<69>()? + .as_str(), + ) + .payload(payload) + .build(), + ) .await?; drop(message); @@ -254,8 +264,11 @@ impl FleetProvisioner { match Topic::from_str(message.topic_name()) { Some(Topic::RegisterThingAccepted(_, format)) => { - let response = - Self::deserialize::, SUBS>(format, &mut message)?; + let response = Self::deserialize::< + RegisterThingResponse<'_, C>, + SliceBufferProvider<'a>, + SUBS, + >(format, &mut message)?; Ok(response.device_configuration) } @@ -277,29 +290,32 @@ impl FleetProvisioner { mqtt: &'b embedded_mqtt::MqttClient<'a, M, SUBS>, csr: Option<&str>, payload_format: PayloadFormat, - ) -> Result, Error> { + ) -> Result, Error> + where + BitsImpl<{ SUBS }>: Bits, + { if let Some(csr) = csr { let subscription = mqtt - .subscribe(Subscribe::new(&[ - SubscribeTopic { - topic_path: Topic::CreateCertificateFromCsrRejected(payload_format) - .format::<47>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - SubscribeTopic { - topic_path: Topic::CreateCertificateFromCsrAccepted(payload_format) - .format::<47>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - ])) + .subscribe( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::CreateCertificateFromCsrRejected(payload_format) + .format::<47>()? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::CreateCertificateFromCsrAccepted(payload_format) + .format::<47>()? + .as_str(), + ) + .build(), + ]) + .build(), + ) .await?; let request = CreateCertificateFromCsrRequest { @@ -326,65 +342,66 @@ impl FleetProvisioner { csr.len() + 32, ); - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::CreateCertificateFromCsr(payload_format) - .format::<40>()? - .as_str(), - payload, - properties: embedded_mqtt::Properties::Slice(&[]), - }) + mqtt.publish( + Publish::builder() + .topic_name( + Topic::CreateCertificateFromCsr(payload_format) + .format::<40>()? + .as_str(), + ) + .payload(payload) + .build(), + ) .await?; Ok(subscription) } else { let subscription = mqtt - .subscribe(Subscribe::new(&[ - SubscribeTopic { - topic_path: Topic::CreateKeysAndCertificateAccepted(payload_format) - .format::<38>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - SubscribeTopic { - topic_path: Topic::CreateKeysAndCertificateRejected(payload_format) - .format::<38>()? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - ])) + .subscribe( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::CreateKeysAndCertificateAccepted(payload_format) + .format::<38>()? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::CreateKeysAndCertificateRejected(payload_format) + .format::<38>()? + .as_str(), + ) + .build(), + ]) + .build(), + ) .await?; - mqtt.publish(Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - pid: None, - topic_name: Topic::CreateKeysAndCertificate(payload_format) - .format::<29>()? - .as_str(), - payload: b"", - properties: embedded_mqtt::Properties::Slice(&[]), - }) + mqtt.publish( + Publish::builder() + .topic_name( + Topic::CreateKeysAndCertificate(payload_format) + .format::<29>()? + .as_str(), + ) + .payload(b"") + .build(), + ) .await?; Ok(subscription) } } - fn deserialize<'a, R: Deserialize<'a>, const SUBS: usize>( + fn deserialize<'a, R: Deserialize<'a>, B: BufferProvider, const SUBS: usize>( payload_format: PayloadFormat, - message: &'a mut Message<'_, SUBS>, - ) -> Result { + message: &'a mut Message<'_, B, SUBS>, + ) -> Result + where + BitsImpl<{ SUBS }>: Bits, + { trace!( "Accepted Topic {:?}. Payload len: {:?}", payload_format, @@ -398,10 +415,13 @@ impl FleetProvisioner { }) } - fn handle_error( + fn handle_error( format: PayloadFormat, - mut message: Message<'_, SUBS>, - ) -> Result<(), Error> { + mut message: Message<'_, B, SUBS>, + ) -> Result<(), Error> + where + BitsImpl<{ SUBS }>: Bits, + { error!(">> {:?}", message.topic_name()); let response = match format { diff --git a/src/shadows/README.md b/src/shadows/README.md index a1ec0b0..9ea3132 100644 --- a/src/shadows/README.md +++ b/src/shadows/README.md @@ -8,4 +8,4 @@ You can find an example of how to use this crate for iot shadow states in the `t pfx identity files can be created from a set of device certificate and private key using OpenSSL as: `openssl pkcs12 -export -out identity.pfx -inkey private.pem.key -in certificate.pem.crt -certfile root-ca.pem` -The example functions as a CI integration test, that is run against `Blackbirds` integration account on every PR. This test will run through a statemachine of shadow delete, updates and gets from both device & cloud side with assertions in between. +The example functions as a CI integration test, that is run against Factbirds integration account on every PR. This test will run through a statemachine of shadow delete, updates and gets from both device & cloud side with assertions in between. diff --git a/src/shadows/dao.rs b/src/shadows/dao.rs index 875c5c2..1435cd6 100644 --- a/src/shadows/dao.rs +++ b/src/shadows/dao.rs @@ -2,53 +2,23 @@ use serde::{de::DeserializeOwned, Serialize}; use super::{Error, ShadowState}; -pub trait ShadowDAO { - fn read(&mut self) -> Result; - fn write(&mut self, state: &S) -> Result<(), Error>; +pub trait ShadowDAO { + async fn read(&mut self) -> Result; + async fn write(&mut self, state: &S) -> Result<(), Error>; } -impl ShadowDAO for () { - fn read(&mut self) -> Result { - Err(Error::NoPersistance) - } - - fn write(&mut self, _state: &S) -> Result<(), Error> { - Err(Error::NoPersistance) - } -} - -pub struct EmbeddedStorageDAO(T); - -impl From for EmbeddedStorageDAO -where - T: embedded_storage::Storage, -{ - fn from(v: T) -> Self { - Self::new(v) - } -} - -impl EmbeddedStorageDAO -where - T: embedded_storage::Storage, -{ - pub fn new(storage: T) -> Self { - Self(storage) - } -} - -const U32_SIZE: usize = core::mem::size_of::(); +const U32_SIZE: usize = 4; -impl ShadowDAO for EmbeddedStorageDAO +impl ShadowDAO for T where S: ShadowState + DeserializeOwned, - T: embedded_storage::Storage, + T: embedded_storage_async::nor_flash::NorFlash, [(); S::MAX_PAYLOAD_SIZE + U32_SIZE]:, { - fn read(&mut self) -> Result { + async fn read(&mut self) -> Result { let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; - self.0.read(OFFSET, buf).map_err(|_| Error::DaoRead)?; + self.read(0, buf).await.map_err(|_| Error::DaoRead)?; match buf[..U32_SIZE].try_into() { Ok(len_bytes) => { @@ -68,8 +38,8 @@ where } } - fn write(&mut self, state: &S) -> Result<(), Error> { - assert!(S::MAX_PAYLOAD_SIZE <= self.0.capacity() - OFFSET as usize); + async fn write(&mut self, state: &S) -> Result<(), Error> { + assert!(S::MAX_PAYLOAD_SIZE <= self.capacity()); let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; @@ -88,11 +58,11 @@ where buf[..U32_SIZE].copy_from_slice(&(len as u32).to_le_bytes()); - self.0 - .write(OFFSET, &buf[..len + U32_SIZE]) + self.write(0, &buf[..len + U32_SIZE]) + .await .map_err(|_| Error::DaoWrite)?; - debug!("Wrote {} bytes to DAO @ {}", len + U32_SIZE, OFFSET); + debug!("Wrote {} bytes to DAO", len + U32_SIZE); Ok(()) } @@ -128,7 +98,7 @@ where T: std::io::Write + std::io::Read, [(); S::MAX_PAYLOAD_SIZE]:, { - fn read(&mut self) -> Result { + async fn read(&mut self) -> Result { let bytes = &mut [0u8; S::MAX_PAYLOAD_SIZE]; self.0.read(bytes).map_err(|_| Error::DaoRead)?; @@ -136,7 +106,7 @@ where Ok(shadow) } - fn write(&mut self, state: &S) -> Result<(), Error> { + async fn write(&mut self, state: &S) -> Result<(), Error> { let bytes = serde_json_core::to_vec::<_, { S::MAX_PAYLOAD_SIZE }>(state) .map_err(|_| Error::Overflow)?; diff --git a/src/shadows/data_types.rs b/src/shadows/data_types.rs index 7a25453..5449726 100644 --- a/src/shadows/data_types.rs +++ b/src/shadows/data_types.rs @@ -34,16 +34,22 @@ impl From for Patch { #[derive(Debug, Serialize, Deserialize)] pub struct State { + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "desired")] pub desired: Option, + + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "reported")] pub reported: Option, } #[derive(Debug, Serialize, Deserialize)] pub struct DeltaState { + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "desired")] pub desired: Option, + + #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "reported")] pub reported: Option, #[serde(rename = "delta")] @@ -172,7 +178,7 @@ mod tests { exp_map .0 .insert( - heapless::String::from("1"), + heapless::String::try_from("1").unwrap(), Patch::Set(Test { field: true }), ) .unwrap(); @@ -189,7 +195,7 @@ mod tests { exp_map .0 .insert( - heapless::String::from("1"), + heapless::String::try_from("1").unwrap(), Patch::Set(Test { field: true }), ) .unwrap(); @@ -215,7 +221,7 @@ mod tests { let mut exp_map = TestMap(heapless::LinearMap::default()); exp_map .0 - .insert(heapless::String::from("1"), Patch::Unset) + .insert(heapless::String::try_from("1").unwrap(), Patch::Unset) .unwrap(); let (patch, _) = serde_json_core::from_str::(payload).unwrap(); diff --git a/src/shadows/error.rs b/src/shadows/error.rs index f08ff4d..54bd0b1 100644 --- a/src/shadows/error.rs +++ b/src/shadows/error.rs @@ -1,9 +1,4 @@ use core::convert::TryFrom; -use core::fmt::Display; -use core::str::FromStr; - -use heapless::String; -use mqttrust::MqttError; use super::data_types::ErrorResponse; @@ -11,21 +6,15 @@ use super::data_types::ErrorResponse; #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Error { Overflow, - NoPersistance, + NoPersistence, DaoRead, DaoWrite, InvalidPayload, WrongShadowName, - Mqtt(MqttError), + MqttError(embedded_mqtt::Error), ShadowError(ShadowError), } -impl From for Error { - fn from(e: MqttError) -> Self { - Self::Mqtt(e) - } -} - impl From for Error { fn from(e: ShadowError) -> Self { Self::ShadowError(e) @@ -47,7 +36,6 @@ pub enum ShadowError { Unauthorized, Forbidden, NotFound, - NoNamedShadow(String<64>), VersionConflict, PayloadTooLarge, UnsupportedEncoding, @@ -70,7 +58,7 @@ impl ShadowError { ShadowError::Unauthorized => 401, ShadowError::Forbidden => 403, - ShadowError::NotFound | ShadowError::NoNamedShadow(_) => 404, + ShadowError::NotFound => 404, ShadowError::VersionConflict => 409, ShadowError::PayloadTooLarge => 413, ShadowError::UnsupportedEncoding => 415, @@ -85,7 +73,7 @@ impl<'a> TryFrom> for ShadowError { fn try_from(e: ErrorResponse<'a>) -> Result { Ok(match e.code { - 400 | 404 => Self::from_str(e.message)?, + 400 | 404 => ShadowError::NotFound, 401 => ShadowError::Unauthorized, 403 => ShadowError::Forbidden, 409 => ShadowError::VersionConflict, @@ -97,67 +85,3 @@ impl<'a> TryFrom> for ShadowError { }) } } - -// impl Display for ShadowError { -// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { -// match self { -// Self::InvalidJson => write!(f, "Invalid JSON"), -// Self::MissingState => write!(f, "Missing required node: state"), -// Self::MalformedState => write!(f, "State node must be an object"), -// Self::MalformedDesired => write!(f, "Desired node must be an object"), -// Self::MalformedReported => write!(f, "Reported node must be an object"), -// Self::InvalidVersion => write!(f, "Invalid version"), -// Self::InvalidClientToken => write!(f, "Invalid clientToken"), -// Self::JsonTooDeep => { -// write!(f, "JSON contains too many levels of nesting; maximum is 6") -// } -// Self::InvalidStateNode => write!(f, "State contains an invalid node"), -// Self::Unauthorized => write!(f, "Unauthorized"), -// Self::Forbidden => write!(f, "Forbidden"), -// Self::NotFound => write!(f, "Thing not found"), -// Self::NoNamedShadow(shadow_name) => { -// write!(f, "No shadow exists with name: {}", shadow_name) -// } -// Self::VersionConflict => write!(f, "Version conflict"), -// Self::PayloadTooLarge => write!(f, "The payload exceeds the maximum size allowed"), -// Self::UnsupportedEncoding => write!( -// f, -// "Unsupported documented encoding; supported encoding is UTF-8" -// ), -// Self::TooManyRequests => write!(f, "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection"), -// Self::InternalServerError => write!(f, "Internal service failure"), -// } -// } -// } - -// // TODO: This seems like an extremely brittle way of doing this??! -// impl FromStr for ShadowError { -// type Err = (); - -// fn from_str(s: &str) -> Result { -// Ok(match s.trim() { -// "Invalid JSON" => Self::InvalidJson, -// "Missing required node: state" => Self::MissingState, -// "State node must be an object" => Self::MalformedState, -// "Desired node must be an object" => Self::MalformedDesired, -// "Reported node must be an object" => Self::MalformedReported, -// "Invalid version" => Self::InvalidVersion, -// "Invalid clientToken" => Self::InvalidClientToken, -// "JSON contains too many levels of nesting; maximum is 6" => Self::JsonTooDeep, -// "State contains an invalid node" => Self::InvalidStateNode, -// "Unauthorized" => Self::Unauthorized, -// "Forbidden" => Self::Forbidden, -// "Thing not found" => Self::NotFound, -// // TODO: -// "No shadow exists with name: " => Self::NoNamedShadow(String::new()), -// "Version conflict" => Self::VersionConflict, -// "The payload exceeds the maximum size allowed" => Self::PayloadTooLarge, -// "Unsupported documented encoding; supported encoding is UTF-8" => { -// Self::UnsupportedEncoding -// } -// "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection" => Self::TooManyRequests, -// "Internal service failure" => Self::InternalServerError, -// _ => return Err(()), -// }) -// } -// } diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index 31ce599..2825fd5 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -4,18 +4,23 @@ mod error; mod shadow_diff; pub mod topics; -use core::marker::PhantomData; - -use mqttrust::{Mqtt, QoS}; +use core::{marker::PhantomData, ops::DerefMut}; +use bitmaps::{Bits, BitsImpl}; pub use data_types::Patch; +use embassy_sync::{ + blocking_mutex::raw::{NoopRawMutex, RawMutex}, + mutex::Mutex, +}; +use embedded_mqtt::{DeferredPayload, Publish, Subscribe, SubscribeTopic, ToPayload}; pub use error::Error; -use serde::de::DeserializeOwned; +use futures::StreamExt; +use serde::Serialize; pub use shadow_derive as derive; pub use shadow_diff::ShadowPatch; -use data_types::{AcceptedResponse, DeltaResponse, ErrorResponse}; -use topics::{Direction, Subscribe, Topic, Unsubscribe}; +use data_types::{AcceptedResponse, DeltaResponse, DeltaState, ErrorResponse}; +use topics::Topic; use self::dao::ShadowDAO; @@ -23,315 +28,441 @@ const MAX_TOPIC_LEN: usize = 128; const PARTIAL_REQUEST_OVERHEAD: usize = 64; const CLASSIC_SHADOW: &str = "Classic"; -pub trait ShadowState: ShadowPatch { +pub trait ShadowState: ShadowPatch + Default { const NAME: Option<&'static str>; const MAX_PAYLOAD_SIZE: usize = 512; } -struct ShadowHandler<'a, M: Mqtt, S: ShadowState> +struct ShadowHandler<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - mqtt: &'a M, + mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, + subscription: Mutex>>, _shadow: PhantomData, } -impl<'a, M: Mqtt, S: ShadowState> ShadowHandler<'a, M, S> +impl<'a, 'm, M: RawMutex, S: ShadowState, const SUBS: usize> ShadowHandler<'a, 'm, M, S, SUBS> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - Subscribe::<7>::new() - .topic(Topic::GetAccepted, QoS::AtLeastOnce) - .topic(Topic::GetRejected, QoS::AtLeastOnce) - .topic(Topic::DeleteAccepted, QoS::AtLeastOnce) - .topic(Topic::DeleteRejected, QoS::AtLeastOnce) - .topic(Topic::UpdateAccepted, QoS::AtLeastOnce) - .topic(Topic::UpdateRejected, QoS::AtLeastOnce) - .topic(Topic::UpdateDelta, QoS::AtLeastOnce) - .send(self.mqtt, S::NAME)?; + async fn handle_delta(&self) -> Result, Error> { + let mut sub_ref = self.subscription.lock().await; + + let delta_subscription = match sub_ref.deref_mut() { + Some(sub) => sub, + None => { + self.mqtt.wait_connected().await; + + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + topics::Topic::UpdateDelta + .format::<64>(self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build()]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + sub_ref.insert(sub) + } + }; - Ok(()) - } + let delta_message = delta_subscription + .next() + .await + .ok_or(Error::InvalidPayload)?; - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - Unsubscribe::<7>::new() - .topic(Topic::GetAccepted) - .topic(Topic::GetRejected) - .topic(Topic::DeleteAccepted) - .topic(Topic::DeleteRejected) - .topic(Topic::UpdateAccepted) - .topic(Topic::UpdateRejected) - .topic(Topic::UpdateDelta) - .send(self.mqtt, S::NAME)?; + // Update the device's state to match the desired state in the + // message body. + debug!( + "[{:?}] Received shadow delta event.", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); - Ok(()) - } + let (delta, _) = + serde_json_core::from_slice::>(delta_message.payload()) + .map_err(|_| Error::InvalidPayload)?; - /// Helper function to check whether a topic name is relevant for this - /// particular shadow. - pub fn should_handle_topic(&mut self, topic: &str) -> bool { - if let Some((_, thing_name, shadow_name)) = Topic::from_str(topic) { - return thing_name == self.mqtt.client_id() && shadow_name == S::NAME; + if let Some(client) = delta.client_token { + if client.eq(self.mqtt.client_id()) { + return Ok(None); + } } - false + + Ok(delta.state) } /// Internal helper function for applying a delta state to the actual shadow /// state, and update the cloud shadow. - fn change_shadow_value( - &mut self, - state: &mut S, - delta: Option, - update_desired: Option, - ) -> Result<(), Error> { - if let Some(ref delta) = delta { - state.apply_patch(delta.clone()); - } - + async fn report(&self, reported: &R) -> Result<(), Error> { debug!( - "[{:?}] Updating reported shadow value. Update_desired: {:?}", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - update_desired + "[{:?}] Updating reported shadow value.", + S::NAME.unwrap_or(CLASSIC_SHADOW), ); - if let Some(update_desired) = update_desired { - let desired = if update_desired { Some(&state) } else { None }; - - let request = data_types::Request { - state: data_types::State { - reported: Some(&state), - desired, - }, - client_token: None, - version: None, - }; - - let payload = serde_json_core::to_vec::< - _, - { S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD }, - >(&request) - .map_err(|_| Error::Overflow)?; - - let update_topic = - Topic::Update.format::(self.mqtt.client_id(), S::NAME)?; - self.mqtt - .publish(update_topic.as_str(), &payload, QoS::AtLeastOnce)?; - } + let request = data_types::Request { + state: data_types::State { + reported: Some(reported), + desired: None, + }, + client_token: Some(self.mqtt.client_id()), + version: None, + }; - Ok(()) + let payload = DeferredPayload::new( + |buf| { + serde_json_core::to_slice(&request, buf) + .map_err(|_| embedded_mqtt::EncodingError::BufferSize) + }, + S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD, + ); + + let mut sub = self.publish_and_subscribe(Topic::Update, payload).await?; + + //*** WAIT RESPONSE ***/ + debug!("Wait for Accepted or Rejected"); + loop { + let message = sub.next().await.ok_or(Error::InvalidPayload)?; + + // Check if topic is GetAccepted + match Topic::from_str(message.topic_name()) { + Some((Topic::UpdateAccepted, _, _)) => { + // Check client token + let (response, _) = serde_json_core::from_slice::< + AcceptedResponse, + >(message.payload()) + .map_err(|_| Error::InvalidPayload)?; + + if response.client_token != Some(self.mqtt.client_id()) { + error!( + "Unexpected client token received: {}, expected: {}", + response.client_token.unwrap_or("None"), + self.mqtt.client_id() + ); + continue; + } + + return Ok(()); + } + Some((Topic::UpdateRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + return Err(Error::WrongShadowName); + } + } + } } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - let get_topic = Topic::Get.format::(self.mqtt.client_id(), S::NAME)?; - self.mqtt - .publish(get_topic.as_str(), b"", QoS::AtLeastOnce)?; - Ok(()) + async fn get_shadow(&self) -> Result, Error> { + //Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let mut sub = self.publish_and_subscribe(Topic::Get, b"").await?; + + let get_message = sub.next().await.ok_or(Error::InvalidPayload)?; + + //Check if topic is GetAccepted + //Deserialize message + //Persist shadow and return new shadow + match Topic::from_str(get_message.topic_name()) { + Some((Topic::GetAccepted, _, _)) => { + let (response, _) = serde_json_core::from_slice::>( + get_message.payload(), + ) + .map_err(|_| Error::InvalidPayload)?; + + Ok(response.state) + } + Some((Topic::GetRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(get_message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.code == 404 { + debug!( + "[{:?}] Thing has no shadow document. Creating with defaults...", + S::NAME.unwrap_or(CLASSIC_SHADOW) + ); + self.create_shadow().await?; + } + + Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )) + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + Err(Error::WrongShadowName) + } + } + } + + pub async fn delete_shadow(&mut self) -> Result<(), Error> { + // Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let mut sub = self + .publish_and_subscribe(topics::Topic::Delete, b"") + .await?; + + let message = sub.next().await.ok_or(Error::InvalidPayload)?; + + // Check if topic is DeleteAccepted + match Topic::from_str(message.topic_name()) { + Some((Topic::DeleteAccepted, _, _)) => Ok(()), + Some((Topic::DeleteRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )) + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + Err(Error::WrongShadowName) + } + } + } + + pub async fn create_shadow(&self) -> Result, Error> { + debug!( + "[{:?}] Creating initial shadow value.", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + + let state = S::default(); + + let request = data_types::Request { + state: data_types::State { + reported: Some(&state), + desired: Some(&state), + }, + client_token: Some(self.mqtt.client_id()), + version: None, + }; + + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let payload = serde_json_core::to_vec::< + _, + { S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD }, + >(&request) + .map_err(|_| Error::Overflow)?; + + let mut sub = self + .publish_and_subscribe(Topic::Update, payload.as_slice()) + .await?; + loop { + let message = sub.next().await.ok_or(Error::InvalidPayload)?; + + match Topic::from_str(message.topic_name()) { + Some((Topic::UpdateAccepted, _, _)) => { + let (response, _) = serde_json_core::from_slice::< + AcceptedResponse, + >(message.payload()) + .map_err(|_| Error::InvalidPayload)?; + + if response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Ok(response.state); + } + Some((Topic::UpdateRejected, _, _)) => { + let (error_response, _) = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + return Err(Error::WrongShadowName); + } + } + } } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - let delete_topic = Topic::Delete.format::(self.mqtt.client_id(), S::NAME)?; + ///This function will subscribe to accepted and rejected topics and then do a publish. + ///It will only return when something is accepted or rejected + ///Topic is the topic you want to publish to + ///The function will automatically subscribe to the accepted and rejected topic related to the publish topic + async fn publish_and_subscribe( + &self, + topic: topics::Topic, + payload: impl ToPayload, + ) -> Result, Error> { + let (accepted, rejected) = match topic { + Topic::Get => (Topic::GetAccepted, Topic::GetRejected), + Topic::Update => (Topic::UpdateAccepted, Topic::UpdateRejected), + Topic::Delete => (Topic::DeleteAccepted, Topic::DeleteRejected), + _ => return Err(Error::ShadowError(error::ShadowError::Forbidden)), + }; + + //*** SUBSCRIBE ***/ + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + accepted + .format::<64>(self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + rejected + .format::<64>(self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + //*** PUBLISH REQUEST ***/ + let topic_name = topic.format::(self.mqtt.client_id(), S::NAME)?; self.mqtt - .publish(delete_topic.as_str(), b"", QoS::AtLeastOnce)?; - Ok(()) + .publish( + Publish::builder() + .topic_name(topic_name.as_str()) + .payload(payload) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + Ok(sub) } } -pub struct PersistedShadow<'a, S: ShadowState + DeserializeOwned, M: Mqtt, D: ShadowDAO> +pub struct PersistedShadow<'a, 'm, S: ShadowState, M: RawMutex, D: ShadowDAO, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { - handler: ShadowHandler<'a, M, S>, - pub(crate) dao: D, + handler: ShadowHandler<'a, 'm, M, S, SUBS>, + pub(crate) dao: Mutex, } -impl<'a, S, M, D> PersistedShadow<'a, S, M, D> +impl<'a, 'm, S, M, D, const SUBS: usize> PersistedShadow<'a, 'm, S, M, D, SUBS> where - S: ShadowState + DeserializeOwned, - M: Mqtt, + BitsImpl<{ SUBS }>: Bits, + S: ShadowState + Default, + M: RawMutex, D: ShadowDAO, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { /// Instantiate a new shadow that will be automatically persisted to NVM /// based on the passed `DAO`. - pub fn new( - initial_state: S, - mqtt: &'a M, - mut dao: D, - auto_subscribe: bool, - ) -> Result { - if dao.read().is_err() { - dao.write(&initial_state)?; - } - + pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, dao: D) -> Self { let handler = ShadowHandler { mqtt, + subscription: Mutex::new(None), _shadow: PhantomData, }; - if auto_subscribe { - handler.subscribe()?; - } - Ok(Self { handler, dao }) - } - - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - self.handler.subscribe() - } - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - self.handler.unsubscribe() - } - - /// Helper function to check whether a topic name is relevant for this - /// particular shadow. - pub fn should_handle_topic(&mut self, topic: &str) -> bool { - self.handler.should_handle_topic(topic) + Self { + handler, + dao: Mutex::new(dao), + } } - /// Handle incomming publish messages from the cloud on any topics relevant - /// for this particular shadow. + /// Wait delta will subscribe if not already to Updatedelta and wait for changes /// - /// This function needs to be fed all relevant incoming MQTT payloads in - /// order for the shadow manager to work. - #[must_use] - pub fn handle_message( - &mut self, - topic: &str, - payload: &[u8], - ) -> Result<(S, Option), Error> { - let (topic, thing_name, shadow_name) = - Topic::from_str(topic).ok_or(Error::WrongShadowName)?; - - assert_eq!(thing_name, self.handler.mqtt.client_id()); - assert_eq!(topic.direction(), Direction::Incoming); - - if shadow_name != S::NAME { - return Err(Error::WrongShadowName); - } - - let mut state = self.dao.read()?; - - let delta = match topic { - Topic::GetAccepted => { - // The actions necessary to process the state document in the - // message body. - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(response, _)| { - if let Some(_) = response.state.delta { - debug!( - "[{:?}] Received delta state", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.handler.change_shadow_value( - &mut state, - response.state.delta.clone(), - Some(false), - )?; - } else if let Some(_) = response.state.reported { - self.handler.change_shadow_value( - &mut state, - response.state.reported, - None, - )?; - } - Ok(response.state.delta) - })? + pub async fn wait_delta(&self) -> Result<(S, Option), Error> { + let mut state = match self.dao.lock().await.read().await { + Ok(state) => state, + Err(_) => { + error!("Could not read state from flash writing default"); + self.dao.lock().await.write(&S::default()).await?; + S::default() } - Topic::GetRejected | Topic::UpdateRejected => { - // Respond to the error message in the message body. - if let Ok((error, _)) = serde_json_core::from_slice::(payload) { - if error.code == 404 && matches!(topic, Topic::GetRejected) { - debug!( - "[{:?}] Thing has no shadow document. Creating with defaults...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.report_shadow()?; - } else { - error!( - "{:?} request was rejected. code: {:?} message:'{:?}'", - if matches!(topic, Topic::GetRejected) { - "Get" - } else { - "Update" - }, - error.code, - error.message - ); - } - } - None - } - Topic::UpdateDelta => { - // Update the device's state to match the desired state in the - // message body. - debug!( - "[{:?}] Received shadow delta event.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(delta, _)| { - if let Some(_) = delta.state { - debug!( - "[{:?}] Delta reports new desired value. Changing local value...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - } - self.handler.change_shadow_value( - &mut state, - delta.state.clone(), - Some(false), - )?; - Ok(delta.state) - })? - } - Topic::UpdateAccepted => { - // Confirm the updated data in the message body matches the - // device state. - - debug!( - "[{:?}] Finished updating reported shadow value.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - - None - } - _ => None, }; + let delta = self.handler.handle_delta().await?; + // Something has changed as part of handling a message. Persist it // to NVM storage. - if delta.is_some() { - self.dao.write(&state)?; + if let Some(delta) = &delta { + debug!( + "[{:?}] Delta reports new desired value. Changing local value...", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + + state.apply_patch(delta.clone()); + + self.handler.report(&state).await?; + + self.dao.lock().await.write(&state).await?; } Ok((state, delta)) } /// Get an immutable reference to the internal local state. - pub fn try_get(&mut self) -> Result { - self.dao.read() + pub async fn try_get(&mut self) -> Result { + self.dao.lock().await.read().await } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - self.handler.get_shadow() - } + pub async fn get_shadow(&self) -> Result { + let delta_state = self.handler.get_shadow().await?; + + debug!("Persisting new state after get shadow request"); + let mut state = self.dao.lock().await.read().await.unwrap_or_default(); + if let Some(desired) = delta_state.desired { + state.apply_patch(desired); + self.dao.lock().await.write(&state).await?; + if delta_state.delta.is_some() { + self.handler.report(&state).await?; + } + } - /// Initiate an `UpdateShadow` request, reporting the local state to the cloud. - pub fn report_shadow(&mut self) -> Result<(), Error> { - let mut state = self.dao.read()?; - self.handler - .change_shadow_value(&mut state, None, Some(false))?; - Ok(()) + Ok(state) } /// Update the state of the shadow. @@ -340,179 +471,75 @@ where /// and depending on whether the state update is rejected or accepted, it /// will automatically update the local version after response /// - /// The returned `bool` from the update closure will determine wether the + /// The returned `bool` from the update closure will determine whether the /// update is persisted using the `DAO`, or just updated in the cloud. This /// can be handy for activity or status field updates that are not relevant - /// to store persistant on the device, but are required to be part of the + /// to store persistent on the device, but are required to be part of the /// same cloud shadow. - pub fn update bool>(&mut self, f: F) -> Result<(), Error> { + pub async fn update(&self, f: F) -> Result<(), Error> { let mut desired = S::PatchState::default(); - let mut state = self.dao.read()?; - let should_persist = f(&state, &mut desired); + let mut state = self.dao.lock().await.read().await?; + f(&state, &mut desired); - self.handler - .change_shadow_value(&mut state, Some(desired), Some(false))?; + self.handler.report(&desired).await?; - if should_persist { - self.dao.write(&state)?; - } + state.apply_patch(desired); + + // Always persist + self.dao.lock().await.write(&state).await?; Ok(()) } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - self.handler.delete_shadow() + pub async fn delete_shadow(&mut self) -> Result<(), Error> { + self.handler.delete_shadow().await?; + self.dao.lock().await.write(&S::default()).await?; + Ok(()) } } -pub struct Shadow<'a, S: ShadowState, M: Mqtt> +pub struct Shadow<'a, 'm, S: ShadowState, M: RawMutex, const SUBS: usize> where + BitsImpl<{ SUBS }>: Bits, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { state: S, - handler: ShadowHandler<'a, M, S>, + handler: ShadowHandler<'a, 'm, M, S, SUBS>, } -impl<'a, S, M> Shadow<'a, S, M> +impl<'a, 'm, S, M, const SUBS: usize> Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState, - M: Mqtt, + M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { /// Instantiate a new non-persisted shadow - pub fn new(state: S, mqtt: &'a M, auto_subscribe: bool) -> Result { + pub fn new(state: S, mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>) -> Self { let handler = ShadowHandler { mqtt, + subscription: Mutex::new(None), _shadow: PhantomData, }; - if auto_subscribe { - handler.subscribe()?; - } - Ok(Self { handler, state }) - } - - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - self.handler.subscribe() - } - - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - self.handler.unsubscribe() + Self { handler, state } } - /// Handle incomming publish messages from the cloud on any topics relevant + /// Handle incoming publish messages from the cloud on any topics relevant /// for this particular shadow. /// /// This function needs to be fed all relevant incoming MQTT payloads in /// order for the shadow manager to work. - #[must_use] - pub fn handle_message( - &mut self, - topic: &str, - payload: &[u8], - ) -> Result<(&S, Option), Error> { - let (topic, thing_name, shadow_name) = - Topic::from_str(topic).ok_or(Error::WrongShadowName)?; - - assert_eq!(thing_name, self.handler.mqtt.client_id()); - assert_eq!(topic.direction(), Direction::Incoming); - - if shadow_name != S::NAME { - return Err(Error::WrongShadowName); + pub async fn wait_delta(&mut self) -> Result<(&S, Option), Error> { + let delta = self.handler.handle_delta().await?; + if let Some(delta) = &delta { + debug!( + "[{:?}] Delta reports new desired value. Changing local value...", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + self.handler.report(delta).await?; } - let delta = match topic { - Topic::GetAccepted => { - // The actions necessary to process the state document in the - // message body. - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(response, _)| { - if let Some(_) = response.state.delta { - debug!( - "[{:?}] Received delta state", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.handler.change_shadow_value( - &mut self.state, - response.state.delta.clone(), - Some(false), - )?; - } else if let Some(_) = response.state.reported { - self.handler.change_shadow_value( - &mut self.state, - response.state.reported, - None, - )?; - } - Ok(response.state.delta) - })? - } - Topic::GetRejected | Topic::UpdateRejected => { - // Respond to the error message in the message body. - if let Ok((error, _)) = serde_json_core::from_slice::(payload) { - if error.code == 404 && matches!(topic, Topic::GetRejected) { - debug!( - "[{:?}] Thing has no shadow document. Creating with defaults...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.report_shadow()?; - } else { - error!( - "{:?} request was rejected. code: {:?} message:'{:?}'", - if matches!(topic, Topic::GetRejected) { - "Get" - } else { - "Update" - }, - error.code, - error.message - ); - } - } - None - } - Topic::UpdateDelta => { - // Update the device's state to match the desired state in the - // message body. - debug!( - "[{:?}] Received shadow delta event.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(delta, _)| { - if let Some(_) = delta.state { - debug!( - "[{:?}] Delta reports new desired value. Changing local value...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - } - self.handler.change_shadow_value( - &mut self.state, - delta.state.clone(), - Some(false), - )?; - Ok(delta.state) - })? - } - Topic::UpdateAccepted => { - // Confirm the updated data in the message body matches the - // device state. - - debug!( - "[{:?}] Finished updating reported shadow value.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - - None - } - _ => None, - }; - - Ok((self.get(), delta)) + Ok((&self.state, delta)) } /// Get an immutable reference to the internal local state. @@ -520,10 +547,9 @@ where &self.state } - /// Initiate an `UpdateShadow` request, reporting the local state to the cloud. - pub fn report_shadow(&mut self) -> Result<(), Error> { - self.handler - .change_shadow_value(&mut self.state, None, Some(false))?; + /// Report the state of the shadow. + pub async fn report(&mut self) -> Result<(), Error> { + self.handler.report(&self.state).await?; Ok(()) } @@ -532,47 +558,59 @@ where /// This function will update the desired state of the shadow in the cloud, /// and depending on whether the state update is rejected or accepted, it /// will automatically update the local version after response - pub fn update(&mut self, f: F) -> Result<(), Error> { + pub async fn update(&mut self, f: F) -> Result<(), Error> { let mut desired = S::PatchState::default(); f(&self.state, &mut desired); - self.handler - .change_shadow_value(&mut self.state, Some(desired), Some(false))?; + self.handler.report(&desired).await?; + + self.state.apply_patch(desired); Ok(()) } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - self.handler.get_shadow() + pub async fn get_shadow(&mut self) -> Result<&S, Error> { + let delta_state = self.handler.get_shadow().await?; + + if let Some(desired) = delta_state.desired { + self.state.apply_patch(desired); + if delta_state.delta.is_some() { + self.handler.report(&self.state).await?; + } + } + + Ok(&self.state) } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - self.handler.delete_shadow() + pub async fn delete_shadow(&mut self) -> Result<(), Error> { + self.handler.delete_shadow().await } } -impl<'a, S, M> core::fmt::Debug for Shadow<'a, S, M> +impl<'a, 'm, S, M, const SUBS: usize> core::fmt::Debug for Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState + core::fmt::Debug, - M: Mqtt, + M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!( f, "[{:?}] = {:?}", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), + S::NAME.unwrap_or(CLASSIC_SHADOW), self.get() ) } } #[cfg(feature = "defmt")] -impl<'a, S, M> defmt::Format for Shadow<'a, S, M> +impl<'a, 'm, S, M, const SUBS: usize> defmt::Format for Shadow<'a, 'm, S, M, SUBS> where + BitsImpl<{ SUBS }>: Bits, S: ShadowState + defmt::Format, - M: Mqtt, + M: RawMutex, [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { fn format(&self, fmt: defmt::Formatter) { @@ -585,17 +623,6 @@ where } } -impl<'a, S, M> Drop for Shadow<'a, S, M> -where - S: ShadowState, - M: Mqtt, - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, -{ - fn drop(&mut self) { - self.unsubscribe().ok(); - } -} - // #[cfg(test)] // mod tests { // use super::*; diff --git a/src/shadows/topics.rs b/src/shadows/topics.rs index c73e35a..34642d0 100644 --- a/src/shadows/topics.rs +++ b/src/shadows/topics.rs @@ -2,8 +2,8 @@ use core::fmt::Write; +use embedded_mqtt::QoS; use heapless::String; -use mqttrust::{Mqtt, QoS, SubscribeTopic}; use crate::jobs::MAX_THING_NAME_LEN; @@ -33,6 +33,7 @@ pub enum Topic { UpdateRejected, DeleteAccepted, DeleteRejected, + Any, } impl Topic { @@ -188,6 +189,14 @@ impl Topic { name_prefix, shadow_name )), + Self::Any => topic_path.write_fmt(format_args!( + "{}/{}/{}{}{}/#", + Self::PREFIX, + thing_name, + Self::SHADOW, + name_prefix, + shadow_name + )), } .map_err(|_| Error::Overflow)?; @@ -233,29 +242,6 @@ impl Subscribe { .map(|(topic, qos)| Ok((Topic::from(*topic).format(thing_name, shadow_name)?, *qos))) .collect() } - - pub fn send(self, mqtt: &M, shadow_name: Option<&'static str>) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics(mqtt.client_id(), shadow_name)?; - - let topics: heapless::Vec<_, N> = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - qos: *qos, - }) - .collect(); - - debug!("Subscribing!"); - - for t in topics.chunks(5) { - mqtt.subscribe(t)?; - } - Ok(()) - } } #[derive(Default)] @@ -295,19 +281,4 @@ impl Unsubscribe { .map(|topic| Topic::from(*topic).format(thing_name, shadow_name)) .collect() } - - pub fn send(self, mqtt: &M, shadow_name: Option<&'static str>) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics(mqtt.client_id(), shadow_name)?; - let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - - for t in topics.chunks(5) { - mqtt.unsubscribe(t)?; - } - - Ok(()) - } } diff --git a/tests/common/file_handler.rs b/tests/common/file_handler.rs index 942a082..d90e995 100644 --- a/tests/common/file_handler.rs +++ b/tests/common/file_handler.rs @@ -1,14 +1,13 @@ use core::ops::Deref; -use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embassy_sync::mutex::Mutex; +use embedded_storage_async::nor_flash::{ErrorType, NorFlash, ReadNorFlash}; use rustot::ota::{ - self, + encoding::json, pal::{OtaPal, OtaPalError, PalImageState}, }; use sha2::{Digest, Sha256}; use std::{ - fs::File, - io::{Cursor, Read, Write}, + convert::Infallible, + io::{Cursor, Write}, }; #[derive(Debug, PartialEq, Eq)] @@ -17,8 +16,44 @@ pub enum State { Boot, } +pub struct BlockFile { + filebuf: Cursor>, +} + +impl NorFlash for BlockFile { + const WRITE_SIZE: usize = 1; + + const ERASE_SIZE: usize = 1; + + async fn erase(&mut self, _from: u32, _to: u32) -> Result<(), Self::Error> { + Ok(()) + } + + async fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> { + self.filebuf.set_position(offset as u64); + self.filebuf.write_all(bytes).unwrap(); + Ok(()) + } +} + +impl ReadNorFlash for BlockFile { + const READ_SIZE: usize = 1; + + async fn read(&mut self, _offset: u32, _bytes: &mut [u8]) -> Result<(), Self::Error> { + todo!() + } + + fn capacity(&self) -> usize { + self.filebuf.get_ref().capacity() + } +} + +impl ErrorType for BlockFile { + type Error = Infallible; +} + pub struct FileHandler { - filebuf: Option>>, + filebuf: Option, compare_file_path: String, pub plateform_state: State, } @@ -34,6 +69,8 @@ impl FileHandler { } impl OtaPal for FileHandler { + type BlockWriter = BlockFile; + async fn abort( &mut self, _file: &rustot::ota::encoding::FileContext, @@ -44,9 +81,10 @@ impl OtaPal for FileHandler { async fn create_file_for_rx( &mut self, file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { - self.filebuf = Some(Cursor::new(Vec::with_capacity(file.filesize))); - Ok(()) + ) -> Result<&mut Self::BlockWriter, OtaPalError> { + Ok(self.filebuf.get_or_insert(BlockFile { + filebuf: Cursor::new(Vec::with_capacity(file.filesize)), + })) } async fn get_platform_image_state(&mut self) -> Result { @@ -78,12 +116,12 @@ impl OtaPal for FileHandler { if let Some(ref mut buf) = &mut self.filebuf { log::debug!( "Closing completed file. Len: {}/{} -> {}", - buf.get_ref().len(), + buf.filebuf.get_ref().len(), file.filesize, file.filepath.as_str() ); - let mut expected_data = std::fs::read(self.compare_file_path.as_str()).unwrap(); + let expected_data = std::fs::read(self.compare_file_path.as_str()).unwrap(); let mut expected_hasher = ::new(); expected_hasher.update(&expected_data); let expected_hash = expected_hasher.finalize(); @@ -93,27 +131,19 @@ impl OtaPal for FileHandler { self.compare_file_path, file.filepath.as_str() ); - assert_eq!(buf.get_ref().len(), file.filesize); + assert_eq!(buf.filebuf.get_ref().len(), file.filesize); let mut hasher = ::new(); - hasher.update(&buf.get_ref()); + hasher.update(&buf.filebuf.get_ref()); assert_eq!(hasher.finalize().deref(), expected_hash.deref()); // Check file signature - match &file.signature { - ota::encoding::json::Signature::Sha1Rsa(_) => { - panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Rsa(_) => { - panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha1Ecdsa(_) => { - panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Ecdsa(sig) => { - assert_eq!(sig.as_str(), "This is my custom signature\\n") - } - } + let signature = match file.signature.as_ref() { + Some(json::Signature::Sha256Ecdsa(ref s)) => s.as_str(), + sig => panic!("Unexpected signature format! {:?}", sig), + }; + + assert_eq!(signature, "This is my custom signature\\n"); self.plateform_state = State::Swap; @@ -122,20 +152,4 @@ impl OtaPal for FileHandler { Err(OtaPalError::BadFileHandle) } } - - async fn write_block( - &mut self, - _file: &mut rustot::ota::encoding::FileContext, - block_offset: usize, - block_payload: &[u8], - ) -> Result { - if let Some(ref mut buf) = &mut self.filebuf { - buf.set_position(block_offset as u64); - buf.write(block_payload) - .map_err(|_e| OtaPalError::FileWriteFailed)?; - Ok(block_payload.len()) - } else { - Err(OtaPalError::BadFileHandle) - } - } } diff --git a/tests/common/network.rs b/tests/common/network.rs index dfbe27c..0cfe3db 100644 --- a/tests/common/network.rs +++ b/tests/common/network.rs @@ -40,7 +40,7 @@ impl Dns for Network { host: &str, addr_type: AddrType, ) -> Result { - for ip in tokio::net::lookup_host(host).await? { + for ip in tokio::net::lookup_host(format!("{}:0", host)).await? { match (&addr_type, ip) { (AddrType::IPv4 | AddrType::Either, SocketAddr::V4(ip)) => { return Ok(IpAddr::V4(Ipv4Addr::from(ip.ip().octets()))) @@ -114,7 +114,7 @@ impl Dns for TlsNetwork { addr_type: AddrType, ) -> Result { log::info!("Looking up {}", host); - for ip in tokio::net::lookup_host(host).await? { + for ip in tokio::net::lookup_host(format!("{}:0", host)).await? { log::info!("Found IP {}", ip); match (&addr_type, ip) { diff --git a/tests/ota.rs b/tests/ota.rs index ac82f18..da8c020 100644 --- a/tests/ota.rs +++ b/tests/ota.rs @@ -3,28 +3,21 @@ mod common; -use std::{net::ToSocketAddrs, process}; - use common::credentials; use common::file_handler::{FileHandler, State as FileHandlerState}; use common::network::TlsNetwork; use embassy_futures::select; -use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; -use embassy_time::Duration; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embedded_mqtt::transport::embedded_nal::NalTransport; -use embedded_mqtt::{ - Config, DomainBroker, IpBroker, Message, Publish, QoS, RetainHandling, State, Subscribe, - SubscribeTopic, -}; +use embedded_mqtt::{Config, DomainBroker, Message, State, Subscribe, SubscribeTopic}; use futures::StreamExt; -use serde::{Deserialize, Serialize}; -use static_cell::make_static; +use serde::Deserialize; +use static_cell::StaticCell; use rustot::{ jobs::{ self, data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, - JobTopic, StatusDetails, }, ota::{ self, @@ -49,7 +42,7 @@ impl<'a> Jobs<'a> { } fn handle_ota<'a, const SUBS: usize>( - message: Message<'a, NoopRawMutex, SUBS>, + message: Message<'a, SUBS>, config: &ota::config::Config, ) -> Option { let job = match jobs::Topic::from_str(message.topic_name()) { @@ -99,8 +92,7 @@ async fn test_mqtt_ota() { // Create the MQTT stack let broker = DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); - let config = - Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); + let config = Config::new(thing_name).keepalive_interval(embassy_time::Duration::from_secs(50)); static STATE: StaticCell> = StaticCell::new(); let state = STATE.init(State::::new()); @@ -110,26 +102,26 @@ async fn test_mqtt_ota() { let ota_fut = async { let mut jobs_subscription = client - .subscribe::<2>(Subscribe::new(&[ - SubscribeTopic { - topic_path: jobs::JobTopic::NotifyNext - .format::<64>(thing_name)? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - SubscribeTopic { - topic_path: jobs::JobTopic::DescribeAccepted("$next") - .format::<64>(thing_name)? - .as_str(), - maximum_qos: QoS::AtLeastOnce, - no_local: false, - retain_as_published: false, - retain_handling: RetainHandling::SendAtSubscribeTime, - }, - ])) + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + jobs::JobTopic::NotifyNext + .format::<64>(thing_name)? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + jobs::JobTopic::DescribeAccepted("$next") + .format::<64>(thing_name)? + .as_str(), + ) + .build(), + ]) + .build(), + ) .await?; Updater::check_for_job(&client).await?; @@ -167,7 +159,7 @@ async fn test_mqtt_ota() { Ok::<_, ota::error::OtaError>(()) }; - let mut transport = NalTransport::new(network); + let mut transport = NalTransport::new(network, broker); match embassy_time::with_timeout( embassy_time::Duration::from_secs(25), diff --git a/tests/provisioning.rs b/tests/provisioning.rs index 804d2c5..e9eafd5 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -3,23 +3,16 @@ mod common; -use std::{net::ToSocketAddrs, process}; - use common::credentials; use common::network::TlsNetwork; use ecdsa::Signature; use embassy_futures::select; use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embedded_mqtt::{ - transport::embedded_nal::NalTransport, Config, DomainBroker, IpBroker, Publish, State, - Subscribe, SubscribeTopic, -}; +use embedded_mqtt::{transport::embedded_nal::NalTransport, Config, DomainBroker, State}; use p256::{ecdsa::signature::Signer, NistP256}; -use rustot::provisioning::{ - topics::Topic, CredentialHandler, Credentials, Error, FleetProvisioner, -}; +use rustot::provisioning::{CredentialHandler, Credentials, Error, FleetProvisioner}; use serde::{Deserialize, Serialize}; -use static_cell::make_static; +use static_cell::StaticCell; pub struct OwnedCredentials { pub certificate_id: String, @@ -82,8 +75,7 @@ async fn test_provisioning() { // Create the MQTT stack let broker = DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); - let config = - Config::new(thing_name, broker).keepalive_interval(embassy_time::Duration::from_secs(50)); + let config = Config::new(thing_name).keepalive_interval(embassy_time::Duration::from_secs(50)); static STATE: StaticCell> = StaticCell::new(); let state = STATE.init(State::::new()); @@ -115,7 +107,7 @@ async fn test_provisioning() { &mut credential_handler, ); - let mut transport = NalTransport::new(network); + let mut transport = NalTransport::new(network, broker); let device_config = match embassy_time::with_timeout( embassy_time::Duration::from_secs(15), diff --git a/tests/shadows.rs b/tests/shadows.rs index cbd979e..af788dd 100644 --- a/tests/shadows.rs +++ b/tests/shadows.rs @@ -26,7 +26,17 @@ // use core::fmt::Write; -// use common::{clock::SysClock, credentials, network::Network}; +// use common::{ +// clock::SysClock, +// credentials, +// network::{Network, TlsNetwork}, +// }; +// use embassy_futures::select; +// use embassy_sync::blocking_mutex::raw::NoopRawMutex; +// use embedded_mqtt::{ +// transport::embedded_nal::{self, NalTransport}, +// DomainBroker, Properties, Publish, QoS, State, +// }; // use embedded_nal::Ipv4Addr; // use mqttrust::Mqtt; // use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification}; @@ -37,9 +47,7 @@ // use serde::{de::DeserializeOwned, Deserialize, Serialize}; // use smlang::statemachine; - -// const Q_SIZE: usize = 1024 * 6; -// static mut Q: BBBuffer = BBBuffer::new(); +// use static_cell::StaticCell; // #[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] // pub struct ConfigId(pub u8); @@ -284,7 +292,7 @@ // pub fn spin( // &mut self, // notification: Notification, -// mqtt_client: &mqttrust_core::Client<'static, 'static, Q_SIZE>, +// mqtt_client: &embedded_mqtt::MqttClient<'a, M, 1>, // ) -> bool { // log::info!("State: {:?}", self.state()); // match (self.state(), notification) { @@ -294,15 +302,19 @@ // (&States::DeleteShadow, Notification::Suback(_)) => { // mqtt_client // .publish( -// &Topic::Update -// .format::<128>( -// mqtt_client.client_id(), -// ::NAME, +// Publish::builder() +// .topic_name( +// &Topic::Update +// .format::<128>( +// mqtt_client.client_id(), +// ::NAME, +// ) +// .unwrap(), // ) -// .unwrap(), -// b"{\"state\":{\"desired\":null,\"reported\":null}}", -// mqttrust::QoS::AtLeastOnce, +// .payload(b"{\"state\":{\"desired\":null,\"reported\":null}}") +// .build(), // ) +// .await // .unwrap(); // self.process_event(Events::Get).unwrap(); @@ -387,14 +399,17 @@ // mqtt_client // .publish( -// &Topic::Update -// .format::<128>( -// mqtt_client.client_id(), -// ::NAME, +// Publish::builder() +// .topic_name( +// &Topic::Update +// .format::<128>( +// mqtt_client.client_id(), +// ::NAME, +// ) +// .unwrap(), // ) -// .unwrap(), -// payload.as_bytes(), -// mqttrust::QoS::AtLeastOnce, +// .payload(payload.as_bytes()) +// .build(), // ) // .unwrap(); // self.process_event(Events::Ack).unwrap(); @@ -453,50 +468,66 @@ // } // } -// #[test] -// fn test_shadows() { +// #[tokio::test(flavor = "current_thread")] +// async fn test_shadows() { // env_logger::init(); -// let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - // log::info!("Starting shadows test..."); -// let hostname = credentials::HOSTNAME.unwrap(); // let (thing_name, identity) = credentials::identity(); -// let connector = TlsConnector::builder() -// .identity(identity) -// .add_root_certificate(credentials::root_ca()) -// .build() -// .unwrap(); +// let hostname = credentials::HOSTNAME.unwrap(); -// let mut network = Network::new_tls(connector, std::string::String::from(hostname)); +// static NETWORK: StaticCell = StaticCell::new(); +// let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); -// let mut mqtt_eventloop = EventLoop::new( -// c, -// SysClock::new(), -// MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), -// ); +// // Create the MQTT stack +// let broker = +// DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), &network).unwrap(); +// let config = embedded_mqtt::Config::new(thing_name) +// .keepalive_interval(embassy_time::Duration::from_secs(50)); -// let mqtt_client = mqttrust_core::Client::new(p, thing_name); +// let mut state = State::::new(); +// let (mut stack, client) = embedded_mqtt::new(&mut state, config); -// let mut test_state = StateMachine::new(TestContext { -// shadow: Shadow::new(WifiConfig::default(), &mqtt_client, true).unwrap(), -// update_cnt: 0, -// }); +// let mqtt_client = client; -// loop { -// if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { -// log::info!("Successfully connected to broker"); -// } +// let shadow = Shadow::new(WifiConfig::default(), &mqtt_client).unwrap(); -// match mqtt_eventloop.yield_event(&mut network) { -// Ok(notification) => { -// if test_state.spin(notification, &mqtt_client) { -// break; -// } -// } -// Err(_) => {} +// // loop { +// // if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { +// // log::info!("Successfully connected to broker"); +// // } + +// // match mqtt_eventloop.yield_event(&mut network) { +// // Ok(notification) => { +// // if test_state.spin(notification, &mqtt_client) { +// // break; +// // } +// // } +// // Err(_) => {} +// // } +// // } + +// // cloud_updater(mqtt_client); + +// let shadows_fut = async { +// shadow.next_update().await; +// todo!() +// }; + +// let mut transport = NalTransport::new(network, broker); + +// match embassy_time::with_timeout( +// embassy_time::Duration::from_secs(25), +// select::select(stack.run(&mut transport), shadows_fut), +// ) +// .await +// .unwrap() +// { +// select::Either::First(_) => { +// unreachable!() // } -// } +// select::Either::Second(result) => result.unwrap(), +// }; // }