Skip to content

Commit

Permalink
Dont blindly copy serde attrs in ShadowPatch derive, but rather intro…
Browse files Browse the repository at this point in the history
…duce patch attr that specifies attrs to copy
  • Loading branch information
MathiasKoch committed Jul 31, 2024
1 parent c8667e5 commit 37469e4
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 82 deletions.
24 changes: 12 additions & 12 deletions shadow_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use syn::DeriveInput;
use syn::Generics;
use syn::Ident;
use syn::Result;
use syn::{parenthesized, Attribute, Error, Field, LitStr};
use syn::{parenthesized, Error, Field, LitStr};

#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field))]
#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field, patch))]
pub fn shadow_state(input: TokenStream) -> TokenStream {
match parse_macro_input!(input as ParseInput) {
ParseInput::Struct(input) => {
Expand All @@ -32,7 +32,7 @@ pub fn shadow_state(input: TokenStream) -> TokenStream {
}
}

#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, serde))]
#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, patch))]
pub fn shadow_patch(input: TokenStream) -> TokenStream {
TokenStream::from(match parse_macro_input!(input as ParseInput) {
ParseInput::Struct(input) => generate_shadow_patch_struct(&input),
Expand All @@ -56,7 +56,7 @@ struct StructParseInput {
pub ident: Ident,
pub generics: Generics,
pub shadow_fields: Vec<Field>,
pub copy_attrs: Vec<Attribute>,
pub copy_attrs: Vec<proc_macro2::TokenStream>,
pub shadow_name: Option<LitStr>,
}

Expand All @@ -67,8 +67,6 @@ impl Parse for ParseInput {
let mut shadow_name = None;
let mut copy_attrs = vec![];

let attrs_to_copy = ["serde"];

// Parse valid container attributes
for attr in derive_input.attrs {
if attr.path.is_ident("shadow") {
Expand All @@ -78,12 +76,14 @@ impl Parse for ParseInput {
content.parse()
}
shadow_name = Some(shadow_arg.parse2(attr.tokens)?);
} else if attrs_to_copy
.iter()
.find(|a| attr.path.is_ident(a))
.is_some()
{
copy_attrs.push(attr);
} else if attr.path.is_ident("patch") {
fn patch_arg(input: ParseStream) -> Result<proc_macro2::TokenStream> {
let content;
parenthesized!(content in input);
content.parse()
}
let args = patch_arg.parse2(attr.tokens)?;
copy_attrs.push(quote! { #[ #args ]})
}
}

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![cfg_attr(not(any(test, feature = "std")), no_std)]
#![allow(async_fn_in_trait)]
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]

// This mod MUST go first, so that the others see its macros.
Expand Down
12 changes: 9 additions & 3 deletions src/ota/encoding/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ impl<'a> FileDescription<'a> {
return Some(Signature::Sha1Rsa(heapless::String::try_from(sig).unwrap()));
}
if let Some(sig) = self.sha256_rsa {
return Some(Signature::Sha256Rsa(heapless::String::try_from(sig).unwrap()));
return Some(Signature::Sha256Rsa(
heapless::String::try_from(sig).unwrap(),
));
}
if let Some(sig) = self.sha1_ecdsa {
return Some(Signature::Sha1Ecdsa(heapless::String::try_from(sig).unwrap()));
return Some(Signature::Sha1Ecdsa(
heapless::String::try_from(sig).unwrap(),
));
}
if let Some(sig) = self.sha256_ecdsa {
return Some(Signature::Sha256Ecdsa(heapless::String::try_from(sig).unwrap()));
return Some(Signature::Sha256Ecdsa(
heapless::String::try_from(sig).unwrap(),
));
}
None
}
Expand Down
1 change: 0 additions & 1 deletion src/provisioning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ impl FleetProvisioner {
where
C: DeserializeOwned,
{
use crate::provisioning::data_types::CreateCertificateFromCsrResponse;
let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?;
let mut message = create_subscription
.next()
Expand Down
165 changes: 99 additions & 66 deletions src/shadows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ pub use data_types::Patch;
use embassy_sync::blocking_mutex::raw::RawMutex;
use embedded_mqtt::{Publish, QoS, RetainHandling, Subscribe, SubscribeTopic};
pub use error::Error;
use serde::de::DeserializeOwned;
pub use shadow_derive as derive;
pub use shadow_diff::ShadowPatch;

Expand All @@ -23,7 +22,7 @@ const MAX_TOPIC_LEN: usize = 128;
const PARTIAL_REQUEST_OVERHEAD: usize = 64;
const CLASSIC_SHADOW: &str = "Classic";

pub trait ShadowState: ShadowPatch {
pub trait ShadowState: ShadowPatch + Default {
const NAME: Option<&'static str>;

const MAX_PAYLOAD_SIZE: usize = 512;
Expand Down Expand Up @@ -105,8 +104,6 @@ where
) -> Result<(), Error> {
if let Some(delta) = delta {
state.apply_patch(delta);
} else {
error!("Delta was NONE");
}

debug!(
Expand Down Expand Up @@ -138,20 +135,15 @@ where
match Topic::from_str(message.topic_name()) {
Some((Topic::UpdateAccepted, _, _)) => Ok(()),
Some((Topic::UpdateRejected, _, _)) => {
match serde_json_core::from_slice::<ErrorResponse>(message.payload()) {
//Try to return shadow error from message error code. Return NotFound otherwise
Ok((error_response, _)) => {
if let Ok(shadow_error) = error_response.try_into() {
Err(Error::ShadowError(shadow_error))
} else {
Err(Error::ShadowError(error::ShadowError::NotFound))
}
}
Err(_) => {
error!("Error deserializing GetRejected message");
Err(Error::ShadowError(error::ShadowError::NotFound))
}
}
let (error_response, _) =
serde_json_core::from_slice::<ErrorResponse>(message.payload())
.map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?;

Err(Error::ShadowError(
error_response
.try_into()
.unwrap_or(error::ShadowError::NotFound),
))
}
_ => {
error!("Expected Topic name GetRejected or GetAccepted but got something else");
Expand All @@ -171,38 +163,30 @@ where
//Deserialize message
//Persist shadow and return new shadow
match Topic::from_str(get_message.topic_name()) {
Some((Topic::GetAccepted, _, _)) => {
match serde_json_core::from_slice::<AcceptedResponse<S::PatchState>>(
get_message.payload(),
) {
Ok((response, _)) => match response.state.desired {
Some(desired) => Ok(desired),
None => {
error!("Shadow state was deserialized but desired was None");
Err(Error::InvalidPayload)
}
},
Err(_) => {
error!("Failed deserializing shadow payload");
Err(Error::InvalidPayload)
}
}
}
Some((Topic::GetAccepted, _, _)) => serde_json_core::from_slice::<
AcceptedResponse<S::PatchState>,
>(get_message.payload())
.ok()
.and_then(|(r, _)| r.state.desired)
.ok_or(Error::InvalidPayload),
Some((Topic::GetRejected, _, _)) => {
match serde_json_core::from_slice::<ErrorResponse>(get_message.payload()) {
//Try to return shadow error from message error code. Return NotFound otherwise
Ok((error_response, _)) => {
if let Ok(shadow_error) = error_response.try_into() {
Err(Error::ShadowError(shadow_error))
} else {
Err(Error::ShadowError(error::ShadowError::NotFound))
}
}
Err(_) => {
error!("Error deserializing GetRejected message");
Err(Error::ShadowError(error::ShadowError::NotFound))
}
let (error_response, _) =
serde_json_core::from_slice::<ErrorResponse>(get_message.payload())
.map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?;

if error_response.code == 404 {
debug!(
"[{:?}] Thing has no shadow document. Creating with defaults...",
S::NAME.unwrap_or_else(|| CLASSIC_SHADOW)
);
return self.create_shadow().await;
}

Err(Error::ShadowError(
error_response
.try_into()
.unwrap_or(error::ShadowError::NotFound),
))
}
_ => {
error!("Expected Topic name GetRejected or GetAccepted but got something else");
Expand All @@ -223,27 +207,76 @@ where
match Topic::from_str(message.topic_name()) {
Some((Topic::DeleteAccepted, _, _)) => Ok(()),
Some((Topic::DeleteRejected, _, _)) => {
match serde_json_core::from_slice::<ErrorResponse>(message.payload()) {
//Try to return shadow error from message error code. Return NotFound otherwise
Ok((error_response, _)) => {
if let Ok(shadow_error) = error_response.try_into() {
Err(Error::ShadowError(shadow_error))
} else {
Err(Error::ShadowError(error::ShadowError::NotFound))
}
}
Err(_) => {
error!("Error deserializing GetRejected message");
Err(Error::ShadowError(error::ShadowError::NotFound))
}
}
let (error_response, _) =
serde_json_core::from_slice::<ErrorResponse>(message.payload())
.map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?;

Err(Error::ShadowError(
error_response
.try_into()
.unwrap_or(error::ShadowError::NotFound),
))
}
_ => {
error!("Expected Topic name GetRejected or GetAccepted but got something else");
Err(Error::WrongShadowName)
}
}
}

pub async fn create_shadow(&mut self) -> Result<S::PatchState, Error> {
debug!(
"[{:?}] Creating initial shadow value.",
S::NAME.unwrap_or(CLASSIC_SHADOW),
);

let state = S::default();

let request = data_types::Request {
state: data_types::State {
reported: Some(&state),
desired: Some(&state),
},
client_token: None,
version: None,
};

// FIXME: Serialize directly into the publish payload through `DeferredPublish` API
let payload = serde_json_core::to_vec::<
_,
{ S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD },
>(&request)
.map_err(|_| Error::Overflow)?;

let message = self
.publish_and_subscribe(Topic::Update, payload.as_slice())
.await?;

match Topic::from_str(message.topic_name()) {
Some((Topic::UpdateAccepted, _, _)) => {
serde_json_core::from_slice::<AcceptedResponse<S::PatchState>>(message.payload())
.ok()
.and_then(|(r, _)| r.state.desired)
.ok_or(Error::InvalidPayload)
}
Some((Topic::UpdateRejected, _, _)) => {
let (error_response, _) =
serde_json_core::from_slice::<ErrorResponse>(message.payload())
.map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?;

Err(Error::ShadowError(
error_response
.try_into()
.unwrap_or(error::ShadowError::NotFound),
))
}
_ => {
error!("Expected Topic name GetRejected or GetAccepted but got something else");
Err(Error::WrongShadowName)
}
}
}

///This function will subscribe to accepted and rejected topics and then do a publish.
///It will only return when something is accepted or rejected
///Topic is the topic you want to publish to
Expand Down Expand Up @@ -325,14 +358,14 @@ where
{
/// Instantiate a new shadow that will be automatically persisted to NVM
/// based on the passed `DAO`.
pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, dao: D) -> Result<Self, Error> {
pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>, dao: D) -> Self {
let handler = ShadowHandler {
mqtt,
subscription: None,
_shadow: PhantomData,
};

Ok(Self { handler, dao })
Self { handler, dao }
}

/// Wait delta will subscribe if not already to Updatedelta and wait for changes
Expand Down Expand Up @@ -427,13 +460,13 @@ where
[(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:,
{
/// Instantiate a new non-persisted shadow
pub fn new(state: S, mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>) -> Result<Self, Error> {
pub fn new(state: S, mqtt: &'m embedded_mqtt::MqttClient<'a, M, SUBS>) -> Self {
let handler = ShadowHandler {
mqtt,
subscription: None,
_shadow: PhantomData,
};
Ok(Self { handler, state })
Self { handler, state }
}

/// Handle incoming publish messages from the cloud on any topics relevant
Expand Down

0 comments on commit 37469e4

Please sign in to comment.