diff --git a/Cargo.toml b/Cargo.toml index f8dd8aa..cae2c1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,7 +73,7 @@ hex = { version = "0.4.3", features = ["alloc"] } [features] -default = ["ota_mqtt_data", "metric_cbor", "provision_cbor"] +default = ["ota_mqtt_data", "metric_cbor", "provision_cbor", "std"] metric_cbor = ["dep:minicbor", "dep:minicbor-serde"] diff --git a/rustot_derive/src/shadow/generation/generator.rs b/rustot_derive/src/shadow/generation/generator.rs index a1c16bd..827b421 100644 --- a/rustot_derive/src/shadow/generation/generator.rs +++ b/rustot_derive/src/shadow/generation/generator.rs @@ -1,11 +1,8 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{punctuated::Punctuated, spanned::Spanned as _, Data, DeriveInput, Token}; +use syn::{punctuated::Punctuated, Data, DeriveInput, Token}; -use crate::shadow::{ - generation::variant_or_field_visitor::{borrow_fields, get_attr, has_shadow_arg, is_primitive}, - CFG_ATTRIBUTE, DEFAULT_ATTRIBUTE, -}; +use crate::shadow::{generation::variant_or_field_visitor::get_attr, DEFAULT_ATTRIBUTE}; pub trait Generator { fn generate(&mut self, original: &DeriveInput, output: &DeriveInput) -> TokenStream; @@ -21,125 +18,6 @@ impl Generator for NewGenerator { } } -pub struct GenerateFromImpl; - -impl GenerateFromImpl { - fn variables_actions<'a>( - fields: impl Iterator, - ) -> ( - Punctuated, - Punctuated, - ) { - fields.enumerate().fold( - (Punctuated::new(), Punctuated::new()), - |(mut variables, mut actions), (i, field)| { - let var_ident = field.ident.clone().unwrap_or_else(|| { - syn::Ident::new(&format!("{}", (b'a' + i as u8) as char), field.span()) - }); - - let action = if is_primitive(&field.ty) || has_shadow_arg(&field.attrs, "leaf") { - quote! {Some(#var_ident)} - } else { - quote! {Some(#var_ident.into())} - }; - - actions.push( - field - .ident - .as_ref() - .map(|ident| quote! {#ident: #action}) - .unwrap_or(action), - ); - - variables.push(var_ident); - - (variables, actions) - }, - ) - } -} - -impl Generator for GenerateFromImpl { - fn generate(&mut self, original: &DeriveInput, output: &DeriveInput) -> TokenStream { - let (impl_generics, ty_generics, where_clause) = original.generics.split_for_impl(); - let orig_name = &original.ident; - let new_name = &output.ident; - - let from_impl = match (&original.data, &output.data) { - (Data::Struct(data_struct_old), Data::Struct(data_struct_new)) => { - let original_fields = borrow_fields(data_struct_old); - let new_fields = borrow_fields(data_struct_new); - - let from_fields = original_fields.iter().fold(quote! {}, |acc, field| { - let is_leaf = is_primitive(&field.ty) || has_shadow_arg(&field.attrs, "leaf"); - - let has_new_field = new_fields - .iter() - .find(|&f| f.ident == field.ident) - .is_some(); - - let cfg_attr = get_attr(&field.attrs, CFG_ATTRIBUTE); - - let ident = &field.ident; - if !has_new_field { - quote! { #acc #cfg_attr #ident: None, } - } else if is_leaf { - quote! { #acc #cfg_attr #ident: Some(v.#ident), } - } else { - quote! { #acc #cfg_attr #ident: Some(v.#ident.into()), } - } - }); - - quote! { - Self { - #from_fields - } - } - } - (Data::Enum(data_struct_old), Data::Enum(_)) => { - let match_arms = data_struct_old - .variants - .iter() - .fold(Punctuated::::new(), |mut acc, variant| { - let variant_ident = &variant.ident; - let cfg_attr = get_attr(&variant.attrs, CFG_ATTRIBUTE); - - acc.push(match &variant.fields { - syn::Fields::Named(fields_named) => { - let (variables, actions) = Self::variables_actions(fields_named.named.iter()); - quote! {#cfg_attr #orig_name::#variant_ident { #variables } => Self::#variant_ident { #actions }} - } - syn::Fields::Unnamed(fields_unnamed) => { - let (variables, actions) = Self::variables_actions(fields_unnamed.unnamed.iter()); - quote! {#cfg_attr #orig_name::#variant_ident ( #variables ) => Self::#variant_ident ( #actions )} - } - syn::Fields::Unit => { - quote! {#cfg_attr #orig_name::#variant_ident => Self::#variant_ident} - } - }); - - acc - }); - - quote! { - match v { - #match_arms - } - } - } - _ => panic!(), - }; - - quote! { - impl #impl_generics From<#orig_name #ty_generics> for #new_name #ty_generics #where_clause { - fn from(v: #orig_name #ty_generics) -> Self { - #from_impl - } - } - } - } -} - pub struct DefaultGenerator(pub bool); impl Generator for DefaultGenerator { diff --git a/rustot_derive/src/shadow/generation/mod.rs b/rustot_derive/src/shadow/generation/mod.rs index 24030a2..17bc1f9 100644 --- a/rustot_derive/src/shadow/generation/mod.rs +++ b/rustot_derive/src/shadow/generation/mod.rs @@ -6,11 +6,13 @@ use generator::Generator; use modifier::Modifier; use proc_macro2::TokenStream; use quote::{format_ident, quote}; -use syn::{punctuated::Punctuated, Data, DeriveInput, Path, Token}; +use syn::{punctuated::Punctuated, spanned::Spanned as _, Data, DeriveInput, Path, Token}; use variant_or_field_visitor::{ borrow_fields_mut, get_attr, has_shadow_arg, is_primitive, VariantOrFieldVisitor, }; +use crate::shadow::generation::variant_or_field_visitor::borrow_fields; + use super::CFG_ATTRIBUTE; pub struct ShadowGenerator { @@ -92,6 +94,40 @@ impl GenerateShadowPatchImplVisitor { apply_patch_impl: quote! {}, } } + + fn variables_actions<'a>( + fields: impl Iterator, + ) -> ( + Punctuated, + Punctuated, + ) { + fields.enumerate().fold( + (Punctuated::new(), Punctuated::new()), + |(mut variables, mut actions), (i, field)| { + let var_ident = field.ident.clone().unwrap_or_else(|| { + syn::Ident::new(&format!("{}", (b'a' + i as u8) as char), field.span()) + }); + + let action = if is_primitive(&field.ty) || has_shadow_arg(&field.attrs, "leaf") { + quote! {Some(#var_ident)} + } else { + quote! {Some(#var_ident.into_reported())} + }; + + actions.push( + field + .ident + .as_ref() + .map(|ident| quote! {#ident: #action}) + .unwrap_or(action), + ); + + variables.push(var_ident); + + (variables, actions) + }, + ) + } } impl Generator for GenerateShadowPatchImplVisitor { @@ -116,6 +152,71 @@ impl Generator for GenerateShadowPatchImplVisitor { } }; + let into_reported_impl = match (&original.data, &output.data) { + (Data::Struct(data_struct_old), Data::Struct(data_struct_new)) => { + let original_fields = borrow_fields(data_struct_old); + let new_fields = borrow_fields(data_struct_new); + + let from_fields = original_fields.iter().fold(quote! {}, |acc, field| { + let is_leaf = is_primitive(&field.ty) || has_shadow_arg(&field.attrs, "leaf"); + + let has_new_field = new_fields + .iter() + .find(|&f| f.ident == field.ident) + .is_some(); + + let cfg_attr = get_attr(&field.attrs, CFG_ATTRIBUTE); + + let ident = &field.ident; + if !has_new_field { + quote! { #acc #cfg_attr #ident: None, } + } else if is_leaf { + quote! { #acc #cfg_attr #ident: Some(self.#ident), } + } else { + quote! { #acc #cfg_attr #ident: Some(self.#ident.into_reported()), } + } + }); + + quote! { + Self::Reported { + #from_fields + } + } + } + (Data::Enum(data_struct_old), Data::Enum(_)) => { + let match_arms = data_struct_old + .variants + .iter() + .fold(Punctuated::::new(), |mut acc, variant| { + let variant_ident = &variant.ident; + let cfg_attr = get_attr(&variant.attrs, CFG_ATTRIBUTE); + + acc.push(match &variant.fields { + syn::Fields::Named(fields_named) => { + let (variables, actions) = Self::variables_actions(fields_named.named.iter()); + quote! {#cfg_attr #orig_name::#variant_ident { #variables } => Self::Reported::#variant_ident { #actions }} + } + syn::Fields::Unnamed(fields_unnamed) => { + let (variables, actions) = Self::variables_actions(fields_unnamed.unnamed.iter()); + quote! {#cfg_attr #orig_name::#variant_ident ( #variables ) => Self::Reported::#variant_ident ( #actions )} + } + syn::Fields::Unit => { + quote! {#cfg_attr #orig_name::#variant_ident => Self::Reported::#variant_ident} + } + }); + + acc + }); + + quote! { + match self { + #match_arms + } + } + } + _ => panic!(), + }; + quote! { impl #impl_generics rustot::shadows::ShadowPatch for #orig_name #ty_generics #where_clause { type Delta = #delta_name #ty_generics; @@ -124,6 +225,10 @@ impl Generator for GenerateShadowPatchImplVisitor { fn apply_patch(&mut self, delta: Self::Delta) { #apply_patch_impl } + + fn into_reported(self) -> Self::Reported { + #into_reported_impl + } } } } diff --git a/rustot_derive/src/shadow/shadow_patch.rs b/rustot_derive/src/shadow/shadow_patch.rs index 1718afd..7416f33 100644 --- a/rustot_derive/src/shadow/shadow_patch.rs +++ b/rustot_derive/src/shadow/shadow_patch.rs @@ -6,7 +6,7 @@ use syn::{ }; use crate::shadow::generation::{ - generator::{DefaultGenerator, GenerateFromImpl, NewGenerator}, + generator::{DefaultGenerator, NewGenerator}, modifier::{RenameModifier, ReportOnlyModifier, WithDerivesModifier}, variant_or_field_visitor::{ AddSerdeSkipAttribute, RemoveShadowAttributesVisitor, SetNewTypeVisitor, @@ -128,7 +128,6 @@ pub fn shadow_patch(attr: TokenStream, input: TokenStream) -> TokenStream { .variant_or_field_visitor(&mut RemoveShadowAttributesVisitor) .generator(&mut NewGenerator) .modifier(&mut ReportOnlyModifier) - .generator(&mut GenerateFromImpl) .finalize() }; diff --git a/rustot_derive/tests/shadow.rs b/rustot_derive/tests/shadow.rs index 4eb7d8a..daa5ac8 100644 --- a/rustot_derive/tests/shadow.rs +++ b/rustot_derive/tests/shadow.rs @@ -59,7 +59,7 @@ fn nested() { assert_eq!(Foo::MAX_PAYLOAD_SIZE, 256); assert_eq!( - ReportedFoo::from(foo.clone()), + foo.clone().into_reported(), ReportedFoo { bar: Some(56), baz: Some("HelloWorld".to_string()), @@ -151,7 +151,7 @@ fn simple_enum() { bar: Some(Either::B), }); - assert_eq!(ReportedFoo::from(desired), reported); + assert_eq!(desired.into_reported(), reported); } #[test] @@ -211,7 +211,7 @@ fn complex_enum() { bar: Either::D(InnerA { hello: 56 }, InnerB::default()) } ); - assert_eq!(ReportedFoo::from(desired), reported); + assert_eq!(desired.into_reported(), reported); } #[test] @@ -271,7 +271,7 @@ fn manual_reported() { Self { bar: Some(v.bar), baz: Some(v.baz), - inner: Some(v.inner.into()), + inner: Some(v.inner.into_reported()), ..Default::default() } } diff --git a/src/shadows/alloc_impl.rs b/src/shadows/alloc_impl.rs new file mode 100644 index 0000000..1c9464d --- /dev/null +++ b/src/shadows/alloc_impl.rs @@ -0,0 +1,197 @@ +use core::hash::Hash; + +use serde::{de::DeserializeOwned, Serialize}; + +use crate::shadows::ShadowPatch; + +impl ShadowPatch for std::string::String { + type Delta = Self; + + type Reported = Self; + + fn apply_patch(&mut self, delta: Self::Delta) { + *self = delta; + } + + fn into_reported(self) -> Self::Reported { + self + } +} + +impl ShadowPatch for std::vec::Vec { + type Delta = Self; + + type Reported = Self; + + fn apply_patch(&mut self, delta: Self::Delta) { + *self = delta; + } + + fn into_reported(self) -> Self::Reported { + self + } +} + +impl ShadowPatch for std::collections::HashMap +where + K: Clone + Serialize + DeserializeOwned + Eq + Hash, + V: ShadowPatch, +{ + type Delta = std::collections::HashMap::Delta>; + + type Reported = std::collections::HashMap::Reported>; + + fn apply_patch(&mut self, delta: Self::Delta) { + for (key, value) in delta.into_iter() { + if let Some(entry) = self.get_mut(&key) { + entry.apply_patch(value.clone()); + } else { + let mut entry = V::default(); + entry.apply_patch(value.clone()); + self.insert(key.clone(), entry); + } + } + } + + fn into_reported(self) -> Self::Reported { + self.into_iter() + .map(|(k, v)| (k, v.into_reported())) + .collect() + } +} + +#[cfg(test)] +mod tests { + use crate as rustot; + use crate::shadows::ShadowPatch; + use rustot_derive::shadow_patch; + use serde::{Deserialize, Serialize}; + + use std::collections::HashMap; + + #[test] + fn string_shadow_patch() { + let mut s = String::from("hello"); + + // Test apply_patch with Some + s.apply_patch(String::from("world")); + assert_eq!(s, "world"); + + // Test into_reported + assert_eq!(s.into_reported(), "world"); + } + + #[test] + fn vec_shadow_patch() { + let mut v = vec![1, 2, 3]; + + // Test apply_patch with Some + v.apply_patch(vec![4, 5]); + assert_eq!(v, vec![4, 5]); + + // Test into_reported + assert_eq!(v.into_reported(), vec![4, 5]); + } + + #[test] + fn hashmap_shadow_patch() { + let mut map = HashMap::new(); + map.insert("a".to_string(), String::from("alpha")); + map.insert("b".to_string(), String::from("beta")); + + let mut delta = HashMap::new(); + delta.insert("a".to_string(), String::from("updated")); + delta.insert("c".to_string(), String::from("gamma")); + + // Test apply_patch + map.apply_patch(delta); + assert_eq!(map.get("a"), Some(&String::from("updated"))); + assert_eq!(map.get("b"), Some(&String::from("beta"))); + assert_eq!(map.get("c"), Some(&String::from("gamma"))); + + // Test into_reported + let reported = map.into_reported(); + assert_eq!(reported.get("a"), Some(&String::from("updated"))); + assert_eq!(reported.get("b"), Some(&String::from("beta"))); + assert_eq!(reported.get("c"), Some(&String::from("gamma"))); + } + + #[shadow_patch] + #[derive(Default, Clone, Debug, PartialEq, Deserialize, Serialize)] + struct Device { + name: String, + temperature: f32, + } + + #[test] + fn hashmap_with_struct_values() { + let mut devices = HashMap::new(); + devices.insert( + "sensor1".to_string(), + Device { + name: "Temperature Sensor".to_string(), + temperature: 22.5, + }, + ); + devices.insert( + "sensor2".to_string(), + Device { + name: "Humidity Sensor".to_string(), + temperature: 18.0, + }, + ); + + let mut delta = HashMap::new(); + delta.insert( + "sensor1".to_string(), + DeltaDevice { + name: Some("Temperature Sensor".to_string()), + temperature: Some(25.0), + }, + ); + delta.insert( + "sensor3".to_string(), + DeltaDevice { + name: Some("New Sensor".to_string()), + temperature: Some(20.0), + }, + ); + + // Apply patch + devices.apply_patch(delta); + + // Verify updates + assert_eq!( + devices.get("sensor1"), + Some(&Device { + name: "Temperature Sensor".to_string(), + temperature: 25.0, + }) + ); + assert_eq!( + devices.get("sensor2"), + Some(&Device { + name: "Humidity Sensor".to_string(), + temperature: 18.0, + }) + ); + assert_eq!( + devices.get("sensor3"), + Some(&Device { + name: "New Sensor".to_string(), + temperature: 20.0, + }) + ); + + // Test into_reported + let reported = devices.into_reported(); + assert_eq!(reported.len(), 3); + assert_eq!( + reported.get("sensor1"), + Some(&ReportedDevice { + name: Some("Temperature Sensor".to_string()), + temperature: Some(25.0), + }) + ); + } +} diff --git a/src/shadows/data_types.rs b/src/shadows/data_types.rs index 18fd2a3..6dd0f9f 100644 --- a/src/shadows/data_types.rs +++ b/src/shadows/data_types.rs @@ -34,7 +34,7 @@ pub struct RequestState { pub reported: Option, } -#[derive(Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct DeltaState { pub desired: Option, diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index fc8cf87..448aeb5 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -3,6 +3,9 @@ pub mod data_types; pub mod error; pub mod topics; +#[cfg(feature = "std")] +mod alloc_impl; + pub use rustot_derive; use core::{marker::PhantomData, ops::DerefMut}; @@ -39,9 +42,11 @@ pub trait ShadowPatch: Default + Clone + Sized { type Delta: DeserializeOwned + Serialize + Clone + Default; // Contains all fields from `Delta` + additional optional fields - type Reported: From + Serialize + Default; + type Reported: Serialize + Default; fn apply_patch(&mut self, delta: Self::Delta); + + fn into_reported(self) -> Self::Reported; } struct ShadowHandler<'a, 'm, M: RawMutex, S> { @@ -391,7 +396,7 @@ where state.apply_patch(delta.clone()); self.handler - .update_shadow(None, Some(state.clone().into())) + .update_shadow(None, Some(state.clone().into_reported())) .await?; self.dao.lock().await.write(&state).await?; @@ -415,7 +420,7 @@ where state.apply_patch(delta.clone()); self.dao.lock().await.write(&state).await?; self.handler - .update_shadow(None, Some(state.clone().into())) + .update_shadow(None, Some(state.clone().into_reported())) .await?; } @@ -426,7 +431,9 @@ where pub async fn report(&self) -> Result<(), Error> { let state = self.dao.lock().await.read().await?; - self.handler.update_shadow(None, Some(state.into())).await?; + self.handler + .update_shadow(None, Some(state.into_reported())) + .await?; Ok(()) } @@ -435,12 +442,6 @@ where /// This function will update the desired state of the shadow in the cloud, /// and depending on whether the state update is rejected or accepted, it /// will automatically update the local version after response - /// - /// The returned `bool` from the update closure will determine whether the - /// update is persisted using the `DAO`, or just updated in the cloud. This - /// can be handy for activity or status field updates that are not relevant - /// to store persistent on the device, but are required to be part of the - /// same cloud shadow. pub async fn update(&self, f: F) -> Result<(), Error> { let mut update = S::Reported::default(); let mut state = self.dao.lock().await.read().await?; @@ -520,7 +521,7 @@ where self.state.apply_patch(delta.clone()); self.handler - .update_shadow(None, Some(self.state.clone().into())) + .update_shadow(None, Some(self.state.clone().into_reported())) .await?; } @@ -535,7 +536,7 @@ where /// Report the state of the shadow. pub async fn report(&mut self) -> Result<(), Error> { self.handler - .update_shadow(None, Some(self.state.clone().into())) + .update_shadow(None, Some(self.state.clone().into_reported())) .await?; Ok(()) } @@ -566,7 +567,7 @@ where if let Some(delta) = delta_state.delta { self.state.apply_patch(delta.clone()); self.handler - .update_shadow(None, Some(self.state.clone().into())) + .update_shadow(None, Some(self.state.clone().into_reported())) .await?; }