Skip to content

Commit

Permalink
Feature/async shadows (#57)
Browse files Browse the repository at this point in the history
* Wip on rewriting shadows to async

* Further work on async shadows. Still working on compile errors

* Fix: Async shadow (#60)

* fix asyunc shadow

* renaming of handle message and some linting

* shadows error fix and handle delta should wait for connected

* fmt

* Add const generic SUBS to shadows

* Fix/async shadow (#61)

* fix asyunc shadow

* renaming of handle message and some linting

* shadows error fix and handle delta should wait for connected

* fmt

* subscribe to get shadow and do not overwrite desired state

* Get shadow should deserialize patchState

* wait for accepted and rejected for delete and update as well

* Make sure OTA job documents can be deserialized with no codesigning properties in the document (#62)

* Dont blindly copy serde attrs in ShadowPatch derive, but rather introduce patch attr that specifies attrs to copy

* Add skip_serializing_if none to all patchstate fields

* Shadows: Check client token on all request/response pairs

* Create initial shadow state, if dao read fails during getShadow operation

* remove some client token checks

* Fix not holding delta message across report call

* handle delta on get shadow

* Bump embedded-mqtt

* Fix all tests

* Allow reporting non-persisted shadows directly, through a report fn

* Bump embedded-mqtt

* Enhancement(async): Mutex shadow to borrow as immutable (#63)

* Use mutex to borrow shadow as immutable

* remove .git in embedded-mqtt dependency

---------

Co-authored-by: Kenneth Knudsen <[email protected]>
Co-authored-by: Kenneth Knudsen <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent f424608 commit f924e1e
Show file tree
Hide file tree
Showing 23 changed files with 979 additions and 917 deletions.
17 changes: 13 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ members = ["shadow_derive"]
[package]
name = "rustot"
version = "0.5.0"
authors = ["Mathias Koch <[email protected]>"]
authors = ["Factbird team <[email protected]>"]
description = "AWS IoT"
readme = "README.md"
keywords = ["iot", "no-std"]
categories = ["embedded", "no-std"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/BlackbirdHQ/rustot"
repository = "https://github.com/FactbirdHQ/rustot"
edition = "2021"
documentation = "https://docs.rs/rustot"
exclude = ["/documentation"]
Expand All @@ -29,7 +29,7 @@ serde_cbor = { version = "0.11", default-features = false, optional = true }
serde-json-core = { version = "0.5" }
shadow-derive = { path = "shadow_derive", version = "0.2.1" }
embedded-storage-async = "0.4"
embedded-mqtt = { git = "ssh://[email protected]/BlackbirdHQ/embedded-mqtt/", rev = "d766137" }
embedded-mqtt = { git = "ssh://[email protected]/FactbirdHQ/embedded-mqtt", rev = "d2b7c02" }

futures = { version = "0.3.28", default-features = false }

Expand All @@ -46,6 +46,7 @@ embedded-nal-async = "0.7"
env_logger = "0.11"
sha2 = "0.10.1"
static_cell = { version = "2", features = ["nightly"] }

tokio = { version = "1.33", default-features = false, features = [
"macros",
"rt",
Expand Down Expand Up @@ -73,5 +74,13 @@ ota_http_data = []

std = ["serde/std", "serde_cbor?/std"]

defmt = ["dep:defmt", "heapless/defmt-03", "embedded-mqtt/defmt"]
defmt = [
"dep:defmt",
"heapless/defmt-03",
"embedded-mqtt/defmt",
"embassy-time/defmt",
]
log = ["dep:log", "embedded-mqtt/log"]

# [patch."ssh://[email protected]/FactbirdHQ/embedded-mqtt"]
# embedded-mqtt = { path = "../embedded-mqtt" }
8 changes: 4 additions & 4 deletions documentation/stack.drawio
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
<mxfile host="65bd71144e">
<diagram id="OCeliohVpZ0w719LyYzf" name="Page-1">
<mxGraphModel dx="492" dy="467" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="1169" pageHeight="827" math="0" shadow="0">
<mxGraphModel dx="581" dy="907" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="1169" pageHeight="827" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="2" value="embedded-nal" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxCell id="2" value="embedded-io" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="210" y="450" width="480" height="60" as="geometry"/>
</mxCell>
<mxCell id="3" value="mqttrust" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxCell id="3" value="embedded-mqtt" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="210" y="390" width="480" height="60" as="geometry"/>
</mxCell>
<mxCell id="5" value="rustot::jobs" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
Expand All @@ -22,7 +22,7 @@
<mxCell id="8" value="rustot::dev_defender" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="330" y="330" width="120" height="60" as="geometry"/>
</mxCell>
<mxCell id="9" value="rustot::provisioning" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
<mxCell id="9" value="rustot::provisioning" style="rounded=0;whiteSpace=wrap;html=1;" parent="1" vertex="1">
<mxGeometry x="210" y="330" width="120" height="60" as="geometry"/>
</mxCell>
</root>
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[toolchain]
channel = "1.79"
channel = "nightly-2024-07-17"
components = ["rust-src", "rustfmt", "llvm-tools"]
targets = [
"x86_64-unknown-linux-gnu",
Expand Down
26 changes: 13 additions & 13 deletions shadow_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use syn::DeriveInput;
use syn::Generics;
use syn::Ident;
use syn::Result;
use syn::{parenthesized, Attribute, Error, Field, LitStr};
use syn::{parenthesized, Error, Field, LitStr};

#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field))]
#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field, patch))]
pub fn shadow_state(input: TokenStream) -> TokenStream {
match parse_macro_input!(input as ParseInput) {
ParseInput::Struct(input) => {
Expand All @@ -32,7 +32,7 @@ pub fn shadow_state(input: TokenStream) -> TokenStream {
}
}

#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, serde))]
#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, patch))]
pub fn shadow_patch(input: TokenStream) -> TokenStream {
TokenStream::from(match parse_macro_input!(input as ParseInput) {
ParseInput::Struct(input) => generate_shadow_patch_struct(&input),
Expand All @@ -56,7 +56,7 @@ struct StructParseInput {
pub ident: Ident,
pub generics: Generics,
pub shadow_fields: Vec<Field>,
pub copy_attrs: Vec<Attribute>,
pub copy_attrs: Vec<proc_macro2::TokenStream>,
pub shadow_name: Option<LitStr>,
}

Expand All @@ -67,8 +67,6 @@ impl Parse for ParseInput {
let mut shadow_name = None;
let mut copy_attrs = vec![];

let attrs_to_copy = ["serde"];

// Parse valid container attributes
for attr in derive_input.attrs {
if attr.path.is_ident("shadow") {
Expand All @@ -78,12 +76,14 @@ impl Parse for ParseInput {
content.parse()
}
shadow_name = Some(shadow_arg.parse2(attr.tokens)?);
} else if attrs_to_copy
.iter()
.find(|a| attr.path.is_ident(a))
.is_some()
{
copy_attrs.push(attr);
} else if attr.path.is_ident("patch") {
fn patch_arg(input: ParseStream) -> Result<proc_macro2::TokenStream> {
let content;
parenthesized!(content in input);
content.parse()
}
let args = patch_arg.parse2(attr.tokens)?;
copy_attrs.push(quote! { #[ #args ]})
}
}

Expand Down Expand Up @@ -161,7 +161,7 @@ fn create_optional_fields(fields: &Vec<Field>) -> Vec<proc_macro2::TokenStream>
Some(if type_name_string.starts_with("Option<") {
quote! { #(#attrs)* pub #field_name: Option<rustot::shadows::Patch<<#type_name as rustot::shadows::ShadowPatch>::PatchState>> }
} else {
quote! { #(#attrs)* pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> }
quote! { #(#attrs)* #[serde(skip_serializing_if = "Option::is_none")] pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> }
})
}
})
Expand Down
15 changes: 8 additions & 7 deletions src/jobs/data_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ pub enum JobStatus {
Removed,
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ErrorCode {
/// The request was sent to a topic in the AWS IoT Jobs namespace that does
/// not map to any API.
Expand Down Expand Up @@ -89,7 +90,7 @@ pub struct GetPendingJobExecutionsResponse<'a> {
/// A client token used to correlate requests and responses. Enter an
/// arbitrary value here and it is reflected in the response.
#[serde(rename = "clientToken")]
pub client_token: &'a str,
pub client_token: Option<&'a str>,
}

/// Contains data about a job execution.
Expand Down Expand Up @@ -211,7 +212,7 @@ pub struct StartNextPendingJobExecutionResponse<'a, J> {
/// A client token used to correlate requests and responses. Enter an
/// arbitrary value here and it is reflected in the response.
#[serde(rename = "clientToken")]
pub client_token: &'a str,
pub client_token: Option<&'a str>,
}

/// Topic (accepted): $aws/things/{thingName}/jobs/{jobId}/update/accepted \
Expand All @@ -232,7 +233,7 @@ pub struct UpdateJobExecutionResponse<'a, J> {
/// A client token used to correlate requests and responses. Enter an
/// arbitrary value here and it is reflected in the response.
#[serde(rename = "clientToken")]
pub client_token: &'a str,
pub client_token: Option<&'a str>,
}

/// Sent whenever a job execution is added to or removed from the list of
Expand Down Expand Up @@ -289,7 +290,7 @@ pub struct Jobs {
/// service operation.
#[derive(Debug, PartialEq, Deserialize)]
pub struct ErrorResponse<'a> {
code: ErrorCode,
pub code: ErrorCode,
/// An error message string.
message: &'a str,
/// A client token used to correlate requests and responses. Enter an
Expand Down Expand Up @@ -394,7 +395,7 @@ mod test {
in_progress_jobs: Some(Vec::<JobExecutionSummary, MAX_RUNNING_JOBS>::new()),
queued_jobs: None,
timestamp: 1587381778,
client_token: "0:client_name",
client_token: Some("0:client_name"),
}
);

Expand Down Expand Up @@ -433,7 +434,7 @@ mod test {
in_progress_jobs: Some(Vec::<JobExecutionSummary, MAX_RUNNING_JOBS>::new()),
queued_jobs: Some(queued_jobs),
timestamp: 1587381778,
client_token: "0:client_name",
client_token: Some("0:client_name"),
}
);
}
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![cfg_attr(not(any(test, feature = "std")), no_std)]
#![allow(async_fn_in_trait)]
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]

// This mod MUST go first, so that the others see its macros.
pub(crate) mod fmt;
Expand All @@ -8,6 +10,6 @@ pub mod jobs;
#[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))]
pub mod ota;
pub mod provisioning;
// pub mod shadows;
pub mod shadows;

pub use serde_cbor;
116 changes: 90 additions & 26 deletions src/ota/control_interface/mqtt.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use core::fmt::Write;

use bitmaps::{Bits, BitsImpl};
use embassy_sync::blocking_mutex::raw::RawMutex;
use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS};
use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS, Subscribe, SubscribeTopic};
use futures::StreamExt as _;

use super::ControlInterface;
use crate::jobs::data_types::JobStatus;
use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN};
use crate::jobs::data_types::{ErrorResponse, JobStatus, UpdateJobExecutionResponse};
use crate::jobs::{JobError, JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN};
use crate::ota::config::Config;
use crate::ota::encoding::json::JobStatusReason;
use crate::ota::encoding::FileContext;
use crate::ota::encoding::{self, FileContext};
use crate::ota::error::OtaError;

impl<'a, M: RawMutex, const SUBS: usize> ControlInterface
for embedded_mqtt::MqttClient<'a, M, SUBS>
impl<'a, M: RawMutex, const SUBS: usize> ControlInterface for embedded_mqtt::MqttClient<'a, M, SUBS>
where
BitsImpl<{ SUBS }>: Bits,
{
/// Check for next available OTA job from the job service by publishing a
/// "get next job" message to the job service.
Expand All @@ -21,15 +24,12 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface
let mut buf = [0u8; 512];
let (topic, payload_len) = Jobs::describe().topic_payload(self.client_id(), &mut buf)?;

self.publish(Publish {
dup: false,
qos: QoS::AtLeastOnce,
retain: false,
pid: None,
topic_name: &topic,
payload: &buf[..payload_len],
properties: embedded_mqtt::Properties::Slice(&[]),
})
self.publish(
Publish::builder()
.topic_name(&topic)
.payload(&buf[..payload_len])
.build(),
)
.await?;

Ok(())
Expand Down Expand Up @@ -69,7 +69,7 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface
}

// Don't override the progress on succeeded, nor on self-test
// active. (Cases where progess counter is lost due to device
// active. (Cases where progress counter is lost due to device
// restarts)
if status != JobStatus::Succeeded && reason != JobStatusReason::SelfTestActive {
let mut progress = heapless::String::new();
Expand All @@ -93,29 +93,93 @@ impl<'a, M: RawMutex, const SUBS: usize> ControlInterface
}
}

let mut sub = self
.subscribe::<2>(
Subscribe::builder()
.topics(&[
SubscribeTopic::builder()
.topic_path(
JobTopic::UpdateAccepted(file_ctx.job_name.as_str())
.format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>(
self.client_id(),
)?
.as_str(),
)
.build(),
SubscribeTopic::builder()
.topic_path(
JobTopic::UpdateRejected(file_ctx.job_name.as_str())
.format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>(
self.client_id(),
)?
.as_str(),
)
.build(),
])
.build(),
)
.await?;

let topic = JobTopic::Update(file_ctx.job_name.as_str())
.format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?;
let payload = DeferredPayload::new(
|buf| {
Jobs::update(status)
.client_token(self.client_id())
.status_details(&file_ctx.status_details)
.payload(buf)
.map_err(|_| EncodingError::BufferSize)
},
512,
);

self.publish(Publish {
dup: false,
qos,
retain: false,
pid: None,
topic_name: &topic,
payload,
properties: embedded_mqtt::Properties::Slice(&[]),
})
self.publish(
Publish::builder()
.qos(qos)
.topic_name(&topic)
.payload(payload)
.build(),
)
.await?;

Ok(())
loop {
let message = sub.next().await.ok_or(JobError::Encoding)?;

// Check if topic is GetAccepted
match crate::jobs::Topic::from_str(message.topic_name()) {
Some(crate::jobs::Topic::UpdateAccepted(_)) => {
// Check client token
let (response, _) = serde_json_core::from_slice::<
UpdateJobExecutionResponse<encoding::json::OtaJob<'_>>,
>(message.payload())
.map_err(|_| JobError::Encoding)?;

if response.client_token != Some(self.client_id()) {
error!(
"Unexpected client token received: {}, expected: {}",
response.client_token.unwrap_or("None"),
self.client_id()
);
continue;
}

return Ok(());
}
Some(crate::jobs::Topic::UpdateRejected(_)) => {
let (error_response, _) =
serde_json_core::from_slice::<ErrorResponse>(message.payload())
.map_err(|_| JobError::Encoding)?;

if error_response.client_token != Some(self.client_id()) {
continue;
}

return Err(OtaError::UpdateRejected(error_response.code));
}
_ => {
error!("Expected Topic name GetRejected or GetAccepted but got something else");
}
}
}
}
}
Loading

0 comments on commit f924e1e

Please sign in to comment.