diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3f9320d..685e4aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,68 +6,103 @@ on: - master pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: ALL_FEATURES: "ota_mqtt_data,ota_http_data" jobs: - cancel_previous_runs: - name: Cancel previous runs + build: + name: Build runs-on: ubuntu-latest steps: - - uses: styfle/cancel-workflow-action@0.4.1 - with: - access_token: ${{ secrets.GITHUB_TOKEN }} - + - name: Checkout source code + uses: actions/checkout@v4 + + - uses: dsherret/rust-toolchain-file@v1 + + - name: Build (library) + run: cargo build --all --target thumbv7em-none-eabihf + test: - name: Build & Test + name: Test runs-on: ubuntu-latest steps: - name: Checkout source code - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - uses: dsherret/rust-toolchain-file@v1 - - name: Build + + - name: Doc Tests uses: actions-rs/cargo@v1 with: - command: build - args: --all --target thumbv7em-none-eabihf --features ${{ env.ALL_FEATURES }} + command: test + args: --doc --features "std,log" + + - name: Macro Tests + uses: actions-rs/cargo@v1 + with: + command: test + args: -p rustot-derive - - name: Test + - name: Unit Tests uses: actions-rs/cargo@v1 with: command: test - args: --lib --features "ota_mqtt_data,log" - + args: --lib --features "std,log" + rustfmt: name: rustfmt runs-on: ubuntu-latest steps: - name: Checkout source code - uses: actions/checkout@v3 + + uses: actions/checkout@v4 - uses: dsherret/rust-toolchain-file@v1 - - name: Rustfmt - run: cargo fmt -- --check + + - name: Run rustfmt (library) + run: cargo fmt --all -- --check --verbose + + # - name: Run rustfmt (examples) + # run: | + # for EXAMPLE in $(ls examples); + # do + # (cd examples/$EXAMPLE && cargo fmt --all -- --check --verbose) + # done clippy: name: clippy runs-on: ubuntu-latest + env: + CLIPPY_PARAMS: -W clippy::all -W clippy::pedantic -W clippy::nursery -W clippy::cargo steps: - name: Checkout source code - uses: actions/checkout@v3 + + uses: actions/checkout@v4 - uses: dsherret/rust-toolchain-file@v1 - - name: Run clippy - uses: actions-rs/clippy-check@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - args: -- ${{ env.CLIPPY_PARAMS }} + + - name: Run clippy (library) + run: cargo clippy --features "log" -- ${{ env.CLIPPY_PARAMS }} + + # - name: Run clippy (examples) + # run: | + # for EXAMPLE in $(ls examples); + # do + # (cd examples/$EXAMPLE && cargo clippy -- ${{ env.CLIPPY_PARAMS }}) + # done integration-test: name: Integration Tests runs-on: ubuntu-latest - needs: ['test', 'rustfmt', 'clippy'] + needs: ["build", "test", "rustfmt", "clippy"] steps: - name: Checkout source code - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - uses: dsherret/rust-toolchain-file@v1 + - name: Create OTA Job run: | ./scripts/create_ota.sh @@ -75,6 +110,7 @@ jobs: AWS_DEFAULT_REGION: ${{ secrets.MGMT_AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.MGMT_AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} + - name: Integration Tests uses: actions-rs/cargo@v1 with: @@ -91,4 +127,69 @@ jobs: env: AWS_DEFAULT_REGION: ${{ secrets.MGMT_AWS_DEFAULT_REGION }} AWS_ACCESS_KEY_ID: ${{ secrets.MGMT_AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} \ No newline at end of file + AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} + + # device_advisor: + # name: AWS IoT Device Advisor + # runs-on: ubuntu-latest + # needs: test + # env: + # AWS_EC2_METADATA_DISABLED: true + # AWS_DEFAULT_REGION: ${{ secrets.MGMT_AWS_DEFAULT_REGION }} + # AWS_ACCESS_KEY_ID: ${{ secrets.MGMT_AWS_ACCESS_KEY_ID }} + # AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} + # SUITE_ID: 1gaev57dq6i5 + # THING_ARN: arn:aws:iot:eu-west-1:411974994697:thing/embedded-mqtt + # steps: + # - name: Checkout source code + # uses: actions/checkout@v4 + + # - uses: dsherret/rust-toolchain-file@v1 + + # - name: Get AWS_HOSTNAME + # id: hostname + # run: | + # hostname=$(aws iotdeviceadvisor get-endpoint --thing-arn ${{ env.THING_ARN }} --output text --query endpoint) + # ret=$? + # echo "::set-output name=AWS_HOSTNAME::$hostname" + # exit $ret + + # - name: Build test binary + # env: + # AWS_HOSTNAME: ${{ steps.hostname.outputs.AWS_HOSTNAME }} + # run: cargo build --features=log --example aws_device_advisor --release + + # - name: Start test suite + # id: test_suite + # run: | + # suite_id=$(aws iotdeviceadvisor start-suite-run --suite-definition-id ${{ env.SUITE_ID }} --suite-run-configuration "primaryDevice={thingArn=${{ env.THING_ARN }}},parallelRun=true" --output text --query suiteRunId) + # ret=$? + # echo "::set-output name=SUITE_RUN_ID::$suite_id" + # exit $ret + + # - name: Execute test binary + # id: binary + # env: + # DEVICE_ADVISOR_PASSWORD: ${{ secrets.DEVICE_ADVISOR_PASSWORD }} + # RUST_LOG: trace + # run: | + # nohup ./target/release/examples/aws_device_advisor > device_advisor_integration.log & + # echo "::set-output name=PID::$!" + + # - name: Monitor test run + # run: | + # chmod +x ./scripts/da_monitor.sh + # echo ${{ env.SUITE_ID }} ${{ steps.test_suite.outputs.SUITE_RUN_ID }} ${{ steps.binary.outputs.PID }} + # ./scripts/da_monitor.sh ${{ env.SUITE_ID }} ${{ steps.test_suite.outputs.SUITE_RUN_ID }} ${{ steps.binary.outputs.PID }} + + # - name: Kill test binary process + # if: ${{ always() }} + # run: kill ${{ steps.binary.outputs.PID }} || true + + # - name: Log binary output + # if: ${{ always() }} + # run: cat device_advisor_integration.log + + # - name: Stop test suite + # if: ${{ failure() }} + # run: aws iotdeviceadvisor stop-suite-run --suite-definition-id ${{ env.SUITE_ID }} --suite-run-id ${{ steps.test_suite.outputs.SUITE_RUN_ID }} diff --git a/.vscode/settings.json b/.vscode/settings.json index 48fb5ea..4685b72 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,7 @@ { - "rust-analyzer.checkOnSave.allTargets": false, - "rust-analyzer.cargo.features": ["log"], - "rust-analyzer.cargo.target": "x86_64-unknown-linux-gnu" + "rust-analyzer.cargo.features": [ + "log", + "ota_mqtt_data" + ], + "rust-analyzer.cargo.target": "x86_64-unknown-linux-gnu", } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 671dcd3..5cb6533 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,16 @@ [workspace] -members = ["shadow_derive"] +members = ["rustot_derive"] [package] name = "rustot" -version = "0.4.1" -authors = ["Mathias Koch "] +version = "0.5.0" +authors = ["Factbird team "] description = "AWS IoT" readme = "README.md" keywords = ["iot", "no-std"] categories = ["embedded", "no-std"] license = "MIT OR Apache-2.0" -repository = "https://github.com/BlackbirdHQ/rustot" +repository = "https://github.com/FactbirdHQ/rustot" edition = "2021" documentation = "https://docs.rs/rustot" exclude = ["/documentation"] @@ -22,45 +22,77 @@ name = "rustot" maintenance = { status = "actively-developed" } [dependencies] -bitmaps = { version = "^3.1", default-features = false } -heapless = { version = "0.7.0", features = ["serde"] } -mqttrust = { version = "0.6" } -nb = "1" -serde = { version = "1.0.126", default-features = false, features = ["derive"] } -serde_cbor = { version = "^0.11", default-features = false, optional = true } -serde-json-core = { version = "0.4.0" } -smlang = "0.5.0" -fugit-timer = "0.1.2" -shadow-derive = { path = "shadow_derive", version = "0.2.1" } -embedded-storage = "0.3.0" - -log = { version = "^0.4", default-features = false, optional = true } -defmt = { version = "^0.3", optional = true } +bitmaps = { version = "3.1", default-features = false } +heapless = { version = "0.9", features = ["serde"] } +serde = { version = "1.0", default-features = false, features = ["derive"] } + +minicbor = { version = "0.25", optional = true } +minicbor-serde = { version = "0.3.2", optional = true } + + +serde-json-core = { version = "0.6" } +rustot-derive = { path = "rustot_derive", version = "0.2.1" } +embedded-storage-async = "0.4" +embedded-mqtt = { git = "ssh://git@github.com/FactbirdHQ/embedded-mqtt", rev = "f74610b" } + +futures = { version = "0.3.28", default-features = false } + +embassy-time = { version = "0.5.0" } +embassy-sync = "0.7.2" +embassy-futures = "0.1.2" + +log = { version = "0.4", default-features = false, optional = true } +defmt = { version = "0.3", optional = true } +bon = { version = "3.3.2", default-features = false } [dev-dependencies] -native-tls = { version = "^0.2" } -embedded-nal = "0.6.0" -no-std-net = { version = "^0.5", features = ["serde"] } -dns-lookup = "1.0.3" -mqttrust_core = { version = "0.6", features = ["log"] } -env_logger = "0.9.0" +native-tls = { version = "0.2" } +embedded-nal-async = "0.8" +env_logger = "0.11" sha2 = "0.10.1" -ecdsa = { version = "0.13.4", features = ["pkcs8"] } -p256 = "0.10.1" -pkcs8 = { version = "0.8", features = ["encryption", "pem"] } -timebomb = "0.1.2" -hex = "0.4.3" +static_cell = { version = "2", features = ["nightly"] } +log = { version = "0.4" } +serde_json = "1" + +tokio = { version = "1.33", default-features = false, features = [ + "macros", + "rt", + "net", + "time", + "io-std", +] } +tokio-native-tls = { version = "0.3.1" } +embassy-futures = { version = "0.1.2" } +embassy-time = { version = "0.5", features = ["log", "std", "generic-queue-8"] } +embedded-io-adapters = { version = "0.6.0", features = ["tokio-1"] } + +ecdsa = { version = "0.16", features = ["pkcs8", "pem"] } +p256 = "0.13" +pkcs8 = { version = "0.10", features = ["encryption", "pem"] } +hex = { version = "0.4.3", features = ["alloc"] } + [features] -default = ["ota_mqtt_data", "provision_cbor"] +default = ["ota_mqtt_data", "metric_cbor", "provision_cbor"] -provision_cbor = ["serde_cbor"] +metric_cbor = ["dep:minicbor", "dep:minicbor-serde"] + +provision_cbor = ["dep:minicbor", "dep:minicbor-serde"] + +ota_mqtt_data = ["dep:minicbor", "dep:minicbor-serde"] -ota_mqtt_data = ["serde_cbor"] ota_http_data = [] -std = ["serde/std", "serde_cbor?/std"] +std = ["serde/std", "minicbor-serde?/std"] + +defmt = [ + "dep:defmt", + "heapless/defmt", + "embedded-mqtt/defmt", + "embassy-time/defmt", +] +log = ["dep:log", "embedded-mqtt/log"] -defmt = ["dep:defmt", "mqttrust/defmt-impl", "heapless/defmt-impl"] -graphviz = ["smlang/graphviz"] +# [patch."ssh://git@github.com/FactbirdHQ/embedded-mqtt"] +# embedded-mqtt = { path = "../embedded-mqtt" } diff --git a/README.md b/README.md index c0d217a..ae1176c 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,36 @@ -# Rust of things (rustot) +# Rust of Things (rustot) +> A `no_std`, `no_alloc` crate for interacting with AWS IoT services on embedded devices. -**Work in progress** +This crate aims to provide a pure-Rust implementation of essential AWS IoT features for embedded systems, inspired by the Amazon FreeRTOS AWS IoT Device SDK. -> no_std, no_alloc crate for AWS IoT Devices, implementing Jobs, OTA, Device Defender and IoT Shadows +## Features -This crates strives to implement the sum of: -- [AWS OTA](https://github.com/aws/ota-for-aws-iot-embedded-sdk) -- [AWS Device Defender](https://github.com/aws/Device-Defender-for-AWS-IoT-embedded-sdk) -- [AWS Jobs](https://github.com/aws/Jobs-for-AWS-IoT-embedded-sdk) -- [AWS Device Shadow](https://github.com/aws/Device-Shadow-for-AWS-IoT-embedded-sdk) -- [AWS IoT Fleet Provisioning](https://github.com/aws/Fleet-Provisioning-for-AWS-IoT-embedded-sdk) +- **OTA Updates:** ([`ota`] module) + - Download and apply firmware updates securely over MQTT or HTTP. + - Supports both CBOR and raw binary firmware formats. +- **Device Shadow:** ([`shadows`] module) + - Synchronize device state with the cloud using AWS IoT Device Shadow service. + - Get, update, and delete device shadows. +- **Jobs:** ([`jobs`] module) + - Receive and execute jobs remotely on your devices. + - Track job status and report progress to AWS IoT. +- **Device Defender:** ([`defender_metrics`] module) + - Implement security best practices and detect anomalies on your devices. +- **Fleet Provisioning:** ([`provisioning`] module) + - Securely provision and connect devices to AWS IoT at scale. +- **Lightweight and `no_std`:** Designed specifically for resource-constrained embedded devices. +## Contributing -![Test][test] -[![Code coverage][codecov-badge]][codecov] -![No Std][no-std-badge] -[![Crates.io Version][crates-io-badge]][crates-io] -[![Crates.io Downloads][crates-io-download-badge]][crates-io-download] - -Any contributions will be welcomed! Even if they are just suggestions, bugs or reviews! - -This is a port of the Amazon-FreeRTOS AWS IoT Device SDK (https://github.com/nguyenvuhung/amazon-freertos/tree/master/libraries/freertos_plus/aws/ota), written in pure Rust. - -It is written to work with [mqttrust](https://github.com/BlackbirdHQ/mqttrust), but should work with any other mqtt client, that implements the [Mqtt trait](https://github.com/BlackbirdHQ/mqttrust/blob/master/mqttrust/src/lib.rs) from mqttrust. - - -## Tests - -> The crate is covered by tests. These tests can be run by `cargo test --tests --all-features`, and are run by the CI on every push to master. +Contributions, suggestions, bug reports, and reviews are highly appreciated! ## License Licensed under either of - Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or - http://www.apache.org/licenses/LICENSE-2.0) + http://www.apache.org/licenses/LICENSE-2.0) - MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) at your option. @@ -45,13 +40,3 @@ at your option. Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. - - -[test]: https://github.com/BlackbirdHQ/rustot/workflows/Test/badge.svg -[no-std-badge]: https://img.shields.io/badge/no__std-yes-blue -[codecov-badge]: https://codecov.io/gh/BlackbirdHQ/rustot/branch/master/graph/badge.svg -[codecov]: https://codecov.io/gh/BlackbirdHQ/rustot -[crates-io]: https://crates.io/crates/rustot -[crates-io-badge]: https://img.shields.io/crates/v/rustot.svg?maxAge=3600 -[crates-io-download]: https://crates.io/crates/rustot -[crates-io-download-badge]: https://img.shields.io/crates/d/rustot.svg?maxAge=3600 diff --git a/documentation/ota_agent.svg b/documentation/ota_agent.svg deleted file mode 100644 index 58d04e8..0000000 --- a/documentation/ota_agent.svg +++ /dev/null @@ -1,396 +0,0 @@ - - - - - - -G - - - -s - -s - - - -Ready - -Ready - - - -s->Ready - - - - - -Ready->Ready - - -Shutdown - - - -Suspended - -Suspended - - - -Ready->Suspended - - -Suspend - - - -WaitingForJob - -WaitingForJob - - - -Ready->WaitingForJob - - -UserAbort - - - -RequestingJob - -RequestingJob - - - -Ready->RequestingJob - - -Start - - - -RequestingFileBlock - -RequestingFileBlock - - - -RequestingFileBlock->Ready - - -Shutdown - - - -WaitingForFileBlock - -WaitingForFileBlock - - - -RequestingFileBlock->WaitingForFileBlock - - -RequestTimer - - - -RequestingFileBlock->WaitingForFileBlock - - -RequestFileBlock - - - -RequestingFileBlock->Suspended - - -Suspend - - - -RequestingFileBlock->WaitingForJob - - -UserAbort - - - -WaitingForFileBlock->Ready - - -Shutdown - - - -WaitingForFileBlock->WaitingForFileBlock - - -ReceivedFileBlock - - - -WaitingForFileBlock->WaitingForFileBlock - - -RequestTimer - - - -WaitingForFileBlock->WaitingForFileBlock - - -RequestFileBlock - - - -WaitingForFileBlock->Suspended - - -Suspend - - - -WaitingForFileBlock->WaitingForJob - - -UserAbort - - - -WaitingForFileBlock->WaitingForJob - - -CloseFile - - - -WaitingForFileBlock->WaitingForJob - - -RequestJobDocument - - - -WaitingForFileBlock->RequestingJob - - -ReceivedJobDocument - - - -Suspended->RequestingJob - - -Resume - - - -WaitingForJob->Ready - - -Shutdown - - - -WaitingForJob->Suspended - - -Suspend - - - -WaitingForJob->WaitingForJob - - -UserAbort - - - -CreatingFile - -CreatingFile - - - -WaitingForJob->CreatingFile - - -ReceivedJobDocument - - - -CreatingFile->Ready - - -Shutdown - - - -CreatingFile->RequestingFileBlock - - -RequestTimer - - - -CreatingFile->RequestingFileBlock - - -CreateFile - - - -CreatingFile->Suspended - - -Suspend - - - -CreatingFile->WaitingForJob - - -UserAbort - - - -CreatingFile->WaitingForJob - - -StartSelfTest - - - -RequestingJob->Ready - - -Shutdown - - - -RequestingJob->Suspended - - -Suspend - - - -RequestingJob->WaitingForJob - - -UserAbort - - - -RequestingJob->WaitingForJob - - -RequestJobDocument - - - -RequestingJob->WaitingForJob - - -RequestTimer - - - -UserAbort - -UserAbort -[user_abort_handler] / _ - - - -RequestTimer - -RequestTimer -[request_data_handler] / _ - - - -Shutdown - -Shutdown -[shutdown_handler] / _ - - - -StartSelfTest - -StartSelfTest -[in_self_test_handler] / _ - - - -CreateFile - -CreateFile -[init_file_handler] / _ - - - -Suspend - -Suspend -[_] / _ - - - -RequestJobDocument - -RequestJobDocument -[request_job_handler] / _ - - - -RequestFileBlock - -RequestFileBlock -[request_data_handler] / _ - - - -CloseFile - -CloseFile -[close_file_handler] / _ - - - -ReceivedJobDocument - -ReceivedJobDocument -[process_job_handler] / _ - - - -ReceivedFileBlock - -ReceivedFileBlock -[process_data_handler] / _ - - - -Start - -Start -[start_handler] / _ - - - -Resume - -Resume -[resume_job_handler] / _ - - - diff --git a/documentation/provisioning.drawio b/documentation/provisioning.drawio index dbd6394..dea9dcc 100644 --- a/documentation/provisioning.drawio +++ b/documentation/provisioning.drawio @@ -1,6 +1,6 @@ - + @@ -96,8 +96,8 @@ - - + + @@ -117,8 +117,8 @@ - - + + @@ -135,8 +135,8 @@ - - + + diff --git a/documentation/shadows_test.svg b/documentation/shadows_test.svg deleted file mode 100644 index 668e518..0000000 --- a/documentation/shadows_test.svg +++ /dev/null @@ -1,185 +0,0 @@ - - - - - - -G - - - -s - -s - - - -Begin - -Begin - - - -s->Begin - - - - - -DeleteShadow - -DeleteShadow - - - -Begin->DeleteShadow - - -Delete - - - -UpdateFromDevice - -UpdateFromDevice - - - -Check - -Check - - - -UpdateFromDevice->Check - - -CheckState - - - -LoadShadow - -LoadShadow - - - -LoadShadow->UpdateFromDevice - - -UpdateStateFromDevice - - - -Done - -Done - - - -GetShadow - -GetShadow - - - -GetShadow->LoadShadow - - -Load - - - -Check->UpdateFromDevice - - -UpdateStateFromDevice - - - -Check->Done - - -Finish - - - -UpdateFromCloud - -UpdateFromCloud - - - -Check->UpdateFromCloud - - -UpdateStateFromCloud - - - -UpdateFromCloud->Check - - -CheckState - - - -DeleteShadow->GetShadow - - -Get - - - -CheckState - -CheckState -[_] / check - - - -UpdateStateFromDevice - -UpdateStateFromDevice -[_] / get_next_device - - - -Get - -Get -[_] / _ - - - -Finish - -Finish -[_] / _ - - - -UpdateStateFromCloud - -UpdateStateFromCloud -[_] / get_next_cloud - - - -Load - -Load -[_] / load_initial - - - -Delete - -Delete -[_] / _ - - - diff --git a/documentation/stack.drawio b/documentation/stack.drawio deleted file mode 100644 index cdf864a..0000000 --- a/documentation/stack.drawio +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 3cd5460..2ca6b90 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,7 +1,8 @@ [toolchain] -channel = "nightly-2023-06-28" -components = [ "rust-src", "rustfmt", "llvm-tools-preview", "clippy" ] +channel = "nightly-2025-06-29" +components = ["rust-src", "rustfmt", "llvm-tools", "miri", "clippy"] targets = [ "x86_64-unknown-linux-gnu", - "thumbv7em-none-eabihf" + "thumbv6m-none-eabi", + "thumbv7em-none-eabihf", ] diff --git a/rustot_derive/Cargo.toml b/rustot_derive/Cargo.toml new file mode 100644 index 0000000..f59fcc5 --- /dev/null +++ b/rustot_derive/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "rustot-derive" +version = "0.2.1" +authors = ["Factbird team "] +description = "Procedual macros for rustot crate" +license = "MIT OR Apache-2.0" +repository = "https://github.com/FactbirdHQ/rustot" +edition = "2021" +readme = "../README.md" + +[lib] +proc-macro = true + +[dependencies] +syn = "2" +quote = "1" +proc-macro2 = "1" + +[dev-dependencies] +rustot = { path = ".." } +serde = "1" +heapless = "0.8" diff --git a/rustot_derive/src/lib.rs b/rustot_derive/src/lib.rs new file mode 100644 index 0000000..ec46eba --- /dev/null +++ b/rustot_derive/src/lib.rs @@ -0,0 +1,20 @@ +mod shadow; + +#[proc_macro_attribute] +pub fn shadow( + attr: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + proc_macro::TokenStream::from(shadow::shadow(attr.into(), input.into())) +} + +#[proc_macro_attribute] +pub fn shadow_patch( + attr: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + proc_macro::TokenStream::from(shadow::shadow_patch::shadow_patch( + attr.into(), + input.into(), + )) +} diff --git a/rustot_derive/src/shadow/generation/generator.rs b/rustot_derive/src/shadow/generation/generator.rs new file mode 100644 index 0000000..a1c16bd --- /dev/null +++ b/rustot_derive/src/shadow/generation/generator.rs @@ -0,0 +1,197 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{punctuated::Punctuated, spanned::Spanned as _, Data, DeriveInput, Token}; + +use crate::shadow::{ + generation::variant_or_field_visitor::{borrow_fields, get_attr, has_shadow_arg, is_primitive}, + CFG_ATTRIBUTE, DEFAULT_ATTRIBUTE, +}; + +pub trait Generator { + fn generate(&mut self, original: &DeriveInput, output: &DeriveInput) -> TokenStream; +} + +pub struct NewGenerator; + +impl Generator for NewGenerator { + fn generate(&mut self, _original: &DeriveInput, output: &DeriveInput) -> TokenStream { + quote! { + #output + } + } +} + +pub struct GenerateFromImpl; + +impl GenerateFromImpl { + fn variables_actions<'a>( + fields: impl Iterator, + ) -> ( + Punctuated, + Punctuated, + ) { + fields.enumerate().fold( + (Punctuated::new(), Punctuated::new()), + |(mut variables, mut actions), (i, field)| { + let var_ident = field.ident.clone().unwrap_or_else(|| { + syn::Ident::new(&format!("{}", (b'a' + i as u8) as char), field.span()) + }); + + let action = if is_primitive(&field.ty) || has_shadow_arg(&field.attrs, "leaf") { + quote! {Some(#var_ident)} + } else { + quote! {Some(#var_ident.into())} + }; + + actions.push( + field + .ident + .as_ref() + .map(|ident| quote! {#ident: #action}) + .unwrap_or(action), + ); + + variables.push(var_ident); + + (variables, actions) + }, + ) + } +} + +impl Generator for GenerateFromImpl { + fn generate(&mut self, original: &DeriveInput, output: &DeriveInput) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = original.generics.split_for_impl(); + let orig_name = &original.ident; + let new_name = &output.ident; + + let from_impl = match (&original.data, &output.data) { + (Data::Struct(data_struct_old), Data::Struct(data_struct_new)) => { + let original_fields = borrow_fields(data_struct_old); + let new_fields = borrow_fields(data_struct_new); + + let from_fields = original_fields.iter().fold(quote! {}, |acc, field| { + let is_leaf = is_primitive(&field.ty) || has_shadow_arg(&field.attrs, "leaf"); + + let has_new_field = new_fields + .iter() + .find(|&f| f.ident == field.ident) + .is_some(); + + let cfg_attr = get_attr(&field.attrs, CFG_ATTRIBUTE); + + let ident = &field.ident; + if !has_new_field { + quote! { #acc #cfg_attr #ident: None, } + } else if is_leaf { + quote! { #acc #cfg_attr #ident: Some(v.#ident), } + } else { + quote! { #acc #cfg_attr #ident: Some(v.#ident.into()), } + } + }); + + quote! { + Self { + #from_fields + } + } + } + (Data::Enum(data_struct_old), Data::Enum(_)) => { + let match_arms = data_struct_old + .variants + .iter() + .fold(Punctuated::::new(), |mut acc, variant| { + let variant_ident = &variant.ident; + let cfg_attr = get_attr(&variant.attrs, CFG_ATTRIBUTE); + + acc.push(match &variant.fields { + syn::Fields::Named(fields_named) => { + let (variables, actions) = Self::variables_actions(fields_named.named.iter()); + quote! {#cfg_attr #orig_name::#variant_ident { #variables } => Self::#variant_ident { #actions }} + } + syn::Fields::Unnamed(fields_unnamed) => { + let (variables, actions) = Self::variables_actions(fields_unnamed.unnamed.iter()); + quote! {#cfg_attr #orig_name::#variant_ident ( #variables ) => Self::#variant_ident ( #actions )} + } + syn::Fields::Unit => { + quote! {#cfg_attr #orig_name::#variant_ident => Self::#variant_ident} + } + }); + + acc + }); + + quote! { + match v { + #match_arms + } + } + } + _ => panic!(), + }; + + quote! { + impl #impl_generics From<#orig_name #ty_generics> for #new_name #ty_generics #where_clause { + fn from(v: #orig_name #ty_generics) -> Self { + #from_impl + } + } + } + } +} + +pub struct DefaultGenerator(pub bool); + +impl Generator for DefaultGenerator { + fn generate(&mut self, _original: &DeriveInput, output: &DeriveInput) -> TokenStream { + if self.0 { + return quote! {}; + } + + if let Data::Enum(enum_data) = &output.data { + let default_variant = enum_data + .variants + .iter() + .find(|variant| get_attr(&variant.attrs, DEFAULT_ATTRIBUTE).is_some()) + .expect(""); + + let default_ident = &default_variant.ident; + + let default_fields = match &default_variant.fields { + syn::Fields::Named(fields_named) => { + let assigners = fields_named.named.iter().fold( + Punctuated::::new(), + |mut acc, field| { + let ident = &field.ident; + acc.push(quote! { #ident: Default::default() }); + acc + }, + ); + quote! { { #assigners } } + } + syn::Fields::Unnamed(fields_unnamed) => { + let assigners = fields_unnamed.unnamed.iter().fold( + Punctuated::::new(), + |mut acc, _field| { + acc.push(quote! { Default::default() }); + acc + }, + ); + quote! { ( #assigners ) } + } + syn::Fields::Unit => quote! {}, + }; + + let ident = &output.ident; + + return quote! { + impl Default for #ident { + fn default() -> Self { + Self::#default_ident #default_fields + } + } + }; + } + quote! {} + } +} diff --git a/rustot_derive/src/shadow/generation/mod.rs b/rustot_derive/src/shadow/generation/mod.rs new file mode 100644 index 0000000..24030a2 --- /dev/null +++ b/rustot_derive/src/shadow/generation/mod.rs @@ -0,0 +1,315 @@ +pub mod generator; +pub mod modifier; +pub mod variant_or_field_visitor; + +use generator::Generator; +use modifier::Modifier; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{punctuated::Punctuated, Data, DeriveInput, Path, Token}; +use variant_or_field_visitor::{ + borrow_fields_mut, get_attr, has_shadow_arg, is_primitive, VariantOrFieldVisitor, +}; + +use super::CFG_ATTRIBUTE; + +pub struct ShadowGenerator { + input: DeriveInput, + output: DeriveInput, + generated: TokenStream, +} + +impl ShadowGenerator { + pub fn new(input: DeriveInput) -> Self { + let output = input.clone(); + Self { + input, + output, + generated: quote! {}, + } + } + + pub fn variant_or_field_visitor(mut self, visitor: &mut impl VariantOrFieldVisitor) -> Self { + match (&mut self.input.data, &mut self.output.data) { + (Data::Struct(data_struct_old), Data::Struct(data_struct_new)) => { + let old_fields = borrow_fields_mut(data_struct_old); + let new_fields = borrow_fields_mut(data_struct_new); + + for (old_field, new_field) in old_fields.iter_mut().zip(new_fields.iter_mut()) { + visitor.visit_field(old_field, new_field); + } + } + (Data::Enum(data_enum_old), Data::Enum(data_enum_new)) => { + for (old_variant, new_variant) in data_enum_old + .variants + .iter_mut() + .zip(data_enum_new.variants.iter_mut()) + { + visitor.visit_variant(old_variant, new_variant); + } + } + _ => {} + } + + self + } + + pub fn modifier(mut self, modifier: &mut impl Modifier) -> Self { + modifier.modify(&self.input, &mut self.output); + + self + } + + pub fn generator(mut self, generator: &mut impl Generator) -> Self { + let gen = generator.generate(&self.input, &self.output); + + let generated = self.generated; + self.generated = quote! { + #generated + + #gen + }; + + self + } + + pub fn finalize(self) -> TokenStream { + self.generated + } +} + +pub struct GenerateShadowPatchImplVisitor { + field_index: usize, + reported_ident: Path, + apply_patch_impl: TokenStream, +} + +impl GenerateShadowPatchImplVisitor { + pub fn new(reported_ident: Path) -> Self { + Self { + field_index: 0, + reported_ident, + apply_patch_impl: quote! {}, + } + } +} + +impl Generator for GenerateShadowPatchImplVisitor { + fn generate(&mut self, original: &DeriveInput, output: &DeriveInput) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = original.generics.split_for_impl(); + let orig_name = &original.ident; + let delta_name = &output.ident; + let reported_name = &self.reported_ident; + + let apply_patch_impl = match original.data { + Data::Enum(_) => { + let arms = &self.apply_patch_impl; + quote! { + match (self, delta) { + #arms + } + } + } + _ => { + let implementation = &self.apply_patch_impl; + quote! { #implementation } + } + }; + + quote! { + impl #impl_generics rustot::shadows::ShadowPatch for #orig_name #ty_generics #where_clause { + type Delta = #delta_name #ty_generics; + type Reported = #reported_name #ty_generics; + + fn apply_patch(&mut self, delta: Self::Delta) { + #apply_patch_impl + } + } + } + } +} + +#[derive(Default)] +struct PatchImpl { + variables: Punctuated, + delta_variables: Punctuated, + delta_deconstructors: Punctuated, + defaults: TokenStream, + assigns: TokenStream, +} + +impl PatchImpl { + fn new<'a>(fields: impl Iterator) -> PatchImpl { + fields + .enumerate() + .fold(PatchImpl::default(), |mut patch_impl, (i, field)| { + let cfg_attr = get_attr(&field.attrs, CFG_ATTRIBUTE); + + let var_ident = field + .ident + .clone() + .unwrap_or_else(|| format_ident!("{}", (b'a' + i as u8) as char)); + + let delta_ident = format_ident!("d_{}", var_ident); + + let cfg_var = quote! { #cfg_attr #var_ident}; + let cfg_delta = quote! { #cfg_attr #delta_ident}; + + let assigns = patch_impl.assigns; + let defaults = patch_impl.defaults; + let field_ty = &field.ty; + + let (action, action_deref) = + if has_shadow_arg(&field.attrs, "leaf") || is_primitive(&field.ty) { + ( + quote! { + #cfg_attr + let #var_ident = #delta_ident.unwrap_or_default(); + }, + quote! { + #cfg_attr + if let Some(delta_var) = #delta_ident { + *#var_ident = delta_var; + } + }, + ) + } else { + ( + quote! { + #cfg_attr + let mut #var_ident = #field_ty ::default(); + + #cfg_attr + if let Some(delta_var) = #delta_ident { + #var_ident.apply_patch(delta_var); + } + }, + quote! { + #cfg_attr + if let Some(delta_var) = #delta_ident { + #var_ident.apply_patch(delta_var); + } + }, + ) + }; + + patch_impl.defaults = quote! { + #defaults + + #action + }; + + patch_impl.assigns = quote! { + #assigns + + #action_deref + }; + + patch_impl.variables.push(cfg_var); + patch_impl.delta_variables.push(cfg_delta); + patch_impl + .delta_deconstructors + .push(quote! { #var_ident: #delta_ident}); + + patch_impl + }) + } +} + +impl VariantOrFieldVisitor for GenerateShadowPatchImplVisitor { + fn visit_field(&mut self, old: &mut syn::Field, _new: &mut syn::Field) { + let field_index = self.field_index; + self.field_index += 1; + + let field_ident = old + .ident + .as_ref() + .map(|i| quote! { #i }) + .unwrap_or_else(|| { + let i = syn::Index::from(field_index); + quote! { #i } + }); + + let cfg_attr = get_attr(&old.attrs, CFG_ATTRIBUTE); + + let acc = &self.apply_patch_impl; + + self.apply_patch_impl = if has_shadow_arg(&old.attrs, "leaf") || is_primitive(&old.ty) { + quote! { + #acc + + #cfg_attr if let Some(inner) = delta.#field_ident { self.#field_ident = inner; } + } + } else { + quote! { + #acc + + #cfg_attr if let Some(inner) = delta.#field_ident { self.#field_ident.apply_patch(inner); } + } + }; + } + + fn visit_variant(&mut self, old: &mut syn::Variant, _new: &mut syn::Variant) { + let variant_ident = &old.ident; + let variant_cfg = get_attr(&old.attrs, CFG_ATTRIBUTE); + + let acc = &self.apply_patch_impl; + self.apply_patch_impl = match &old.fields { + syn::Fields::Named(fields_named) => { + let PatchImpl { + variables, + delta_deconstructors, + defaults, + assigns, + .. + } = PatchImpl::new(fields_named.named.iter()); + + quote! { + #acc + + #variant_cfg + (Self::#variant_ident { #variables }, Self::Delta::#variant_ident { #delta_deconstructors }) => { + #assigns + } + #variant_cfg + (this, Self::Delta::#variant_ident { #delta_deconstructors }) => { + #defaults + + *this = Self::#variant_ident { #variables }; + } + } + } + syn::Fields::Unnamed(fields_unnamed) => { + let PatchImpl { + variables, + delta_variables, + defaults, + assigns, + .. + } = PatchImpl::new(fields_unnamed.unnamed.iter()); + + quote! { + #acc + + #variant_cfg + (Self::#variant_ident ( #variables ), Self::Delta::#variant_ident ( #delta_variables )) => { + #assigns + } + #variant_cfg + (this, Self::Delta::#variant_ident( #delta_variables)) => { + #defaults + + *this = Self::#variant_ident ( #variables ); + } + } + } + syn::Fields::Unit => { + quote! { + #acc + + #variant_cfg (this, Self::Delta::#variant_ident) => *this = Self::#variant_ident, + } + } + }; + } +} diff --git a/rustot_derive/src/shadow/generation/modifier.rs b/rustot_derive/src/shadow/generation/modifier.rs new file mode 100644 index 0000000..f3bff89 --- /dev/null +++ b/rustot_derive/src/shadow/generation/modifier.rs @@ -0,0 +1,104 @@ +use quote::format_ident; +use syn::{parse_quote, punctuated::Punctuated, Data, DeriveInput, Field, Ident, Path, Token}; + +use super::variant_or_field_visitor::{ + borrow_fields, borrow_fields_mut, extract_type_from_option, has_shadow_arg, +}; + +pub trait Modifier { + fn modify(&mut self, original: &DeriveInput, output: &mut DeriveInput); +} + +/// Rename the `output` struct to `self.0`, directly in the AST. +pub struct RenameModifier(pub String); + +impl Modifier for RenameModifier { + fn modify(&mut self, _original: &DeriveInput, output: &mut DeriveInput) { + output.ident = Ident::new(&self.0, output.ident.span()); + } +} + +pub struct WithDerivesModifier(pub bool, pub Vec<&'static str>); + +impl Modifier for WithDerivesModifier { + fn modify(&mut self, _original: &DeriveInput, output: &mut DeriveInput) { + if !self.0 { + return; + } + + let mut all_derives = self + .1 + .iter() + .filter(|s| { + if matches!(output.data, Data::Enum(_)) { + return **s != "Default"; + } + true + }) + .map(|s| Path::from(format_ident!("{}", s))) + .collect::>(); + + output.attrs.retain(|attr| { + if attr.path().is_ident("derive") { + let derived = attr + .parse_args_with(Punctuated::::parse_terminated) + .unwrap_or_default(); + for derived_trait in derived.iter() { + if !all_derives.iter().any(|p| p == derived_trait) { + if !(matches!(output.data, Data::Enum(_)) + && derived_trait.is_ident(&format_ident!("Default"))) + { + all_derives.push(derived_trait.clone()); + } + } + } + return false; + } + true + }); + + // Make sure we get derives first + output.attrs.reverse(); + output.attrs.push(parse_quote! { #[derive(#all_derives)] }); + output.attrs.reverse(); + } +} + +/// Filter all fields annotated with `#[shadow_attr(report_only)]` from the +/// `output` AST. +pub struct ReportOnlyModifier; + +impl ReportOnlyModifier { + fn filter_report_only(field: &Field) -> bool { + if has_shadow_arg(&field.attrs, "report_only") { + return false; + } + + if extract_type_from_option(&field.ty).is_some() { + panic!("Optionals are only allowed in `report_only` fields"); + } + + true + } +} + +impl Modifier for ReportOnlyModifier { + fn modify(&mut self, original: &DeriveInput, output: &mut DeriveInput) { + // This modifier only modifies structs + match (&original.data, &mut output.data) { + (Data::Struct(data_struct_old), Data::Struct(data_struct_new)) => { + let old_fields = borrow_fields(data_struct_old); + let new_fields = borrow_fields_mut(data_struct_new); + + *new_fields = old_fields + .iter() + .zip(new_fields.iter_mut()) + .filter_map(|(old_field, new_field)| { + Self::filter_report_only(old_field).then_some(new_field.clone()) + }) + .collect::>(); + } + _ => {} + } + } +} diff --git a/rustot_derive/src/shadow/generation/variant_or_field_visitor.rs b/rustot_derive/src/shadow/generation/variant_or_field_visitor.rs new file mode 100644 index 0000000..c8e281b --- /dev/null +++ b/rustot_derive/src/shadow/generation/variant_or_field_visitor.rs @@ -0,0 +1,213 @@ +use quote::quote; +use syn::{ + parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, DataStruct, Field, Fields, + Ident, Token, Type, Visibility, +}; + +use crate::shadow::{DEFAULT_ATTRIBUTE, SHADOW_ATTRIBUTE}; + +pub trait VariantOrFieldVisitor { + fn visit_field(&mut self, old: &mut syn::Field, new: &mut syn::Field); + + // Default to running the field visitor for each field in the variant + fn visit_variant(&mut self, old: &mut syn::Variant, new: &mut syn::Variant) { + for (old_field, new_field) in old.fields.iter_mut().zip(new.fields.iter_mut()) { + self.visit_field(old_field, new_field); + } + } +} + +pub struct SetNewVisibilityVisitor(pub bool); + +impl VariantOrFieldVisitor for SetNewVisibilityVisitor { + fn visit_field(&mut self, _old: &mut syn::Field, new: &mut syn::Field) { + if self.0 { + new.vis = Visibility::Public(syn::token::Pub(new.vis.span())); + } + } + + // Do nothing for enums, as field & variant visibility always follows the visibility of the enum itself + fn visit_variant(&mut self, _old: &mut syn::Variant, _new: &mut syn::Variant) {} +} + +pub struct SetNewTypeVisitor(pub Ident); + +impl VariantOrFieldVisitor for SetNewTypeVisitor { + fn visit_field(&mut self, old: &mut syn::Field, new: &mut syn::Field) { + let newtype_ident = &self.0; + + let (is_primitive, inner_type, is_base_opt) = match extract_type_from_option(&old.ty) { + Some(inner_ty) if is_primitive(&inner_ty) || has_shadow_arg(&old.attrs, "leaf") => { + (true, &old.ty, true) + } + Some(inner_ty) => (false, inner_ty, true), + None => (is_primitive(&old.ty), &old.ty, false), + }; + + let new_type = if is_primitive || has_shadow_arg(&old.attrs, "leaf") { + quote! {Option<#inner_type>} + } else if is_base_opt { + quote! {Option::#newtype_ident>>} + } else { + quote! {Option<<#inner_type as rustot::shadows::ShadowPatch>::#newtype_ident>} + }; + + new.ty = syn::parse2(new_type).unwrap(); + } +} + +pub struct AddSerdeSkipAttribute; + +impl VariantOrFieldVisitor for AddSerdeSkipAttribute { + fn visit_field(&mut self, _old: &mut syn::Field, new: &mut syn::Field) { + let attribute: Attribute = + parse_quote! { #[serde(skip_serializing_if = "Option::is_none")] }; + new.attrs.push(attribute); + } +} + +pub struct RemoveShadowAttributesVisitor; + +impl VariantOrFieldVisitor for RemoveShadowAttributesVisitor { + fn visit_field(&mut self, old: &mut syn::Field, new: &mut syn::Field) { + let indexes_to_remove = old + .attrs + .iter() + .enumerate() + .filter_map(|(i, a)| { + if a.path().is_ident(SHADOW_ATTRIBUTE) { + Some(i) + } else { + None + } + }) + .collect::>(); + + // Don't forget to reverse so the indices are removed without being shifted! + for i in indexes_to_remove.into_iter().rev() { + new.attrs.swap_remove(i); + } + } + + fn visit_variant(&mut self, old: &mut syn::Variant, new: &mut syn::Variant) { + let indexes_to_remove = old + .attrs + .iter() + .enumerate() + .filter_map(|(i, a)| { + if a.path().is_ident(DEFAULT_ATTRIBUTE) || a.path().is_ident(SHADOW_ATTRIBUTE) { + Some(i) + } else { + None + } + }) + .collect::>(); + + // Don't forget to reverse so the indices are removed without being shifted! + for i in indexes_to_remove.into_iter().rev() { + new.attrs.swap_remove(i); + } + + for (old_field, new_field) in old.fields.iter_mut().zip(new.fields.iter_mut()) { + self.visit_field(old_field, new_field); + } + } +} + +/// Rust primitive types: https://doc.rust-lang.org/reference/types.html +pub fn is_primitive(t: &Type) -> bool { + match &t { + Type::Path(type_path) => type_path + .path + .segments + .last() + .map(|ps| { + [ + "bool", "u8", "u16", "u32", "u64", "u128", "usize", "i8", "i16", "i32", "i64", + "i128", "isize", "f32", "f64", "char", + ] + .iter() + .any(|s| ps.ident == *s) + }) + .unwrap_or(false), + Type::Paren(type_paren) => is_primitive(&type_paren.elem), + Type::Reference(_) | Type::Array(_) | Type::Tuple(_) => false, + t => panic!("Unsupported type: {:?}", quote! { #t }), + } +} + +pub fn borrow_fields(data_struct: &DataStruct) -> &Punctuated { + match &data_struct.fields { + Fields::Unnamed(f) => &f.unnamed, + Fields::Named(f) => &f.named, + Fields::Unit => unreachable!("Unit structs are not supported"), + } +} + +pub fn borrow_fields_mut(data_struct: &mut DataStruct) -> &mut Punctuated { + match &mut data_struct.fields { + Fields::Unnamed(f) => &mut f.unnamed, + Fields::Named(f) => &mut f.named, + Fields::Unit => unreachable!("Unit structs are not supported"), + } +} + +pub fn get_attr(attrs: &Vec, attr: &str) -> Option { + attrs.iter().find(|a| a.path().is_ident(attr)).cloned() +} + +pub fn has_shadow_arg(attrs: &Vec, arg: &str) -> bool { + if let Some(a) = get_attr(&attrs, SHADOW_ATTRIBUTE) { + let shadow_args = a + .parse_args_with(Punctuated::::parse_terminated) + .unwrap_or_default(); + + return shadow_args.iter().any(|i| i.to_string() == arg); + } + + false +} + +pub fn extract_type_from_option(ty: &syn::Type) -> Option<&syn::Type> { + use syn::{GenericArgument, Path, PathArguments, PathSegment}; + + fn extract_type_path(ty: &syn::Type) -> Option<&Path> { + match *ty { + syn::Type::Path(ref typepath) if typepath.qself.is_none() => Some(&typepath.path), + _ => None, + } + } + + // TODO store (with lazy static) the vec of string + // TODO maybe optimization, reverse the order of segments + fn extract_option_segment(path: &Path) -> Option<&PathSegment> { + let idents_of_path = path + .segments + .iter() + .into_iter() + .fold(String::new(), |mut acc, v| { + acc.push_str(&v.ident.to_string()); + acc.push('|'); + acc + }); + vec!["Option|", "std|option|Option|", "core|option|Option|"] + .into_iter() + .find(|s| &idents_of_path == *s) + .and_then(|_| path.segments.last()) + } + + extract_type_path(ty) + .and_then(|path| extract_option_segment(path)) + .and_then(|path_seg| { + let type_params = &path_seg.arguments; + // It should have only on angle-bracketed param (""): + match *type_params { + PathArguments::AngleBracketed(ref params) => params.args.first(), + _ => None, + } + }) + .and_then(|generic_arg| match *generic_arg { + GenericArgument::Type(ref ty) => Some(ty), + _ => None, + }) +} diff --git a/rustot_derive/src/shadow/mod.rs b/rustot_derive/src/shadow/mod.rs new file mode 100644 index 0000000..48a87c4 --- /dev/null +++ b/rustot_derive/src/shadow/mod.rs @@ -0,0 +1,99 @@ +mod generation; +pub mod shadow_patch; + +use proc_macro2::TokenStream; +use quote::quote; +use syn::parse::{Parse, ParseStream}; +use syn::{DeriveInput, Expr, ExprLit, Lit, LitInt, LitStr}; + +pub const SHADOW_ATTRIBUTE: &str = "shadow_attr"; +// pub const SHADOW_ALL: &str = "shadow_attr"; +// pub const SHADOW_DELTA: &str = "shadow_delta"; +// pub const SHADOW_REPORTED: &str = "shadow_reported"; + +pub const DEFAULT_ATTRIBUTE: &str = "default"; +pub const CFG_ATTRIBUTE: &str = "cfg"; + +#[derive(Default)] +struct MacroParameters { + shadow_name: Option, + topic_prefix: Option, + max_payload_size: Option, +} + +impl Parse for MacroParameters { + fn parse(input: ParseStream) -> syn::Result { + let mut out = MacroParameters::default(); + + while let Ok(optional) = input.parse::() { + match (optional.path.get_ident(), optional.value) { + ( + Some(ident), + Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }), + ) if ident == "name" => { + out.shadow_name = Some(v); + } + ( + Some(ident), + Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }), + ) if ident == "topic_prefix" => { + out.topic_prefix = Some(v); + } + ( + Some(ident), + Expr::Lit(ExprLit { + lit: Lit::Int(v), .. + }), + ) if ident == "max_payload_size" => { + out.max_payload_size = Some(v); + } + _ => {} + } + + if input.parse::().is_err() { + break; + } + } + + Ok(out) + } +} + +pub fn shadow(attr: TokenStream, input: TokenStream) -> TokenStream { + let shadow_patch = shadow_patch::shadow_patch(attr.clone(), input.clone()); + + let derive_input = syn::parse2::(input).unwrap(); + let macro_params = syn::parse2::(attr).unwrap(); + + let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl(); + let ident = &derive_input.ident; + + let name = match macro_params.shadow_name { + Some(name) => quote! { Some(#name) }, + None => quote! { None }, + }; + + let topic_prefix = match macro_params.topic_prefix { + Some(prefix) => quote! { #prefix }, + None => quote! { "$aws" }, + }; + + let max_payload_size = macro_params + .max_payload_size + .map_or(quote! { 512 }, |m| quote! { #m }); + + quote! { + #[automatically_derived] + impl #impl_generics rustot::shadows::ShadowState for #ident #ty_generics #where_clause { + const NAME: Option<&'static str> = #name; + const PREFIX: &'static str = #topic_prefix; + const MAX_PAYLOAD_SIZE: usize = #max_payload_size; + } + + #shadow_patch + } +} diff --git a/rustot_derive/src/shadow/shadow_patch.rs b/rustot_derive/src/shadow/shadow_patch.rs new file mode 100644 index 0000000..1718afd --- /dev/null +++ b/rustot_derive/src/shadow/shadow_patch.rs @@ -0,0 +1,142 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + parse::{Parse, ParseStream}, + DeriveInput, Path, +}; + +use crate::shadow::generation::{ + generator::{DefaultGenerator, GenerateFromImpl, NewGenerator}, + modifier::{RenameModifier, ReportOnlyModifier, WithDerivesModifier}, + variant_or_field_visitor::{ + AddSerdeSkipAttribute, RemoveShadowAttributesVisitor, SetNewTypeVisitor, + SetNewVisibilityVisitor, + }, + GenerateShadowPatchImplVisitor, ShadowGenerator, +}; + +#[derive(Default)] +struct MacroParameters { + auto_derive: Option, + no_default: Option, +} + +impl Parse for MacroParameters { + fn parse(input: ParseStream) -> syn::Result { + let mut out = MacroParameters::default(); + + while let Ok(optional) = input.parse::() { + match (optional.path.get_ident(), optional.value) { + ( + Some(ident), + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Bool(v), + .. + }), + ) if ident == "auto_derive" => { + out.auto_derive = Some(v.value); + } + ( + Some(ident), + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Bool(v), + .. + }), + ) if ident == "no_default" => { + out.no_default = Some(v.value); + } + _ => {} + } + + if input.parse::().is_err() { + break; + } + } + + Ok(out) + } +} + +pub fn shadow_patch(attr: TokenStream, input: TokenStream) -> TokenStream { + let full_input = syn::parse2::(input).unwrap(); + let macro_params = syn::parse2::(attr).unwrap(); + + let original_ident = full_input.ident.to_string(); + + let delta = format_ident!("Delta"); + let reported = format_ident!("Reported"); + let reported_ident = format_ident!("Reported{}", &original_ident); + + let mut auto_derives = vec!["Serialize", "Deserialize", "Clone"]; + if !macro_params.no_default.unwrap_or_default() { + auto_derives.push("Default"); + } + + let desired_tokens = { + ShadowGenerator::new(full_input.clone()) + .modifier(&mut ReportOnlyModifier) + .modifier(&mut RenameModifier(original_ident.clone())) + .modifier(&mut WithDerivesModifier( + macro_params.auto_derive.unwrap_or(true), + auto_derives.clone(), + )) + .generator(&mut DefaultGenerator( + macro_params.no_default.unwrap_or_default(), + )) + .variant_or_field_visitor(&mut RemoveShadowAttributesVisitor) + .generator(&mut NewGenerator) + .finalize() + }; + + let delta_tokens = { + let mut shadowpatch_impl_generator = + GenerateShadowPatchImplVisitor::new(Path::from(reported_ident.clone())); + + ShadowGenerator::new(full_input.clone()) + .modifier(&mut ReportOnlyModifier) + .variant_or_field_visitor(&mut SetNewVisibilityVisitor(true)) + .variant_or_field_visitor(&mut SetNewTypeVisitor(delta.clone())) + .variant_or_field_visitor(&mut shadowpatch_impl_generator) + .modifier(&mut RenameModifier(format!( + "{}{}", + &delta, &original_ident + ))) + .modifier(&mut WithDerivesModifier( + macro_params.auto_derive.unwrap_or(true), + auto_derives, + )) + .generator(&mut DefaultGenerator( + macro_params.no_default.unwrap_or_default(), + )) + .generator(&mut shadowpatch_impl_generator) + .variant_or_field_visitor(&mut RemoveShadowAttributesVisitor) + .generator(&mut NewGenerator) + .finalize() + }; + + let reported_tokens = { + ShadowGenerator::new(full_input) + .variant_or_field_visitor(&mut SetNewVisibilityVisitor(true)) + .variant_or_field_visitor(&mut SetNewTypeVisitor(reported)) + .variant_or_field_visitor(&mut AddSerdeSkipAttribute) + .modifier(&mut RenameModifier(reported_ident.to_string())) + .modifier(&mut WithDerivesModifier( + macro_params.auto_derive.unwrap_or(true), + vec!["Serialize", "Default"], + )) + .generator(&mut DefaultGenerator(false)) + .variant_or_field_visitor(&mut RemoveShadowAttributesVisitor) + .generator(&mut NewGenerator) + .modifier(&mut ReportOnlyModifier) + .generator(&mut GenerateFromImpl) + .finalize() + }; + + quote! { + #desired_tokens + + #delta_tokens + + #reported_tokens + } +} diff --git a/rustot_derive/tests/shadow.rs b/rustot_derive/tests/shadow.rs new file mode 100644 index 0000000..4eb7d8a --- /dev/null +++ b/rustot_derive/tests/shadow.rs @@ -0,0 +1,381 @@ +use rustot::shadows::{ShadowPatch, ShadowState}; +use rustot_derive::{shadow, shadow_patch}; +use serde::{Deserialize, Serialize}; + +#[test] +fn nested() { + #[shadow(name = "test", max_payload_size = 256)] + #[derive(Debug, PartialEq)] + // #[shadow_all(derive(Clone, Debug, PartialEq))] + // #[shadow_all(serde(rename_all = "lowercase"))] + // #[shadow_delta(derive(Deserialize))] + // #[shadow_reported(derive(Serialize, Default))] + struct Foo { + pub bar: u8, + + #[shadow_attr(leaf)] + #[serde(rename = "desired_rename")] + pub baz: String, + + pub inner: Inner, + + #[shadow_attr(report_only)] + #[serde(rename = "report_only_rename")] + pub report_only: u8, + } + + #[shadow_patch] + #[derive(Debug, PartialEq)] + struct Inner { + hello: u16, + + #[shadow_attr(report_only, leaf)] + inner_report: String, + } + + let mut foo = Foo { + bar: 56, + baz: "HelloWorld".to_string(), + inner: Inner { hello: 1337 }, + }; + + ReportedFoo { + bar: Some(56), + baz: Some("HelloWorld".to_string()), + inner: Some(ReportedInner { + hello: Some(1337), + inner_report: None, + }), + report_only: None, + }; + + let delta = DeltaFoo { + bar: Some(66), + baz: None, + inner: Some(DeltaInner { hello: None }), + }; + + assert_eq!(Foo::NAME, Some("test")); + assert_eq!(Foo::MAX_PAYLOAD_SIZE, 256); + + assert_eq!( + ReportedFoo::from(foo.clone()), + ReportedFoo { + bar: Some(56), + baz: Some("HelloWorld".to_string()), + inner: Some(ReportedInner { + hello: Some(1337), + inner_report: None, + }), + report_only: None, + } + ); + + foo.apply_patch(delta); + + assert_eq!( + foo, + Foo { + bar: 66, + baz: "HelloWorld".to_string(), + inner: Inner { hello: 1337 } + } + ); +} + +#[test] +fn optionals() { + #[shadow] + #[derive(Debug, PartialEq)] + struct Foo { + pub bar: u8, + + #[shadow_attr(report_only)] + pub report_only: Option, + + #[shadow_attr(report_only)] + pub report_only_nested: Option, + } + + #[shadow_patch] + #[derive(Debug, PartialEq)] + struct Inner { + hello: u16, + + #[shadow_attr(report_only, leaf)] + inner_report: String, + } + + assert_eq!(Foo::NAME, None); + assert_eq!(Foo::MAX_PAYLOAD_SIZE, 512); + + let mut desired = Foo { bar: 123 }; + + desired.apply_patch(DeltaFoo { bar: Some(78) }); + + assert_eq!(desired, Foo { bar: 78 }); + + let _reported = ReportedFoo { + bar: Some(56), + report_only: Some(Some(14)), + report_only_nested: Some(Some(ReportedInner { + hello: Some(1337), + inner_report: None, + })), + }; +} + +#[test] +fn simple_enum() { + #[shadow] + #[derive(Debug, PartialEq)] + struct Foo { + #[shadow_attr(leaf)] + pub bar: Either, + } + + #[derive(Debug, Default, PartialEq, Serialize, Deserialize, Clone)] + enum Either { + #[default] + A, + B, + } + + let mut desired = Foo { bar: Either::A }; + + let reported = ReportedFoo { + bar: Some(Either::B), + }; + + desired.apply_patch(DeltaFoo { + bar: Some(Either::B), + }); + + assert_eq!(ReportedFoo::from(desired), reported); +} + +#[test] +fn complex_enum() { + #[shadow(topic_prefix = "test")] + #[derive(Debug, PartialEq)] + struct Foo { + pub bar: Either, + } + + #[shadow_patch] + #[derive(Debug, Default, PartialEq)] + pub enum Either { + #[default] + A(InnerA), + B(u32), + C, + D(InnerA, InnerB), + E { + field1: InnerA, + field2: InnerB, + }, + } + + #[shadow_patch] + #[derive(Debug, PartialEq)] + struct InnerA { + hello: u16, + } + + #[shadow_patch] + #[derive(Debug, PartialEq)] + struct InnerB { + baz: i32, + } + + assert_eq!(Foo::PREFIX, "test"); + + let mut desired = Foo { + bar: Either::A(InnerA { hello: 1337 }), + }; + + let reported = ReportedFoo { + bar: Some(ReportedEither::D( + Some(ReportedInnerA { hello: Some(56) }), + Some(ReportedInnerB { baz: Some(0) }), + )), + }; + + desired.apply_patch(DeltaFoo { + bar: Some(DeltaEither::D(Some(DeltaInnerA { hello: Some(56) }), None)), + }); + + assert_eq!( + desired, + Foo { + bar: Either::D(InnerA { hello: 56 }, InnerB::default()) + } + ); + assert_eq!(ReportedFoo::from(desired), reported); +} + +#[test] +fn static_str() { + #[shadow] + #[derive(Debug, PartialEq)] + struct Foo { + // fails: &'static str, + #[shadow_attr(report_only, leaf)] + pub bar: &'static str, + + #[shadow_attr(report_only, leaf)] + pub baz: Option<&'static str>, + } + + let _foo = Foo {}; + + let _reported = ReportedFoo { + bar: Some("Hello"), + baz: Some(Some("HelloBaz")), + }; +} + +#[test] +fn manual_reported() { + #[shadow(name = "manual", reported = ManualReportedFoo)] + #[derive(Debug, PartialEq)] + struct Foo { + pub bar: u8, + + #[shadow_attr(leaf)] + #[serde(rename = "desired_rename")] + pub baz: String, + + pub inner: Inner, + + #[shadow_attr(report_only)] + #[serde(rename = "report_only_rename")] + pub report_only: u8, + } + #[derive(Serialize, Default, Debug, PartialEq)] + struct ManualReportedFoo { + #[serde(skip_serializing_if = "Option::is_none")] + pub bar: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "desired_rename")] + pub baz: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub inner: Option<::Reported>, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "report_only_rename")] + pub report_only: Option, + } + + impl From for ManualReportedFoo { + fn from(v: Foo) -> Self { + Self { + bar: Some(v.bar), + baz: Some(v.baz), + inner: Some(v.inner.into()), + ..Default::default() + } + } + } + + #[shadow_patch] + #[derive(Debug, PartialEq)] + struct Inner { + hello: u16, + + #[shadow_attr(report_only, leaf)] + inner_report: String, + } + + let mut foo = Foo { + bar: 56, + baz: "HelloWorld".to_string(), + inner: Inner { hello: 1337 }, + }; + + ManualReportedFoo { + bar: Some(56), + baz: Some("HelloWorld".to_string()), + inner: Some(ReportedInner { + hello: Some(1337), + inner_report: None, + }), + report_only: None, + }; + + let delta = DeltaFoo { + bar: Some(66), + baz: None, + inner: Some(DeltaInner { hello: None }), + }; + + assert_eq!(Foo::NAME, Some("manual")); + + assert_eq!( + ManualReportedFoo::from(foo.clone()), + ManualReportedFoo { + bar: Some(56), + baz: Some("HelloWorld".to_string()), + inner: Some(ReportedInner { + hello: Some(1337), + inner_report: None, + }), + report_only: None, + } + ); + + foo.apply_patch(delta); + + assert_eq!( + foo, + Foo { + bar: 66, + baz: "HelloWorld".to_string(), + inner: Inner { hello: 1337 } + } + ); +} + +#[test] +fn enum_leaf() { + use heapless::String; + + #[shadow_patch] + #[derive(Debug, Clone, Default, PartialEq, Eq)] + pub enum LeafField { + #[default] + None, + + Inner(#[shadow_attr(leaf)] String<64>), + } + + // #[shadow_patch] + // #[derive(Debug, Clone, Default, PartialEq, Eq)] + // pub enum LeafVariant { + // #[default] + // None, + + // #[shadow_attr(leaf)] + // Inner(String<64>), + // } +} + +// #[test] +// fn generics() { +// use heapless::String; + +// #[shadow_patch] +// #[derive(Debug, Clone)] +// pub struct Foo { +// #[shadow_attr(leaf)] +// pub ssid: String<64>, + +// pub generic: Inner, +// } + +// #[shadow_patch] +// #[derive(Debug, Clone)] +// pub struct Inner { +// #[shadow_attr(leaf)] +// a: A, +// } +// } diff --git a/scripts/register.sh b/scripts/register.sh index 95cc849..c0e0662 100755 --- a/scripts/register.sh +++ b/scripts/register.sh @@ -1,14 +1,14 @@ #!/usr/bin/env bash -# Registers the device in Blackbird's DynamoDB containing whitelisted devices to +# Registers the device in Factbird's DynamoDB containing whitelisted devices to # be provisioned. -# +# # This script will populate `tests/secrets` with `claim_certificate.pem.crt` & # `claim_private.pem.key`, as well as combine them into `claim_identity.pfx`, -# which is password protected with `env:DEVICE_ADVISOR_PASSWORD` +# which is password protected with `env:IDENTITY_PASSWORD` -if [[ -z "${DEVICE_ADVISOR_PASSWORD}" ]]; then - echo "DEVICE_ADVISOR_PASSWORD environment variable is required!" +if [[ -z "${IDENTITY_PASSWORD}" ]]; then + echo "IDENTITY_PASSWORD environment variable is required!" exit 1 fi @@ -27,6 +27,6 @@ jq -r '.certificatePem' response.json > $SECRETS_DIR/claim_certificate.pem.crt jq -r '.privateKey' response.json > $SECRETS_DIR/claim_private.pem.key rm response.json -openssl pkcs12 -export -out $SECRETS_DIR/claim_identity.pfx -inkey $SECRETS_DIR/claim_private.pem.key -in $SECRETS_DIR/claim_certificate.pem.crt -certfile $SECRETS_DIR/root-ca.pem -passout pass:$DEVICE_ADVISOR_PASSWORD +openssl pkcs12 -export -out $SECRETS_DIR/claim_identity.pfx -inkey $SECRETS_DIR/claim_private.pem.key -in $SECRETS_DIR/claim_certificate.pem.crt -certfile $SECRETS_DIR/root-ca.pem -passout pass:$IDENTITY_PASSWORD rm $SECRETS_DIR/claim_certificate.pem.crt -rm $SECRETS_DIR/claim_private.pem.key \ No newline at end of file +rm $SECRETS_DIR/claim_private.pem.key diff --git a/shadow_derive/Cargo.toml b/shadow_derive/Cargo.toml deleted file mode 100644 index 3dccf1a..0000000 --- a/shadow_derive/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "shadow-derive" -version = "0.2.1" -authors = ["Mathias Koch "] -description = "Procedual macros for rustot crate shadow state" -license = "MIT OR Apache-2.0" -repository = "https://github.com/BlackbirdHQ/rustot" -edition = "2021" -readme = "../README.md" - -[lib] -proc-macro = true - -[dependencies] -syn = "^1" -quote = "^1" -proc-macro2 = "^1" diff --git a/shadow_derive/src/lib.rs b/shadow_derive/src/lib.rs deleted file mode 100644 index 7838e87..0000000 --- a/shadow_derive/src/lib.rs +++ /dev/null @@ -1,251 +0,0 @@ -extern crate proc_macro; -extern crate syn; -#[macro_use] -extern crate quote; - -use proc_macro::TokenStream; -use proc_macro2::Span; -use syn::parse::{Parse, ParseStream, Parser}; -use syn::parse_macro_input; -use syn::DeriveInput; -use syn::Generics; -use syn::Ident; -use syn::Result; -use syn::{parenthesized, Attribute, Error, Field, LitStr}; - -#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field))] -pub fn shadow_state(input: TokenStream) -> TokenStream { - match parse_macro_input!(input as ParseInput) { - ParseInput::Struct(input) => { - let shadow_patch = generate_shadow_patch_struct(&input); - let shadow_state = generate_shadow_state(&input); - let implementation = quote! { - #shadow_patch - - #shadow_state - }; - TokenStream::from(implementation) - } - _ => { - todo!() - } - } -} - -#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, serde))] -pub fn shadow_patch(input: TokenStream) -> TokenStream { - TokenStream::from(match parse_macro_input!(input as ParseInput) { - ParseInput::Struct(input) => generate_shadow_patch_struct(&input), - ParseInput::Enum(input) => generate_shadow_patch_enum(&input), - }) -} - -enum ParseInput { - Struct(StructParseInput), - Enum(EnumParseInput), -} - -#[derive(Clone)] -struct EnumParseInput { - pub ident: Ident, - pub generics: Generics, -} - -#[derive(Clone)] -struct StructParseInput { - pub ident: Ident, - pub generics: Generics, - pub shadow_fields: Vec, - pub copy_attrs: Vec, - pub shadow_name: Option, -} - -impl Parse for ParseInput { - fn parse(input: ParseStream) -> Result { - let derive_input = DeriveInput::parse(input)?; - - let mut shadow_name = None; - let mut copy_attrs = vec![]; - - let attrs_to_copy = ["serde"]; - - // Parse valid container attributes - for attr in derive_input.attrs { - if attr.path.is_ident("shadow") { - fn shadow_arg(input: ParseStream) -> Result { - let content; - parenthesized!(content in input); - content.parse() - } - shadow_name = Some(shadow_arg.parse2(attr.tokens)?); - } else if attrs_to_copy - .iter() - .find(|a| attr.path.is_ident(a)) - .is_some() - { - copy_attrs.push(attr); - } - } - - match derive_input.data { - syn::Data::Struct(syn::DataStruct { fields, .. }) => { - Ok(Self::Struct(StructParseInput { - ident: derive_input.ident, - generics: derive_input.generics, - shadow_fields: fields.into_iter().collect::>(), - copy_attrs, - shadow_name, - })) - } - syn::Data::Enum(syn::DataEnum { .. }) => Ok(Self::Enum(EnumParseInput { - ident: derive_input.ident, - generics: derive_input.generics, - })), - _ => Err(Error::new( - Span::call_site(), - "ShadowState & ShadowPatch can only be derived for non-tuple structs & enums", - )), - } - } -} - -fn create_assigners(fields: &Vec) -> Vec { - fields - .iter() - .filter_map(|field| { - let field_name = &field.ident.clone().unwrap(); - - if field - .attrs - .iter() - .find(|a| a.path.is_ident("static_shadow_field")) - .is_some() - { - None - } else { - Some(quote! { - if let Some(attribute) = opt.#field_name { - self.#field_name.apply_patch(attribute); - } - }) - } - }) - .collect::>() -} - -fn create_optional_fields(fields: &Vec) -> Vec { - fields - .iter() - .filter_map(|field| { - let type_name = &field.ty; - let attrs = field - .attrs - .iter() - .filter(|a| { - !a.path.is_ident("static_shadow_field") - }) - .collect::>(); - let field_name = &field.ident.clone().unwrap(); - - let type_name_string = quote! {#type_name}.to_string(); - let type_name_string: String = type_name_string.chars().filter(|&c| c != ' ').collect(); - - if field - .attrs - .iter() - .find(|a| a.path.is_ident("static_shadow_field")) - .is_some() - { - None - } else { - Some(if type_name_string.starts_with("Option<") { - quote! { #(#attrs)* pub #field_name: Option::PatchState>> } - } else { - quote! { #(#attrs)* pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> } - }) - } - }) - .collect::>() -} - -fn generate_shadow_state(input: &StructParseInput) -> proc_macro2::TokenStream { - let StructParseInput { - ident, - generics, - shadow_name, - .. - } = input; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let name = match shadow_name { - Some(name) => quote! { Some(#name) }, - None => quote! { None }, - }; - - return quote! { - #[automatically_derived] - impl #impl_generics rustot::shadows::ShadowState for #ident #ty_generics #where_clause { - const NAME: Option<&'static str> = #name; - // const MAX_PAYLOAD_SIZE: usize = 512; - } - }; -} - -fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenStream { - let StructParseInput { - ident, - generics, - shadow_fields, - copy_attrs, - .. - } = input; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let optional_ident = format_ident!("Patch{}", ident); - - let assigners = create_assigners(&shadow_fields); - let optional_fields = create_optional_fields(&shadow_fields); - - return quote! { - #[automatically_derived] - #[derive(Default, Clone, ::serde::Deserialize, ::serde::Serialize)] - #(#copy_attrs)* - pub struct #optional_ident #generics { - #( - #optional_fields - ),* - } - - #[automatically_derived] - impl #impl_generics rustot::shadows::ShadowPatch for #ident #ty_generics #where_clause { - type PatchState = #optional_ident; - - fn apply_patch(&mut self, opt: Self::PatchState) { - #( - #assigners - )* - } - } - }; -} - -fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStream { - let EnumParseInput { - ident, generics, .. - } = input; - - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - return quote! { - #[automatically_derived] - impl #impl_generics rustot::shadows::ShadowPatch for #ident #ty_generics #where_clause { - type PatchState = #ident #ty_generics; - - fn apply_patch(&mut self, opt: Self::PatchState) { - *self = opt; - } - } - }; -} diff --git a/src/defender_metrics/aws_types.rs b/src/defender_metrics/aws_types.rs new file mode 100644 index 0000000..484a5a0 --- /dev/null +++ b/src/defender_metrics/aws_types.rs @@ -0,0 +1,81 @@ +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct TcpConnections<'a> { + #[serde(rename = "ec")] + pub established_connections: Option<&'a EstablishedConnections<'a>>, +} + +#[derive(Debug, Serialize)] +pub struct EstablishedConnections<'a> { + #[serde(rename = "cs")] + pub connections: Option<&'a [&'a Connection<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct Connection<'a> { + #[serde(rename = "rad")] + pub remote_addr: &'a str, + + /// Port number, must be >= 0 + #[serde(rename = "lp")] + pub local_port: Option, + + /// Interface name + #[serde(rename = "li")] + pub local_interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct ListeningTcpPorts<'a> { + #[serde(rename = "pts")] + pub ports: Option<&'a [&'a TcpPort<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct TcpPort<'a> { + #[serde(rename = "pt")] + pub port: u16, + + #[serde(rename = "if")] + pub interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct ListeningUdpPorts<'a> { + #[serde(rename = "pts")] + pub ports: Option<&'a [&'a UdpPort<'a>]>, + + #[serde(rename = "t")] + pub total: Option, +} + +#[derive(Debug, Serialize)] +pub struct UdpPort<'a> { + #[serde(rename = "pt")] + pub port: u16, + + #[serde(rename = "if")] + pub interface: Option<&'a str>, +} + +#[derive(Debug, Serialize)] +pub struct NetworkStats { + #[serde(rename = "bi")] + pub bytes_in: Option, + + #[serde(rename = "bo")] + pub bytes_out: Option, + + #[serde(rename = "pi")] + pub packets_in: Option, + + #[serde(rename = "po")] + pub packets_out: Option, +} diff --git a/src/defender_metrics/data_types.rs b/src/defender_metrics/data_types.rs new file mode 100644 index 0000000..4b9eb36 --- /dev/null +++ b/src/defender_metrics/data_types.rs @@ -0,0 +1,84 @@ +use core::fmt::{Display, Write}; + +use bon::Builder; +use embassy_time::Instant; +use serde::{Deserialize, Serialize}; + +use super::aws_types::{ListeningTcpPorts, ListeningUdpPorts, NetworkStats, TcpConnections}; + +#[derive(Debug, Serialize, Builder)] +pub struct Metric<'a, C: Serialize> { + #[serde(rename = "hed")] + pub header: Header, + + #[serde(rename = "met")] + pub metrics: Option>, + + #[serde(rename = "cmet")] + pub custom_metrics: Option, +} + +#[derive(Debug, Serialize)] +pub struct Metrics<'a> { + listening_tcp_ports: Option>, + listening_udp_ports: Option>, + network_stats: Option, + tcp_connections: Option>, +} + +#[derive(Debug, Serialize)] +pub struct Header { + /// Monotonically increasing value. Epoch timestamp recommended. + #[serde(rename = "rid")] + pub report_id: i64, + + /// Version in Major.Minor format. + #[serde(rename = "v")] + pub version: Version, +} + +impl Default for Header { + fn default() -> Self { + Self { + report_id: Instant::now().as_millis() as i64, + version: Default::default(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum CustomMetric<'a> { + Number(i64), + NumberList(&'a [u64]), + StringList(&'a [&'a str]), + IpList(&'a [&'a str]), +} + +/// Format is `Version(Major, Minor)` +#[derive(Debug, PartialEq, Deserialize)] +pub struct Version(pub u8, pub u8); + +impl Serialize for Version { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut st: heapless::String<7> = heapless::String::new(); + st.write_fmt(format_args!("{}.{}", self.0, self.1)).ok(); + + serializer.serialize_str(&st) + } +} + +impl Display for Version { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}.{}", self.0, self.1,) + } +} + +impl Default for Version { + fn default() -> Self { + Self(1, 0) + } +} diff --git a/src/defender_metrics/errors.rs b/src/defender_metrics/errors.rs new file mode 100644 index 0000000..3bd8d5b --- /dev/null +++ b/src/defender_metrics/errors.rs @@ -0,0 +1,31 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ErrorResponse<'a> { + #[serde(rename = "thingName")] + pub thing_name: &'a str, + pub status: &'a str, + #[serde(rename = "statusDetails")] + pub status_details: StatusDetails<'a>, + pub timestamp: i64, +} +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct StatusDetails<'a> { + #[serde(rename = "ErrorCode")] + pub error_code: MetricError, + #[serde(rename = "ErrorMessage")] + pub error_message: Option<&'a str>, +} +#[derive(Debug, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum MetricError { + Malformed, + InvalidPayload, + Throttled, + MissingHeader, + ErrorResponseDeserialize, + PublishSubscribe, + Other, +} diff --git a/src/defender_metrics/mod.rs b/src/defender_metrics/mod.rs new file mode 100644 index 0000000..9e8f09b --- /dev/null +++ b/src/defender_metrics/mod.rs @@ -0,0 +1,458 @@ +use crate::shadows::Error; +use data_types::Metric; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{DeferredPayload, Publish, Subscribe, SubscribeTopic, ToPayload}; +use errors::{ErrorResponse, MetricError}; +use serde::{Deserialize, Serialize}; +use topics::Topic; + +// pub mod aws_types; +pub mod aws_types; +pub mod data_types; +pub mod errors; +pub mod topics; + +pub struct MetricHandler<'a, 'm, M: RawMutex> { + mqtt: &'m embedded_mqtt::MqttClient<'a, M>, +} + +impl<'a, 'm, M: RawMutex> MetricHandler<'a, 'm, M> { + pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M>) -> Self { + Self { mqtt } + } + + pub async fn publish_metric<'c, C: Serialize>( + &self, + metric: Metric<'c, C>, + max_payload_size: usize, + ) -> Result<(), MetricError> { + //Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let payload = DeferredPayload::new( + |buf: &mut [u8]| { + #[cfg(feature = "metric_cbor")] + { + let mut serializer = minicbor_serde::Serializer::new( + minicbor::encode::write::Cursor::new(&mut *buf), + ); + + match metric.serialize(&mut serializer) { + Ok(_) => {} + Err(_) => { + error!("An error happened when serializing metric with cbor"); + return Err(embedded_mqtt::EncodingError::BufferSize); + } + }; + + Ok(serializer.into_encoder().writer().position()) + } + + #[cfg(not(feature = "metric_cbor"))] + { + serde_json_core::to_slice(&metric, buf) + .map_err(|_| embedded_mqtt::EncodingError::BufferSize) + } + }, + max_payload_size, + ); + + let mut subscription = self + .publish_and_subscribe(payload) + .await + .map_err(|_| MetricError::PublishSubscribe)?; + + loop { + let message = subscription + .next_message() + .await + .ok_or(MetricError::Malformed)?; + + match Topic::from_str(message.topic_name()) { + Some(Topic::Accepted) => return Ok(()), + Some(Topic::Rejected) => { + #[cfg(not(feature = "metric_cbor"))] + { + let error_response = + serde_json_core::from_slice::(message.payload()) + .map_err(|_| MetricError::ErrorResponseDeserialize)?; + + return Err(error_response.0.status_details.error_code); + } + + #[cfg(feature = "metric_cbor")] + { + let mut de = minicbor_serde::Deserializer::new(message.payload()); + let error_response = ErrorResponse::deserialize(&mut de) + .map_err(|_| MetricError::ErrorResponseDeserialize)?; + + return Err(error_response.status_details.error_code); + } + } + + _ => (), + }; + } + } + async fn publish_and_subscribe( + &self, + payload: impl ToPayload, + ) -> Result, Error> { + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::Accepted + .format::<64>(self.mqtt.client_id())? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::Rejected + .format::<64>(self.mqtt.client_id())? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + //*** PUBLISH REQUEST ***/ + let topic_name = Topic::Publish.format::<64>(self.mqtt.client_id())?; + + match self + .mqtt + .publish( + Publish::builder() + .topic_name(topic_name.as_str()) + .payload(payload) + .build(), + ) + .await + .map_err(Error::MqttError) + { + Ok(_) => {} + Err(_) => { + error!("ERROR PUBLISHING PAYLOAD"); + return Err(Error::MqttError(embedded_mqtt::Error::BadTopicFilter)); + } + }; + + Ok(sub) + } +} + +#[cfg(test)] +mod tests { + use core::str::FromStr; + + use super::data_types::*; + + use heapless::{LinearMap, String}; + use serde::{ser::SerializeStruct, Serialize}; + + #[test] + fn serialize_version_json() { + let test_cases = [ + (Version(2, 0), "\"2.0\""), + (Version(0, 0), "\"0.0\""), + (Version(0, 1), "\"0.1\""), + (Version(255, 200), "\"255.200\""), + ]; + + for (version, expected) in test_cases.iter() { + let string: String<100> = serde_json_core::to_string(version).unwrap(); + assert_eq!( + string, *expected, + "Serialization failed for Version({}, {}): expected {}, got {}", + version.0, version.1, expected, string + ); + } + } + #[test] + fn serialize_version_cbor() { + let test_cases: [(Version, [u8; 8]); 4] = [ + (Version(2, 0), [99, 50, 46, 48, 0, 0, 0, 0]), + (Version(0, 0), [99, 48, 46, 48, 0, 0, 0, 0]), + (Version(0, 1), [99, 48, 46, 49, 0, 0, 0, 0]), + (Version(255, 200), [103, 50, 53, 53, 46, 50, 48, 48]), + ]; + + for (version, expected) in test_cases.iter() { + let mut buf = [0u8; 200]; + + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(&mut buf[..])); + + version.serialize(&mut serializer).unwrap(); + + let len = serializer.into_encoder().writer().position(); + + assert_eq!( + &buf[..len], + &expected[..len], + "Serialization failed for Version({}, {}): expected {:?}, got {:?}", + version.0, + version.1, + expected, + &buf[..len], + ); + } + } + + #[test] + fn custom_serialization_cbor() { + #[derive(Debug)] + struct WifiMetric { + signal_strength: u8, + } + + impl Serialize for WifiMetric { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("WifiMetricWrapper", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct Number { + number: u8, + } + + let number = Number { + number: self.signal_strength, + }; + + // Serialize number and wrap in array + outer.serialize_field("MyMetricOfType_Number", &[number])?; + outer.end() + } + } + + let custom_metrics: WifiMetric = WifiMetric { + signal_strength: 23, + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let mut buf = [255u8; 1000]; + + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(&mut buf[..])); + + metric.serialize(&mut serializer).unwrap(); + + let len = serializer.into_encoder().writer().position(); + + assert_eq!( + &buf[..len], + [ + 163, 99, 104, 101, 100, 162, 99, 114, 105, 100, 0, 97, 118, 99, 49, 46, 48, 99, + 109, 101, 116, 246, 100, 99, 109, 101, 116, 161, 117, 77, 121, 77, 101, 116, 114, + 105, 99, 79, 102, 84, 121, 112, 101, 95, 78, 117, 109, 98, 101, 114, 129, 161, 102, + 110, 117, 109, 98, 101, 114, 23 + ] + ) + } + + #[test] + fn custom_serialization() { + #[derive(Debug)] + struct WifiMetric { + signal_strength: u8, + } + + impl Serialize for WifiMetric { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("WifiMetricWrapper", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct Number { + number: u8, + } + + let number = Number { + number: self.signal_strength, + }; + + // Serialize number and wrap in array + outer.serialize_field("MyMetricOfType_Number", &[number])?; + outer.end() + } + } + + let custom_metrics: WifiMetric = WifiMetric { + signal_strength: 23, + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":23}]}}", payload.as_str()) + } + #[test] + fn custom_serialization_string_list() { + #[derive(Debug)] + struct CellType { + cell_type: String, + } + + impl Serialize for CellType { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut outer = serializer.serialize_struct("CellType", 1)?; + + // Define the type we want to wrap our signal_strength field in + #[derive(Serialize)] + struct StringList<'a> { + string_list: &'a [&'a str], + } + + let list = StringList { + string_list: &[self.cell_type.as_str()], + }; + + // Serialize number and wrap in array + outer.serialize_field("cell_type", &[list])?; + outer.end() + } + } + + let custom_metrics: CellType<4> = CellType { + cell_type: String::from_str("gsm").unwrap(), + }; + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"cell_type\":[{\"string_list\":[\"gsm\"]}]}}", payload.as_str()) + } + #[test] + fn number() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + let name_of_metric = String::from_str("myMetric").unwrap(); + + custom_metrics + .insert(name_of_metric, [CustomMetric::Number(23)]) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"myMetric\":[{\"number\":23}]}}", payload.as_str()) + } + + #[test] + fn number_list() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + // NUMBER LIST + let my_number_list = String::from_str("my_number_list").unwrap(); + + custom_metrics + .insert(my_number_list, [CustomMetric::NumberList(&[123, 456, 789])]) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"my_number_list\":[{\"number_list\":[123,456,789]}]}}", payload.as_str()) + } + + #[test] + fn string_list() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 16> = LinearMap::new(); + + // STRING LIST + let my_string_list = String::from_str("my_string_list").unwrap(); + + custom_metrics + .insert( + my_string_list, + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"my_string_list\":[{\"string_list\":[\"value_1\",\"value_2\"]}]}}", payload.as_str()) + } + + #[test] + fn all_types() { + let mut custom_metrics: LinearMap, [CustomMetric; 1], 4> = LinearMap::new(); + + let my_number = String::from_str("MyMetricOfType_Number").unwrap(); + custom_metrics + .insert(my_number, [CustomMetric::Number(1)]) + .unwrap(); + + let my_number_list = String::from_str("MyMetricOfType_NumberList").unwrap(); + custom_metrics + .insert(my_number_list, [CustomMetric::NumberList(&[1, 2, 3])]) + .unwrap(); + + let my_string_list = String::from_str("MyMetricOfType_StringList").unwrap(); + custom_metrics + .insert( + my_string_list, + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + let my_ip_list = String::from_str("MyMetricOfType_IpList").unwrap(); + custom_metrics + .insert( + my_ip_list, + [CustomMetric::IpList(&["172.0.0.0", "172.0.0.10"])], + ) + .unwrap(); + + let metric = Metric::builder() + .header(Default::default()) + .custom_metrics(custom_metrics) + .build(); + + let payload: String<4000> = serde_json_core::to_string(&metric).unwrap(); + + assert_eq!("{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":1}],\"MyMetricOfType_NumberList\":[{\"number_list\":[1,2,3]}],\"MyMetricOfType_StringList\":[{\"string_list\":[\"value_1\",\"value_2\"]}],\"MyMetricOfType_IpList\":[{\"ip_list\":[\"172.0.0.0\",\"172.0.0.10\"]}]}}", payload.as_str()) + } +} diff --git a/src/defender_metrics/topics.rs b/src/defender_metrics/topics.rs new file mode 100644 index 0000000..eba97f7 --- /dev/null +++ b/src/defender_metrics/topics.rs @@ -0,0 +1,78 @@ +#![allow(dead_code)] +use core::fmt::Write; + +use heapless::String; + +use crate::shadows::Error; + +pub enum PayloadFormat { + #[cfg(feature = "metric_cbor")] + Cbor, + #[cfg(not(feature = "metric_cbor"))] + Json, +} + +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum Topic { + Accepted, + Rejected, + Publish, +} + +impl Topic { + const PREFIX: &'static str = "$aws/things"; + const NAME: &'static str = "defender/metrics"; + + #[cfg(feature = "metric_cbor")] + const PAYLOAD_FORMAT: &'static str = "cbor"; + + #[cfg(not(feature = "metric_cbor"))] + const PAYLOAD_FORMAT: &'static str = "json"; + + pub fn format(&self, thing_name: &str) -> Result, Error> { + let mut topic_path = String::new(); + + match self { + Self::Accepted => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}/accepted", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + Self::Rejected => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}/rejected", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + Self::Publish => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}", + Self::PREFIX, + thing_name, + Self::NAME, + Self::PAYLOAD_FORMAT, + )), + } + .map_err(|_| Error::Overflow)?; + + Ok(topic_path) + } + + pub fn from_str(s: &str) -> Option { + let tt = s.splitn(7, '/').collect::>(); + match (tt.first(), tt.get(1), tt.get(3), tt.get(4)) { + (Some(&"$aws"), Some(&"things"), Some(&"defender"), Some(&"metrics")) => { + // This is a defender metric topic, now figure out which one. + + match tt.get(6) { + Some(&"accepted") => Some(Topic::Accepted), + Some(&"rejected") => Some(Topic::Rejected), + _ => None, + } + } + _ => None, + } + } +} diff --git a/src/fmt.rs b/src/fmt.rs index c06793e..45bfe1f 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -1,31 +1,12 @@ -// MIT License - -// Copyright (c) 2020 Dario Nieuwenhuis - -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: - -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. - -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - #![macro_use] -#![allow(unused_macros)] +#![allow(unused)] + +use core::fmt::{Debug, Display, LowerHex}; #[cfg(all(feature = "defmt", feature = "log"))] compile_error!("You may not enable both `defmt` and `log` features."); +#[collapse_debuginfo(yes)] macro_rules! assert { ($($x:tt)*) => { { @@ -37,6 +18,7 @@ macro_rules! assert { }; } +#[collapse_debuginfo(yes)] macro_rules! assert_eq { ($($x:tt)*) => { { @@ -48,6 +30,7 @@ macro_rules! assert_eq { }; } +#[collapse_debuginfo(yes)] macro_rules! assert_ne { ($($x:tt)*) => { { @@ -59,6 +42,7 @@ macro_rules! assert_ne { }; } +#[collapse_debuginfo(yes)] macro_rules! debug_assert { ($($x:tt)*) => { { @@ -70,6 +54,7 @@ macro_rules! debug_assert { }; } +#[collapse_debuginfo(yes)] macro_rules! debug_assert_eq { ($($x:tt)*) => { { @@ -81,6 +66,7 @@ macro_rules! debug_assert_eq { }; } +#[collapse_debuginfo(yes)] macro_rules! debug_assert_ne { ($($x:tt)*) => { { @@ -92,6 +78,7 @@ macro_rules! debug_assert_ne { }; } +#[collapse_debuginfo(yes)] macro_rules! todo { ($($x:tt)*) => { { @@ -103,17 +90,23 @@ macro_rules! todo { }; } +#[cfg(not(feature = "defmt"))] +#[collapse_debuginfo(yes)] macro_rules! unreachable { ($($x:tt)*) => { - { - #[cfg(not(feature = "defmt"))] - ::core::unreachable!($($x)*); - #[cfg(feature = "defmt")] - ::defmt::unreachable!($($x)*); - } + ::core::unreachable!($($x)*) }; } +#[cfg(feature = "defmt")] +#[collapse_debuginfo(yes)] +macro_rules! unreachable { + ($($x:tt)*) => { + ::defmt::unreachable!($($x)*) + }; +} + +#[collapse_debuginfo(yes)] macro_rules! panic { ($($x:tt)*) => { { @@ -125,6 +118,7 @@ macro_rules! panic { }; } +#[collapse_debuginfo(yes)] macro_rules! trace { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -138,6 +132,7 @@ macro_rules! trace { }; } +#[collapse_debuginfo(yes)] macro_rules! debug { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -151,6 +146,7 @@ macro_rules! debug { }; } +#[collapse_debuginfo(yes)] macro_rules! info { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -164,6 +160,7 @@ macro_rules! info { }; } +#[collapse_debuginfo(yes)] macro_rules! warn { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -177,6 +174,7 @@ macro_rules! warn { }; } +#[collapse_debuginfo(yes)] macro_rules! error { ($s:literal $(, $x:expr)* $(,)?) => { { @@ -191,6 +189,7 @@ macro_rules! error { } #[cfg(feature = "defmt")] +#[collapse_debuginfo(yes)] macro_rules! unwrap { ($($x:tt)*) => { ::defmt::unwrap!($($x)*) @@ -198,6 +197,7 @@ macro_rules! unwrap { } #[cfg(not(feature = "defmt"))] +#[collapse_debuginfo(yes)] macro_rules! unwrap { ($arg:expr) => { match $crate::fmt::Try::into_result($arg) { @@ -245,3 +245,30 @@ impl Try for Result { self } } + +pub(crate) struct Bytes<'a>(pub &'a [u8]); + +impl Debug for Bytes<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:#02x?}", self.0) + } +} + +impl Display for Bytes<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:#02x?}", self.0) + } +} + +impl LowerHex for Bytes<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:#02x?}", self.0) + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Bytes<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "{:02x}", self.0) + } +} diff --git a/src/jobs/data_types.rs b/src/jobs/data_types.rs index 23101eb..fb6100e 100644 --- a/src/jobs/data_types.rs +++ b/src/jobs/data_types.rs @@ -22,7 +22,8 @@ pub enum JobStatus { Removed, } -#[derive(Debug, Clone, PartialEq, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum ErrorCode { /// The request was sent to a topic in the AWS IoT Jobs namespace that does /// not map to any API. @@ -89,7 +90,7 @@ pub struct GetPendingJobExecutionsResponse<'a> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Contains data about a job execution. @@ -211,7 +212,7 @@ pub struct StartNextPendingJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Topic (accepted): $aws/things/{thingName}/jobs/{jobId}/update/accepted \ @@ -232,7 +233,7 @@ pub struct UpdateJobExecutionResponse<'a, J> { /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] - pub client_token: &'a str, + pub client_token: Option<&'a str>, } /// Sent whenever a job execution is added to or removed from the list of @@ -289,9 +290,9 @@ pub struct Jobs { /// service operation. #[derive(Debug, PartialEq, Deserialize)] pub struct ErrorResponse<'a> { - code: ErrorCode, + pub code: ErrorCode, /// An error message string. - message: &'a str, + pub message: &'a str, /// A client token used to correlate requests and responses. Enter an /// arbitrary value here and it is reflected in the response. #[serde(rename = "clientToken")] @@ -394,7 +395,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: None, timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); @@ -417,7 +418,7 @@ mod test { queued_jobs .push(JobExecutionSummary { execution_number: Some(1), - job_id: Some(String::from("test")), + job_id: Some(String::try_from("test").unwrap()), last_updated_at: Some(1587036256), queued_at: Some(1587036256), started_at: None, @@ -433,7 +434,7 @@ mod test { in_progress_jobs: Some(Vec::::new()), queued_jobs: Some(queued_jobs), timestamp: 1587381778, - client_token: "0:client_name", + client_token: Some("0:client_name"), } ); } diff --git a/src/jobs/describe.rs b/src/jobs/describe.rs index 846181f..81579c7 100644 --- a/src/jobs/describe.rs +++ b/src/jobs/describe.rs @@ -1,4 +1,3 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; use crate::jobs::JobTopic; @@ -79,18 +78,22 @@ impl<'a> Describe<'a> { pub fn topic_payload( self, client_id: &str, + buf: &mut [u8], ) -> Result< ( heapless::String<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 22 }>, - heapless::Vec, + usize, ), JobError, > { - let payload = serde_json_core::to_vec(&DescribeJobExecutionRequest { - execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), - client_token: self.client_token, - }) + let payload_len = serde_json_core::to_slice( + &DescribeJobExecutionRequest { + execution_number: self.execution_number, + include_job_document: self.include_job_document.then_some(true), + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; Ok(( @@ -98,17 +101,9 @@ impl<'a> Describe<'a> { .map(JobTopic::Get) .unwrap_or(JobTopic::GetNext) .format(client_id)?, - payload, + payload_len, )) } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) - } } #[cfg(test)] @@ -131,15 +126,16 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = Describe::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = Describe::new() .include_job_document() .execution_number(1) .client_token("test_client:token") - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); assert_eq!( - payload, + &buf[..payload_len], br#"{"executionNumber":1,"includeJobDocument":true,"clientToken":"test_client:token"}"# ); @@ -148,16 +144,17 @@ mod test { #[test] fn topic_job_id() { - let (topic, payload) = Describe::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = Describe::new() .include_job_document() .execution_number(1) .job_id("test_job_id") .client_token("test_client:token") - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); assert_eq!( - payload, + &buf[..payload_len], br#"{"executionNumber":1,"includeJobDocument":true,"clientToken":"test_client:token"}"# ); diff --git a/src/jobs/get_pending.rs b/src/jobs/get_pending.rs index d44f2f0..4c185ca 100644 --- a/src/jobs/get_pending.rs +++ b/src/jobs/get_pending.rs @@ -1,4 +1,3 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; use crate::jobs::JobTopic; @@ -38,27 +37,17 @@ impl<'a> GetPending<'a> { pub fn topic_payload( self, client_id: &str, - ) -> Result< - ( - heapless::String<{ MAX_THING_NAME_LEN + 21 }>, - heapless::Vec, - ), - JobError, - > { - let payload = serde_json_core::to_vec(&&GetPendingJobExecutionsRequest { - client_token: self.client_token, - }) + buf: &mut [u8], + ) -> Result<(heapless::String<{ MAX_THING_NAME_LEN + 21 }>, usize), JobError> { + let payload_len = serde_json_core::to_slice( + &&GetPendingJobExecutionsRequest { + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; - Ok((JobTopic::GetPending.format(client_id)?, payload)) - } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) + Ok((JobTopic::GetPending.format(client_id)?, payload_len)) } } @@ -80,12 +69,16 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = GetPending::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = GetPending::new() .client_token("test_client:token_pending") - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); - assert_eq!(payload, br#"{"clientToken":"test_client:token_pending"}"#); + assert_eq!( + &buf[..payload_len], + br#"{"clientToken":"test_client:token_pending"}"# + ); assert_eq!(topic.as_str(), "$aws/things/test_client/jobs/get"); } diff --git a/src/jobs/mod.rs b/src/jobs/mod.rs index 1481087..8f35c71 100644 --- a/src/jobs/mod.rs +++ b/src/jobs/mod.rs @@ -102,14 +102,13 @@ pub mod describe; pub mod get_pending; pub mod start_next; pub mod subscribe; -pub mod unsubscribe; pub mod update; use core::fmt::Write; use self::{ data_types::JobStatus, describe::Describe, get_pending::GetPending, start_next::StartNext, - subscribe::Subscribe, unsubscribe::Unsubscribe, update::Update, + update::Update, }; pub use subscribe::Topic; @@ -124,17 +123,11 @@ pub const MAX_RUNNING_JOBS: usize = 1; pub type StatusDetails<'a> = heapless::LinearMap<&'a str, &'a str, 4>; pub type StatusDetailsOwned = heapless::LinearMap, heapless::String<11>, 4>; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum JobError { Overflow, Encoding, - Mqtt(mqttrust::MqttError), -} - -impl From for JobError { - fn from(e: mqttrust::MqttError) -> Self { - Self::Mqtt(e) - } + Mqtt(embedded_mqtt::Error), } #[derive(Debug, Clone, PartialEq)] @@ -269,15 +262,7 @@ impl Jobs { Describe::new() } - pub fn update(job_id: &str, status: JobStatus) -> Update { - Update::new(job_id, status) - } - - pub fn subscribe<'a, const N: usize>() -> Subscribe<'a, N> { - Subscribe::new() - } - - pub fn unsubscribe<'a, const N: usize>() -> Unsubscribe<'a, N> { - Unsubscribe::new() + pub fn update<'a>(status: JobStatus) -> Update<'a> { + Update::new(status) } } diff --git a/src/jobs/start_next.rs b/src/jobs/start_next.rs index 05e568a..9a4f83f 100644 --- a/src/jobs/start_next.rs +++ b/src/jobs/start_next.rs @@ -1,4 +1,3 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; use crate::jobs::JobTopic; @@ -84,28 +83,18 @@ impl<'a> StartNext<'a> { pub fn topic_payload( self, client_id: &str, - ) -> Result< - ( - heapless::String<{ MAX_THING_NAME_LEN + 28 }>, - heapless::Vec, - ), - JobError, - > { - let payload = serde_json_core::to_vec(&StartNextPendingJobExecutionRequest { - step_timeout_in_minutes: self.step_timeout_in_minutes, - client_token: self.client_token, - }) + buf: &mut [u8], + ) -> Result<(heapless::String<{ MAX_THING_NAME_LEN + 28 }>, usize), JobError> { + let payload_len = serde_json_core::to_slice( + &StartNextPendingJobExecutionRequest { + step_timeout_in_minutes: self.step_timeout_in_minutes, + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; - Ok((JobTopic::StartNext.format(client_id)?, payload)) - } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) + Ok((JobTopic::StartNext.format(client_id)?, payload_len)) } } @@ -136,14 +125,15 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = StartNext::new() + let mut buf = [0u8; 512]; + let (topic, payload_len) = StartNext::new() .client_token("test_client:token_next_pending") .step_timeout_in_minutes(43) - .topic_payload("test_client") + .topic_payload("test_client", &mut buf) .unwrap(); assert_eq!( - payload, + &buf[..payload_len], br#"{"stepTimeoutInMinutes":43,"clientToken":"test_client:token_next_pending"}"# ); diff --git a/src/jobs/subscribe.rs b/src/jobs/subscribe.rs index 1f740cc..4fc604b 100644 --- a/src/jobs/subscribe.rs +++ b/src/jobs/subscribe.rs @@ -1,8 +1,4 @@ -use mqttrust::{Mqtt, QoS, SubscribeTopic}; - -use crate::jobs::JobError; - -use super::{JobTopic, MAX_JOB_ID_LEN}; +use super::JobTopic; #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -22,7 +18,7 @@ pub enum Topic<'a> { impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(8, '/').collect::>(); - Some(match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + Some(match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), _, Some(&"jobs")) => { // This is a job topic! Figure out which match (tt.get(4), tt.get(5), tt.get(6), tt.get(7)) { @@ -72,163 +68,3 @@ impl<'a> From<&Topic<'a>> for JobTopic<'a> { } } } - -#[derive(Default)] -pub struct Subscribe<'a, const N: usize> { - topics: heapless::Vec<(Topic<'a>, QoS), N>, -} - -impl<'a, const N: usize> Subscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>, qos: QoS) -> Self { - match topic { - Topic::DescribeAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::DescribeRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - _ => {} - } - - if self.topics.iter().any(|(t, _)| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push((topic, qos)).ok(); - - Self { topics } - } - - pub fn topics( - self, - client_id: &str, - ) -> Result, QoS), N>, JobError> { - // assert!(client_id.len() <= super::MAX_THING_NAME_LEN); - self.topics - .iter() - .map(|(topic, qos)| Ok((JobTopic::from(topic).format(client_id)?, *qos))) - .collect() - } - - pub fn send(self, mqtt: &M) -> Result<(), JobError> { - if self.topics.is_empty() { - return Ok(()); - } - let topic_paths = self.topics(mqtt.client_id())?; - - let topics: heapless::Vec<_, N> = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - qos: *qos, - }) - .collect(); - - debug!("Subscribing!"); - - for t in topics.chunks(5) { - mqtt.subscribe(t)?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use mqttrust::{encoding::v4::decode_slice, Packet, QoS, SubscribeTopic}; - - use super::*; - - use crate::test::MockMqtt; - - #[test] - fn splits_subscribe_all() { - let mqtt = &MockMqtt::new(); - - Subscribe::<10>::new() - .topic(Topic::Notify, QoS::AtLeastOnce) - .topic(Topic::NotifyNext, QoS::AtLeastOnce) - .topic(Topic::GetAccepted, QoS::AtLeastOnce) - .topic(Topic::GetRejected, QoS::AtLeastOnce) - .topic(Topic::StartNextAccepted, QoS::AtLeastOnce) - .topic(Topic::StartNextRejected, QoS::AtLeastOnce) - .topic(Topic::DescribeAccepted("test_job"), QoS::AtLeastOnce) - .topic(Topic::DescribeRejected("test_job"), QoS::AtLeastOnce) - .topic(Topic::UpdateAccepted("test_job"), QoS::AtLeastOnce) - .topic(Topic::UpdateRejected("test_job"), QoS::AtLeastOnce) - .send(mqtt) - .unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 2); - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - let packet = decode_slice(bytes.as_slice()).unwrap(); - - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![ - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/notify", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/notify-next", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/get/accepted", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/get/rejected", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/start-next/accepted", - qos: QoS::AtLeastOnce - } - ] - ); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - let packet = decode_slice(bytes.as_slice()).unwrap(); - - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![ - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/start-next/rejected", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/get/accepted", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/get/rejected", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/update/accepted", - qos: QoS::AtLeastOnce - }, - SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/test_job/update/rejected", - qos: QoS::AtLeastOnce - } - ] - ); - } -} diff --git a/src/jobs/unsubscribe.rs b/src/jobs/unsubscribe.rs deleted file mode 100644 index 79009ac..0000000 --- a/src/jobs/unsubscribe.rs +++ /dev/null @@ -1,124 +0,0 @@ -use mqttrust::Mqtt; - -use crate::jobs::JobTopic; - -use super::{subscribe::Topic, JobError, MAX_JOB_ID_LEN}; - -#[derive(Default)] -pub struct Unsubscribe<'a, const N: usize> { - topics: heapless::Vec, N>, -} - -impl<'a, const N: usize> Unsubscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>) -> Self { - match topic { - Topic::DescribeAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::DescribeRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateAccepted(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - Topic::UpdateRejected(job_id) => assert!(job_id.len() <= MAX_JOB_ID_LEN), - _ => {} - } - - if self.topics.iter().any(|t| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push(topic).ok(); - Self { topics } - } - - pub fn topics( - self, - client_id: &str, - ) -> Result, N>, JobError> { - // assert!(client_id.len() <= super::MAX_THING_NAME_LEN); - - self.topics - .iter() - .map(|topic| JobTopic::from(topic).format(client_id)) - .collect() - } - - pub fn send(self, mqtt: &M) -> Result<(), JobError> { - if self.topics.is_empty() { - return Ok(()); - } - let topic_paths = self.topics(mqtt.client_id())?; - let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - - for t in topics.chunks(5) { - mqtt.unsubscribe(t)?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use mqttrust::{encoding::v4::decode_slice, Packet}; - - use super::*; - use crate::test::MockMqtt; - - #[test] - fn splits_unsubscribe_all() { - let mqtt = &MockMqtt::new(); - - Unsubscribe::<10>::new() - .topic(Topic::Notify) - .topic(Topic::NotifyNext) - .topic(Topic::GetAccepted) - .topic(Topic::GetRejected) - .topic(Topic::StartNextAccepted) - .topic(Topic::StartNextRejected) - .topic(Topic::DescribeAccepted("test_job")) - .topic(Topic::DescribeRejected("test_job")) - .topic(Topic::UpdateAccepted("test_job")) - .topic(Topic::UpdateRejected("test_job")) - .send(mqtt) - .unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 2); - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Unsubscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![ - "$aws/things/test_client/jobs/notify", - "$aws/things/test_client/jobs/notify-next", - "$aws/things/test_client/jobs/get/accepted", - "$aws/things/test_client/jobs/get/rejected", - "$aws/things/test_client/jobs/start-next/accepted", - ] - ); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Unsubscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - assert_eq!( - topics, - vec![ - "$aws/things/test_client/jobs/start-next/rejected", - "$aws/things/test_client/jobs/test_job/get/accepted", - "$aws/things/test_client/jobs/test_job/get/rejected", - "$aws/things/test_client/jobs/test_job/update/accepted", - "$aws/things/test_client/jobs/test_job/update/rejected" - ] - ); - } -} diff --git a/src/jobs/update.rs b/src/jobs/update.rs index 5a3903d..867d0ac 100644 --- a/src/jobs/update.rs +++ b/src/jobs/update.rs @@ -1,9 +1,6 @@ -use mqttrust::{Mqtt, QoS}; use serde::Serialize; -use crate::jobs::{ - data_types::JobStatus, JobTopic, MAX_CLIENT_TOKEN_LEN, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN, -}; +use crate::jobs::{data_types::JobStatus, MAX_CLIENT_TOKEN_LEN}; use super::{JobError, StatusDetailsOwned}; @@ -70,7 +67,6 @@ pub struct UpdateJobExecutionRequest<'a> { } pub struct Update<'a> { - job_id: &'a str, status: JobStatus, client_token: Option<&'a str>, status_details: Option<&'a StatusDetailsOwned>, @@ -82,11 +78,8 @@ pub struct Update<'a> { } impl<'a> Update<'a> { - pub fn new(job_id: &'a str, status: JobStatus) -> Self { - assert!(job_id.len() < MAX_JOB_ID_LEN); - + pub fn new(status: JobStatus) -> Self { Self { - job_id, status, status_details: None, include_job_document: false, @@ -149,42 +142,30 @@ impl<'a> Update<'a> { } } - pub fn topic_payload( - self, - client_id: &str, - ) -> Result< - ( - heapless::String<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>, - heapless::Vec, - ), - JobError, - > { - let payload = serde_json_core::to_vec(&UpdateJobExecutionRequest { - execution_number: self.execution_number, - include_job_document: self.include_job_document.then(|| true), - expected_version: self.expected_version, - include_job_execution_state: self.include_job_execution_state.then(|| true), - status: self.status, - status_details: self.status_details, - step_timeout_in_minutes: self.step_timeout_in_minutes, - client_token: self.client_token, - }) + pub fn payload(self, buf: &mut [u8]) -> Result { + let payload_len = serde_json_core::to_slice( + &UpdateJobExecutionRequest { + execution_number: self.execution_number, + include_job_document: self.include_job_document.then_some(true), + expected_version: self.expected_version, + include_job_execution_state: self.include_job_execution_state.then_some(true), + status: self.status, + status_details: self.status_details, + step_timeout_in_minutes: self.step_timeout_in_minutes, + client_token: self.client_token, + }, + buf, + ) .map_err(|_| JobError::Encoding)?; - Ok((JobTopic::Update(self.job_id).format(client_id)?, payload)) - } - - pub fn send(self, mqtt: &M, qos: QoS) -> Result<(), JobError> { - let (topic, payload) = self.topic_payload(mqtt.client_id())?; - - mqtt.publish(topic.as_str(), &payload, qos)?; - - Ok(()) + Ok(payload_len) } } #[cfg(test)] mod test { + use crate::jobs::JobTopic; + use super::*; use serde_json_core::to_string; @@ -208,17 +189,21 @@ mod test { #[test] fn topic_payload() { - let (topic, payload) = Update::new("test_job_id", JobStatus::Failed) + let mut buf = [0u8; 512]; + let topic = JobTopic::Update("test_job_id") + .format::<64>("test_client") + .unwrap(); + let payload_len = Update::new(JobStatus::Failed) .client_token("test_client:token_update") .step_timeout_in_minutes(50) .execution_number(5) .expected_version(2) .include_job_document() .include_job_execution_state() - .topic_payload("test_client") + .payload(&mut buf) .unwrap(); - assert_eq!(payload, br#"{"executionNumber":5,"expectedVersion":2,"includeJobDocument":true,"includeJobExecutionState":true,"status":"FAILED","stepTimeoutInMinutes":50,"clientToken":"test_client:token_update"}"#); + assert_eq!(&buf[..payload_len], br#"{"executionNumber":5,"expectedVersion":2,"includeJobDocument":true,"includeJobExecutionState":true,"status":"FAILED","stepTimeoutInMinutes":50,"clientToken":"test_client:token_update"}"#); assert_eq!( topic.as_str(), diff --git a/src/lib.rs b/src/lib.rs index 23917e8..9e39896 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,15 @@ #![cfg_attr(not(any(test, feature = "std")), no_std)] +#![allow(async_fn_in_trait)] #![allow(incomplete_features)] #![feature(generic_const_exprs)] +#![deny(clippy::float_arithmetic)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; +pub mod defender_metrics; pub mod jobs; #[cfg(any(feature = "ota_mqtt_data", feature = "ota_http_data"))] pub mod ota; pub mod provisioning; pub mod shadows; - -pub use serde_cbor; - -#[cfg(test)] -pub mod test; diff --git a/src/ota/agent.rs b/src/ota/agent.rs deleted file mode 100644 index 6bf0060..0000000 --- a/src/ota/agent.rs +++ /dev/null @@ -1,154 +0,0 @@ -use super::{ - builder::{self, NoTimer}, - control_interface::ControlInterface, - data_interface::{DataInterface, NoInterface}, - encoding::json::OtaJob, - pal::OtaPal, - state::{Error, Events, JobEventData, SmContext, StateMachine, States}, -}; -use crate::jobs::StatusDetails; - -// OTA Agent driving the FSM of an OTA update -pub struct OtaAgent<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - pub(crate) state: StateMachine>, -} - -// Make sure any active OTA session is cleaned up, and the topics are -// unsubscribed on drop. -impl<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> Drop - for OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - fn drop(&mut self) { - let sm_context = self.state.context_mut(); - sm_context.ota_close().ok(); - sm_context.control.cleanup().ok(); - } -} - -impl<'a, C, DP, T, PAL, const TIMER_HZ: u32> - OtaAgent<'a, C, DP, NoInterface, T, NoTimer, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - T: fugit_timer::Timer, - PAL: OtaPal, -{ - pub fn builder( - control_interface: &'a C, - data_primary: DP, - request_timer: T, - pal: PAL, - ) -> builder::OtaAgentBuilder<'a, C, DP, NoInterface, T, NoTimer, PAL, TIMER_HZ> { - builder::OtaAgentBuilder::new(control_interface, data_primary, request_timer, pal) - } -} - -/// Public interface of the OTA Agent -impl<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - pub fn init(&mut self) { - if matches!(self.state(), &States::Ready) { - self.state.process_event(Events::Start).ok(); - } else { - self.state.process_event(Events::Resume).ok(); - } - } - - pub fn job_update( - &mut self, - job_name: &str, - ota_document: &OtaJob, - status_details: Option<&StatusDetails>, - ) -> Result<&States, Error> { - self.state - .process_event(Events::ReceivedJobDocument(JobEventData { - job_name, - ota_document, - status_details, - })) - } - - pub fn timer_callback(&mut self) -> Result<(), Error> { - let ctx = self.state.context_mut(); - if ctx.request_timer.wait().is_ok() { - return self.state.process_event(Events::RequestTimer).map(drop); - } - - if let Some(ref mut self_test_timer) = ctx.self_test_timer { - if self_test_timer.wait().is_ok() { - error!( - "Self test failed to complete within {} ms", - ctx.config.self_test_timeout_ms - ); - ctx.pal.reset_device().ok(); - } - } - Ok(()) - } - - pub fn process_event(&mut self) -> Result<&States, Error> { - if let Some(event) = self.state.context_mut().events.dequeue() { - self.state.process_event(event) - } else { - Ok(self.state()) - } - } - - pub fn handle_message(&mut self, payload: &mut [u8]) -> Result<&States, Error> { - self.state.process_event(Events::ReceivedFileBlock(payload)) - } - - pub fn check_for_update(&mut self) -> Result<&States, Error> { - if matches!( - self.state(), - States::WaitingForJob | States::RequestingJob | States::WaitingForFileBlock - ) { - self.state.process_event(Events::RequestJobDocument) - } else { - Ok(self.state()) - } - } - - pub fn abort(&mut self) -> Result<&States, Error> { - self.state.process_event(Events::UserAbort) - } - - pub fn suspend(&mut self) -> Result<&States, Error> { - // Stop the request timer - self.state.context_mut().request_timer.cancel().ok(); - - // Send event to OTA agent task. - self.state.process_event(Events::Suspend) - } - - pub fn resume(&mut self) -> Result<&States, Error> { - // Send event to OTA agent task - self.state.process_event(Events::Resume) - } - - pub fn state(&self) -> &States { - self.state.state() - } -} diff --git a/src/ota/builder.rs b/src/ota/builder.rs deleted file mode 100644 index 603d568..0000000 --- a/src/ota/builder.rs +++ /dev/null @@ -1,237 +0,0 @@ -use crate::ota::{ - config::Config, - control_interface::ControlInterface, - data_interface::DataInterface, - pal::OtaPal, - state::{SmContext, StateMachine}, -}; - -use super::{agent::OtaAgent, data_interface::NoInterface, pal::ImageState}; - -pub struct NoTimer; - -impl fugit_timer::Timer for NoTimer { - type Error = (); - - fn now(&mut self) -> fugit_timer::TimerInstantU32 { - todo!() - } - - fn start( - &mut self, - _duration: fugit_timer::TimerDurationU32, - ) -> Result<(), Self::Error> { - todo!() - } - - fn cancel(&mut self) -> Result<(), Self::Error> { - todo!() - } - - fn wait(&mut self) -> nb::Result<(), Self::Error> { - todo!() - } -} - -pub struct OtaAgentBuilder<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - control: &'a C, - data_primary: DP, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - data_secondary: Option, - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - data_secondary: core::marker::PhantomData, - pal: PAL, - request_timer: T, - self_test_timer: Option, - config: Config, -} - -impl<'a, C, DP, T, PAL, const TIMER_HZ: u32> - OtaAgentBuilder<'a, C, DP, NoInterface, T, NoTimer, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - T: fugit_timer::Timer, - PAL: OtaPal, -{ - pub fn new(control_interface: &'a C, data_primary: DP, request_timer: T, pal: PAL) -> Self { - Self { - control: control_interface, - data_primary, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - data_secondary: None, - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - data_secondary: core::marker::PhantomData, - pal, - request_timer, - self_test_timer: None, - config: Config::default(), - } - } -} - -impl<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32> - OtaAgentBuilder<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - pub fn data_secondary( - self, - interface: D, - ) -> OtaAgentBuilder<'a, C, DP, D, T, ST, PAL, TIMER_HZ> { - OtaAgentBuilder { - control: self.control, - data_primary: self.data_primary, - data_secondary: Some(interface), - pal: self.pal, - request_timer: self.request_timer, - self_test_timer: self.self_test_timer, - config: self.config, - } - } - - pub fn block_size(self, block_size: usize) -> Self { - Self { - config: Config { - block_size, - ..self.config - }, - ..self - } - } - - pub fn max_request_momentum(self, max_request_momentum: u8) -> Self { - Self { - config: Config { - max_request_momentum, - ..self.config - }, - ..self - } - } - - pub fn activate_delay(self, activate_delay: u8) -> Self { - Self { - config: Config { - activate_delay, - ..self.config - }, - ..self - } - } - - pub fn request_wait_ms(self, request_wait_ms: u32) -> Self { - Self { - config: Config { - request_wait_ms, - ..self.config - }, - ..self - } - } - - pub fn status_update_frequency(self, status_update_frequency: u32) -> Self { - Self { - config: Config { - status_update_frequency, - ..self.config - }, - ..self - } - } - - pub fn allow_downgrade(self) -> Self { - Self { - config: Config { - allow_downgrade: true, - ..self.config - }, - ..self - } - } - - pub fn with_self_test_timeout( - self, - timer: NST, - timeout_ms: u32, - ) -> OtaAgentBuilder<'a, C, DP, DS, T, NST, PAL, TIMER_HZ> - where - NST: fugit_timer::Timer, - { - OtaAgentBuilder { - control: self.control, - data_primary: self.data_primary, - data_secondary: self.data_secondary, - pal: self.pal, - request_timer: self.request_timer, - self_test_timer: Some(timer), - config: Config { - self_test_timeout_ms: timeout_ms, - ..self.config - }, - } - } - - pub fn build(self) -> OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ> { - OtaAgent { - state: StateMachine::new(SmContext { - events: heapless::spsc::Queue::new(), - control: self.control, - data_secondary: self.data_secondary, - data_primary: self.data_primary, - active_interface: None, - request_momentum: 0, - request_timer: self.request_timer, - self_test_timer: self.self_test_timer, - pal: self.pal, - config: self.config, - image_state: ImageState::Unknown, - }), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - ota::test::mock::{MockPal, MockTimer}, - test::MockMqtt, - }; - - #[test] - fn enables_allow_downgrade() { - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let builder = OtaAgentBuilder::new(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .allow_downgrade(); - - assert!(builder.config.allow_downgrade); - assert!(builder.self_test_timer.is_some()); - assert_eq!(builder.config.self_test_timeout_ms, 32000); - - let agent = builder.build(); - - assert!(agent.state.context().config.allow_downgrade); - } -} diff --git a/src/ota/config.rs b/src/ota/config.rs index 7862273..0d6395c 100644 --- a/src/ota/config.rs +++ b/src/ota/config.rs @@ -1,12 +1,11 @@ +use embassy_time::Duration; + pub struct Config { - pub(crate) block_size: usize, - pub(crate) max_request_momentum: u8, - pub(crate) activate_delay: u8, - pub(crate) request_wait_ms: u32, - pub(crate) status_update_frequency: u32, - pub(crate) allow_downgrade: bool, - pub(crate) unsubscribe_on_shutdown: bool, - pub(crate) self_test_timeout_ms: u32, + pub block_size: usize, + pub max_request_momentum: u8, + pub request_wait: Duration, + pub status_update_frequency: u32, + pub self_test_timeout: Option, } impl Default for Config { @@ -14,12 +13,9 @@ impl Default for Config { Self { block_size: 256, max_request_momentum: 3, - activate_delay: 5, - request_wait_ms: 8000, - status_update_frequency: 24, - allow_downgrade: false, - unsubscribe_on_shutdown: true, - self_test_timeout_ms: 16000, + request_wait: Duration::from_secs(5), + status_update_frequency: 96, + self_test_timeout: None, } } } diff --git a/src/ota/control_interface/mod.rs b/src/ota/control_interface/mod.rs index e1b28e3..ffd8b03 100644 --- a/src/ota/control_interface/mod.rs +++ b/src/ota/control_interface/mod.rs @@ -1,23 +1,21 @@ use crate::jobs::data_types::JobStatus; use super::{ - config::Config, encoding::{json::JobStatusReason, FileContext}, error::OtaError, + ProgressState, }; pub mod mqtt; // Interfaces required for OTA pub trait ControlInterface { - fn init(&self) -> Result<(), OtaError>; - fn request_job(&self) -> Result<(), OtaError>; - fn update_job_status( + async fn request_job(&self) -> Result<(), OtaError>; + async fn update_job_status( &self, - file_ctx: &mut FileContext, - config: &Config, + file_ctx: &FileContext, + progress: &mut ProgressState, status: JobStatus, reason: JobStatusReason, ) -> Result<(), OtaError>; - fn cleanup(&self) -> Result<(), OtaError>; } diff --git a/src/ota/control_interface/mqtt.rs b/src/ota/control_interface/mqtt.rs index f8c34b1..b04da84 100644 --- a/src/ota/control_interface/mqtt.rs +++ b/src/ota/control_interface/mqtt.rs @@ -1,101 +1,195 @@ use core::fmt::Write; -use mqttrust::QoS; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{DeferredPayload, EncodingError, Publish, QoS}; use super::ControlInterface; use crate::jobs::data_types::JobStatus; -use crate::jobs::subscribe::Topic; -use crate::jobs::Jobs; -use crate::ota::config::Config; +use crate::jobs::{JobTopic, Jobs, MAX_JOB_ID_LEN, MAX_THING_NAME_LEN}; use crate::ota::encoding::json::JobStatusReason; use crate::ota::encoding::FileContext; use crate::ota::error::OtaError; +use crate::ota::ProgressState; -impl ControlInterface for T { - /// Initialize the control interface by subscribing to the OTA job - /// notification topics. - fn init(&self) -> Result<(), OtaError> { - Jobs::subscribe::<1>() - .topic(Topic::NotifyNext, QoS::AtLeastOnce) - .send(self)?; - Ok(()) - } - +impl ControlInterface for embedded_mqtt::MqttClient<'_, M> { /// Check for next available OTA job from the job service by publishing a /// "get next job" message to the job service. - fn request_job(&self) -> Result<(), OtaError> { - Jobs::describe().send(self, QoS::AtLeastOnce)?; + async fn request_job(&self) -> Result<(), OtaError> { + // FIXME: Serialize directly into the publish payload through `DeferredPublish` API + let mut buf = [0u8; 512]; + let (topic, payload_len) = Jobs::describe().topic_payload(self.client_id(), &mut buf)?; + + self.publish( + Publish::builder() + .topic_name(&topic) + .payload(&buf[..payload_len]) + .build(), + ) + .await?; Ok(()) } - /// Update the job status on the service side with progress or completion - /// info - fn update_job_status( + /// Update the job status on the service side. + /// + /// Returns a Result indicating success or an error, + /// along with an Option containing the updated status details + /// if they were modified. + async fn update_job_status( &self, - file_ctx: &mut FileContext, - config: &Config, + file_ctx: &FileContext, + progress_state: &mut ProgressState, status: JobStatus, reason: JobStatusReason, ) -> Result<(), OtaError> { - file_ctx + // Update the status details within this function. + progress_state .status_details .insert( - heapless::String::from("self_test"), - heapless::String::from(reason.as_str()), + heapless::String::try_from("self_test").unwrap(), + heapless::String::try_from(reason.as_str()).unwrap(), ) .map_err(|_| OtaError::Overflow)?; - let mut qos = QoS::AtLeastOnce; + let qos = QoS::AtLeastOnce; - if let (JobStatus::InProgress, _) | (JobStatus::Succeeded, _) = (status, reason) { - let total_blocks = - ((file_ctx.filesize + config.block_size - 1) / config.block_size) as u32; - let received_blocks = total_blocks - file_ctx.blocks_remaining as u32; - - // Output a status update once in a while. Always update first and - // last status - if file_ctx.blocks_remaining != 0 - && received_blocks != 0 - && received_blocks % config.status_update_frequency != 0 - { - return Ok(()); - } + if let JobStatus::InProgress | JobStatus::Succeeded = status { + let received_blocks = progress_state.total_blocks - progress_state.blocks_remaining; // Don't override the progress on succeeded, nor on self-test - // active. (Cases where progess counter is lost due to device + // active. (Cases where progress counter is lost due to device // restarts) if status != JobStatus::Succeeded && reason != JobStatusReason::SelfTestActive { let mut progress = heapless::String::new(); progress - .write_fmt(format_args!("{}/{}", received_blocks, total_blocks)) + .write_fmt(format_args!( + "{}/{}", + received_blocks, progress_state.total_blocks + )) .map_err(|_| OtaError::Overflow)?; - file_ctx + progress_state .status_details - .insert(heapless::String::from("progress"), progress) + .insert(heapless::String::try_from("progress").unwrap(), progress) .map_err(|_| OtaError::Overflow)?; } // Downgrade progress updates to QOS 0 to avoid overloading MQTT - // buffers during active streaming - if status == JobStatus::InProgress { - qos = QoS::AtMostOnce; - } + // buffers during active streaming. But make sure to always send and await ack for first update and last update + // if status == JobStatus::InProgress + // && progress_state.blocks_remaining != 0 + // && received_blocks != 0 + // { + // qos = QoS::AtMostOnce; + // } } - Jobs::update(file_ctx.job_name.as_str(), status) - .status_details(&file_ctx.status_details) - .send(self, qos)?; + // let mut sub = self + // .subscribe::<2>( + // Subscribe::builder() + // .topics(&[ + // SubscribeTopic::builder() + // .topic_path( + // JobTopic::UpdateAccepted(file_ctx.job_name.as_str()) + // .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + // self.client_id(), + // )? + // .as_str(), + // ) + // .build(), + // SubscribeTopic::builder() + // .topic_path( + // JobTopic::UpdateRejected(file_ctx.job_name.as_str()) + // .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 34 }>( + // self.client_id(), + // )? + // .as_str(), + // ) + // .build(), + // ]) + // .build(), + // ) + // .await?; - Ok(()) - } + let topic = JobTopic::Update(file_ctx.job_name.as_str()) + .format::<{ MAX_THING_NAME_LEN + MAX_JOB_ID_LEN + 25 }>(self.client_id())?; + let payload = DeferredPayload::new( + |buf| { + Jobs::update(status) + .client_token(self.client_id()) + .status_details(&progress_state.status_details) + .payload(buf) + .map_err(|_| EncodingError::BufferSize) + }, + 512, + ); + + debug!("Updating job status! {:?}", status); + + self.publish( + Publish::builder() + .qos(qos) + .topic_name(&topic) + .payload(payload) + .build(), + ) + .await?; - /// Perform any cleanup operations required for control plane - fn cleanup(&self) -> Result<(), OtaError> { - Jobs::unsubscribe::<1>() - .topic(Topic::NotifyNext) - .send(self)?; Ok(()) + + // loop { + // let message = match with_timeout( + // embassy_time::Duration::from_secs(1), + // sub.next_message(), + // ) + // .await + // { + // Ok(res) => res.ok_or(JobError::Encoding)?, + // Err(_) => return Err(OtaError::Timeout), + // }; + + // // Check if topic is GetAccepted + // match crate::jobs::Topic::from_str(message.topic_name()) { + // Some(crate::jobs::Topic::UpdateAccepted(_)) => { + // // Check client token + // let (response, _) = serde_json_core::from_slice::< + // UpdateJobExecutionResponse>, + // >(message.payload()) + // .map_err(|_| JobError::Encoding)?; + + // if response.client_token != Some(self.client_id()) { + // error!( + // "Unexpected client token received: {}, expected: {}", + // response.client_token.unwrap_or("None"), + // self.client_id() + // ); + // continue; + // } + + // return Ok(()); + // } + // Some(crate::jobs::Topic::UpdateRejected(_)) => { + // let (error_response, _) = + // serde_json_core::from_slice::(message.payload()) + // .map_err(|_| JobError::Encoding)?; + + // if error_response.client_token != Some(self.client_id()) { + // error!( + // "Unexpected client token received: {}, expected: {}", + // error_response.client_token.unwrap_or("None"), + // self.client_id() + // ); + // continue; + // } + + // error!("OTA Update rejected: {:?}", error_response.message); + + // return Err(OtaError::UpdateRejected(error_response.code)); + // } + // _ => { + // error!("Expected Topic name GetRejected or GetAccepted but got something else"); + // } + // } + // } } } diff --git a/src/ota/data_interface/http.rs b/src/ota/data_interface/http.rs index c1ba639..4b53ac9 100644 --- a/src/ota/data_interface/http.rs +++ b/src/ota/data_interface/http.rs @@ -20,23 +20,11 @@ impl DataInterface for HttpInterface { Ok(()) } - fn request_file_block( + fn request_file_blocks( &self, _file_ctx: &mut FileContext, _config: &Config, ) -> Result<(), OtaError> { Ok(()) } - - fn decode_file_block<'b>( - &self, - _file_ctx: &mut FileContext, - _payload: &'b mut [u8], - ) -> Result, OtaError> { - unimplemented!() - } - - fn cleanup(&self, _file_ctx: &mut FileContext, _config: &Config) -> Result<(), OtaError> { - Ok(()) - } } diff --git a/src/ota/data_interface/mod.rs b/src/ota/data_interface/mod.rs index ec02550..450ca8a 100644 --- a/src/ota/data_interface/mod.rs +++ b/src/ota/data_interface/mod.rs @@ -1,15 +1,18 @@ -#[cfg(feature = "ota_http_data")] -pub mod http; +// #[cfg(feature = "ota_http_data")] +// pub mod http; #[cfg(feature = "ota_mqtt_data")] pub mod mqtt; +use core::ops::DerefMut; + use serde::Deserialize; use crate::ota::config::Config; -use super::{encoding::FileContext, error::OtaError}; +use super::{encoding::FileContext, error::OtaError, ProgressState}; #[derive(Debug, Clone, PartialEq, Deserialize)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Protocol { #[serde(rename = "MQTT")] Mqtt, @@ -26,13 +29,13 @@ pub struct FileBlock<'a> { pub block_payload: &'a [u8], } -impl<'a> FileBlock<'a> { +impl FileBlock<'_> { /// Validate the block index and size. If it is NOT the last block, it MUST /// be equal to a full block size. If it IS the last block, it MUST be equal /// to the expected remainder. If the block ID is out of range, that's an /// error. pub fn validate(&self, block_size: usize, filesize: usize) -> bool { - let total_blocks = (filesize + block_size - 1) / block_size; + let total_blocks = filesize.div_ceil(block_size); let last_block_id = total_blocks - 1; (self.block_id < last_block_id && self.block_size == block_size) @@ -41,49 +44,28 @@ impl<'a> FileBlock<'a> { } } -pub trait DataInterface { - const PROTOCOL: Protocol; - - fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError>; - fn request_file_block( - &self, - file_ctx: &mut FileContext, - config: &Config, - ) -> Result<(), OtaError>; - fn decode_file_block<'a>( - &self, - file_ctx: &mut FileContext, - payload: &'a mut [u8], - ) -> Result, OtaError>; - fn cleanup(&self, file_ctx: &mut FileContext, config: &Config) -> Result<(), OtaError>; +pub trait BlockTransfer { + async fn next_block(&mut self) -> Result>, OtaError>; } -pub struct NoInterface; - -impl DataInterface for NoInterface { - const PROTOCOL: Protocol = Protocol::Mqtt; +pub trait DataInterface { + const PROTOCOL: Protocol; - fn init_file_transfer(&self, _file_ctx: &mut FileContext) -> Result<(), OtaError> { - unreachable!() - } + type ActiveTransfer<'t>: BlockTransfer + where + Self: 't; - fn request_file_block( + async fn init_file_transfer( &self, - _file_ctx: &mut FileContext, - _config: &Config, - ) -> Result<(), OtaError> { - unreachable!() - } + file_ctx: &FileContext, + ) -> Result, OtaError>; - fn decode_file_block<'a>( + async fn request_file_blocks( &self, - _file_ctx: &mut FileContext, - _payload: &'a mut [u8], - ) -> Result, OtaError> { - unreachable!() - } + file_ctx: &FileContext, + progress_state: &mut ProgressState, + config: &Config, + ) -> Result<(), OtaError>; - fn cleanup(&self, _file_ctx: &mut FileContext, _config: &Config) -> Result<(), OtaError> { - unreachable!() - } + fn decode_file_block<'a>(&self, payload: &'a mut [u8]) -> Result, OtaError>; } diff --git a/src/ota/data_interface/mqtt.rs b/src/ota/data_interface/mqtt.rs index 87e5999..865617e 100644 --- a/src/ota/data_interface/mqtt.rs +++ b/src/ota/data_interface/mqtt.rs @@ -1,9 +1,14 @@ use core::fmt::{Display, Write}; +use core::ops::DerefMut; use core::str::FromStr; -use mqttrust::{Mqtt, QoS, SubscribeTopic}; +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{ + DeferredPayload, EncodingError, MqttClient, Publish, Subscribe, SubscribeTopic, Subscription, +}; use crate::ota::error::OtaError; +use crate::ota::ProgressState; use crate::{ jobs::{MAX_STREAM_ID_LEN, MAX_THING_NAME_LEN}, ota::{ @@ -13,6 +18,8 @@ use crate::{ }, }; +use super::BlockTransfer; + #[derive(Debug, Clone, Copy, PartialEq)] pub enum Encoding { Cbor, @@ -50,7 +57,7 @@ pub enum Topic<'a> { impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(8, '/').collect::>(); - Some(match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { + Some(match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { (Some(&"$aws"), Some(&"things"), _, Some(&"streams")) => { // This is a stream topic! Figure out which match (tt.get(4), tt.get(5), tt.get(6), tt.get(7)) { @@ -89,7 +96,7 @@ enum OtaTopic<'a> { Get(Encoding, &'a str), } -impl<'a> OtaTopic<'a> { +impl OtaTopic<'_> { pub fn format(&self, client_id: &str) -> Result, OtaError> { let mut topic_path = heapless::String::new(); match self { @@ -116,295 +123,114 @@ impl<'a> OtaTopic<'a> { } } -impl<'a, M> DataInterface for &'a M -where - M: Mqtt, -{ +impl BlockTransfer for Subscription<'_, '_, M, 1> { + async fn next_block(&mut self) -> Result>, OtaError> { + let next = self.next_message().await; + if next.is_none() { + warn!("[OTA] Data stream ended (subscription closed due to clean session/disconnect)"); + } + Ok(next) + } +} + +impl<'a, M: RawMutex> DataInterface for MqttClient<'a, M> { const PROTOCOL: Protocol = Protocol::Mqtt; + type ActiveTransfer<'t> + = Subscription<'a, 't, M, 1> + where + Self: 't; + /// Init file transfer by subscribing to the OTA data stream topic - fn init_file_transfer(&self, file_ctx: &mut FileContext) -> Result<(), OtaError> { + async fn init_file_transfer( + &self, + file_ctx: &FileContext, + ) -> Result, OtaError> { let topic_path = OtaTopic::Data(Encoding::Cbor, file_ctx.stream_name.as_str()) .format::<256>(self.client_id())?; - let topic = SubscribeTopic { - topic_path: topic_path.as_str(), - qos: mqttrust::QoS::AtLeastOnce, - }; + + let topics = [SubscribeTopic::builder() + .topic_path(topic_path.as_str()) + .build()]; debug!("Subscribing to: [{:?}]", &topic_path); - self.subscribe(&[topic])?; + let sub = self + .subscribe::<1>(Subscribe::builder().topics(&topics).build()) + .await?; - Ok(()) + info!( + "[OTA] Subscribed to data stream {}", + file_ctx.stream_name.as_str() + ); + + Ok(sub) } /// Request file block by publishing to the get stream topic - fn request_file_block( + async fn request_file_blocks( &self, - file_ctx: &mut FileContext, + file_ctx: &FileContext, + progress_state: &mut ProgressState, config: &Config, ) -> Result<(), OtaError> { - // Reset number of blocks requested - file_ctx.request_block_remaining = file_ctx.bitmap.len() as u32; - - let buf = &mut [0u8; 32]; - let len = cbor::to_slice( - &cbor::GetStreamRequest { - // Arbitrary client token sent in the stream "GET" message - client_token: None, - stream_version: None, - file_id: file_ctx.fileid, - block_size: config.block_size, - block_offset: Some(file_ctx.block_offset), - block_bitmap: Some(&file_ctx.bitmap), - number_of_blocks: None, + progress_state.request_block_remaining = progress_state.bitmap.len() as u32; + + let payload = DeferredPayload::new( + |buf| { + cbor::to_slice( + &cbor::GetStreamRequest { + // Arbitrary client token sent in the stream "GET" message + client_token: None, + stream_version: None, + file_id: file_ctx.fileid, + block_size: config.block_size, + block_offset: Some(progress_state.block_offset), + block_bitmap: Some(&progress_state.bitmap), + number_of_blocks: Some(progress_state.request_block_remaining), + }, + buf, + ) + .map_err(|_| EncodingError::BufferSize) }, - buf, - ) - .map_err(|_| OtaError::Encoding)?; + 32, + ); + debug!( + "Requesting more file blocks. Remaining: {}", + progress_state.request_block_remaining + ); + info!( + "[OTA] Requesting blocks stream={} offset={} bitmap_len={} blocks_remaining={}", + file_ctx.stream_name.as_str(), + progress_state.block_offset, + progress_state.bitmap.len(), + progress_state.blocks_remaining + ); self.publish( - OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) - .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>(self.client_id())? - .as_str(), - &buf[..len], - QoS::AtMostOnce, - )?; + Publish::builder() + .topic_name( + OtaTopic::Get(Encoding::Cbor, file_ctx.stream_name.as_str()) + .format::<{ MAX_STREAM_ID_LEN + MAX_THING_NAME_LEN + 30 }>( + self.client_id(), + )? + .as_str(), + ) + // .qos(embedded_mqtt::QoS::AtMostOnce) + .payload(payload) + .build(), + ) + .await?; Ok(()) } /// Decode a cbor encoded fileblock received from streaming service - fn decode_file_block<'c>( - &self, - _file_ctx: &mut FileContext, - payload: &'c mut [u8], - ) -> Result, OtaError> { + fn decode_file_block<'c>(&self, payload: &'c mut [u8]) -> Result, OtaError> { Ok( - serde_cbor::de::from_mut_slice::(payload) + minicbor_serde::from_slice::(payload) .map_err(|_| OtaError::Encoding)? .into(), ) } - - /// Perform any cleanup operations required for data plane - fn cleanup(&self, file_ctx: &mut FileContext, config: &Config) -> Result<(), OtaError> { - if config.unsubscribe_on_shutdown { - // Unsubscribe from data stream topics - self.unsubscribe(&[ - OtaTopic::Data(Encoding::Cbor, file_ctx.stream_name.as_str()) - .format::<256>(self.client_id())? - .as_str(), - ])?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use mqttrust::{encoding::v4::decode_slice, Packet, SubscribeTopic}; - - use super::*; - use crate::{ota::test::test_file_ctx, test::MockMqtt}; - - #[test] - fn protocol_fits() { - assert_eq!(<&MockMqtt as DataInterface>::PROTOCOL, Protocol::Mqtt); - } - - #[test] - fn init_file_transfer_subscribes() { - let mqtt = &MockMqtt::new(); - - let mut file_ctx = test_file_ctx(&Config::default()); - - mqtt.init_file_transfer(&mut file_ctx).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - assert_eq!( - topics, - vec![SubscribeTopic { - topic_path: "$aws/things/test_client/streams/test_stream/data/cbor", - qos: QoS::AtLeastOnce - }] - ); - } - - #[test] - fn request_file_block_publish() { - let mqtt = &MockMqtt::new(); - - let config = Config::default(); - let mut file_ctx = test_file_ctx(&config); - - mqtt.request_file_block(&mut file_ctx, &config).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let publish = match decode_slice(bytes.as_slice()).unwrap() { - Some(Packet::Publish(s)) => s, - _ => panic!(), - }; - - assert_eq!( - publish, - mqttrust::encoding::v4::publish::Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic_name: "$aws/things/test_client/streams/test_stream/get/cbor", - payload: &[ - 164, 97, 102, 0, 97, 108, 25, 1, 0, 97, 111, 0, 97, 98, 68, 255, 255, 255, 127 - ], - pid: None - } - ); - } - - #[test] - fn decode_file_block_cbor() { - let mqtt = &MockMqtt::new(); - - let mut file_ctx = test_file_ctx(&Config::default()); - - let payload = &mut [ - 191, 97, 102, 0, 97, 105, 0, 97, 108, 25, 4, 0, 97, 112, 89, 4, 0, 141, 62, 28, 246, - 80, 193, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 255, - ]; - - let file_blk = mqtt.decode_file_block(&mut file_ctx, payload).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - assert_eq!(file_blk.file_id, 0); - assert_eq!(file_blk.block_id, 0); - assert_eq!( - file_blk.block_payload, - &[ - 141, 62, 28, 246, 80, 193, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ] - ); - assert_eq!(file_blk.block_size, 1024); - assert_eq!(file_blk.client_token, None); - } - - #[test] - fn cleanup_unsubscribe() { - let mqtt = &MockMqtt::new(); - - let config = Config::default(); - - let mut file_ctx = test_file_ctx(&config); - - mqtt.cleanup(&mut file_ctx, &config).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Unsubscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec!["$aws/things/test_client/streams/test_stream/data/cbor"] - ); - } - - #[test] - fn cleanup_no_unsubscribe() { - let mqtt = &MockMqtt::new(); - - let mut config = Config::default(); - config.unsubscribe_on_shutdown = false; - - let mut file_ctx = test_file_ctx(&config); - - mqtt.cleanup(&mut file_ctx, &config).unwrap(); - - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } } diff --git a/src/ota/encoding/cbor.rs b/src/ota/encoding/cbor.rs index a30f45e..b9b1ffc 100644 --- a/src/ota/encoding/cbor.rs +++ b/src/ota/encoding/cbor.rs @@ -76,9 +76,10 @@ pub fn to_slice(value: &T, slice: &mut [u8]) -> Result where T: serde::ser::Serialize, { - let mut serializer = serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(slice)); + let mut serializer = + minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new(slice)); value.serialize(&mut serializer).map_err(|_| ())?; - Ok(serializer.into_inner().bytes_written()) + Ok(serializer.into_encoder().writer().position()) } impl<'a> From> for FileBlock<'a> { @@ -170,7 +171,7 @@ mod test { 0, 0, 0, 0, 0, 0, 255, ]; - let response: GetStreamResponse = serde_cbor::de::from_mut_slice(payload).unwrap(); + let response: GetStreamResponse = minicbor_serde::from_slice(payload).unwrap(); assert_eq!( response, @@ -261,14 +262,14 @@ mod test { // Check the last request (All requests in between will have same bitmap as first request, with different block_offset) { - let bitmap = Bitmap::new(file_size, BLOCK_SIZE, block_offset as u32); + let bitmap = Bitmap::new(file_size, BLOCK_SIZE, block_offset); let req = GetStreamRequest { client_token: Some("rdy"), stream_version: None, file_id: 0, block_size: BLOCK_SIZE, - block_offset: Some(block_offset as u32), + block_offset: Some(block_offset), number_of_blocks: None, block_bitmap: Some(&bitmap), }; diff --git a/src/ota/encoding/json.rs b/src/ota/encoding/json.rs index ae08e4d..cfa7621 100644 --- a/src/ota/encoding/json.rs +++ b/src/ota/encoding/json.rs @@ -32,7 +32,8 @@ pub struct FileDescription<'a> { #[serde(rename = "fileid")] pub fileid: u8, #[serde(rename = "certfile")] - pub certfile: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + pub certfile: Option<&'a str>, #[serde(rename = "update_data_url")] #[serde(skip_serializing_if = "Option::is_none")] pub update_data_url: Option<&'a str>, @@ -58,21 +59,27 @@ pub struct FileDescription<'a> { pub file_type: Option, } -impl<'a> FileDescription<'a> { - pub fn signature(&self) -> Signature { +impl FileDescription<'_> { + pub fn signature(&self) -> Option { if let Some(sig) = self.sha1_rsa { - return Signature::Sha1Rsa(heapless::String::from(sig)); + return Some(Signature::Sha1Rsa(heapless::String::try_from(sig).unwrap())); } if let Some(sig) = self.sha256_rsa { - return Signature::Sha256Rsa(heapless::String::from(sig)); + return Some(Signature::Sha256Rsa( + heapless::String::try_from(sig).unwrap(), + )); } if let Some(sig) = self.sha1_ecdsa { - return Signature::Sha1Ecdsa(heapless::String::from(sig)); + return Some(Signature::Sha1Ecdsa( + heapless::String::try_from(sig).unwrap(), + )); } if let Some(sig) = self.sha256_ecdsa { - return Signature::Sha256Ecdsa(heapless::String::from(sig)); + return Some(Signature::Sha256Ecdsa( + heapless::String::try_from(sig).unwrap(), + )); } - unreachable!() + None } } @@ -132,6 +139,7 @@ mod tests { (JobStatusReason::Accepted, "accepted"), (JobStatusReason::Rejected, "rejected"), (JobStatusReason::Aborted, "aborted"), + (JobStatusReason::Pal(123), "pal err"), ]; for (reason, exp) in reasons { @@ -147,4 +155,28 @@ mod tests { ); } } + + #[test] + fn deserializ() { + let data = r#"{ + "protocols": [ + "MQTT" + ], + "streamname": "AFR_OTA-d11032e9-38d5-4dca-8c7c-1e6f24533ede", + "files": [ + { + "filepath": "3.8.4", + "filesize": 537600, + "fileid": 0, + "certfile": null, + "fileType": 0, + "update_data_url": null, + "auth_scheme": null, + "sig--": null + } + ] + }"#; + + serde_json_core::from_str::(data).unwrap(); + } } diff --git a/src/ota/encoding/mod.rs b/src/ota/encoding/mod.rs index 257a1ea..eabf039 100644 --- a/src/ota/encoding/mod.rs +++ b/src/ota/encoding/mod.rs @@ -3,23 +3,24 @@ pub mod cbor; pub mod json; use core::ops::{Deref, DerefMut}; -use core::str::FromStr; use serde::{Serialize, Serializer}; use crate::jobs::StatusDetailsOwned; -use self::json::{JobStatusReason, OtaJob, Signature}; +use self::json::{JobStatusReason, Signature}; +use super::config::Config; +use super::data_interface::Protocol; use super::error::OtaError; -use super::{config::Config, pal::Version}; +use super::JobEventData; -#[derive(Clone, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct Bitmap(bitmaps::Bitmap<32>); impl Bitmap { pub fn new(file_size: usize, block_size: usize, block_offset: u32) -> Self { // Total number of blocks in file, rounded up - let total_num_blocks = (file_size + block_size - 1) / block_size; + let total_num_blocks = file_size.div_ceil(block_size); Self(bitmaps::Bitmap::mask(core::cmp::min( 32 - 1, @@ -59,11 +60,12 @@ pub struct FileContext { pub filepath: heapless::String<64>, pub filesize: usize, pub fileid: u8, - pub certfile: heapless::String<64>, + pub certfile: Option>, pub update_data_url: Option>, pub auth_scheme: Option>, - pub signature: Signature, + pub signature: Option, pub file_type: Option, + pub protocols: heapless::Vec, pub status_details: StatusDetailsOwned, pub block_offset: u32, @@ -76,62 +78,76 @@ pub struct FileContext { impl FileContext { pub fn new_from( - job_name: &str, - ota_job: &OtaJob, - status_details: Option, + job_data: JobEventData<'_>, file_idx: usize, config: &Config, - current_version: Version, ) -> Result { - let file_desc = ota_job + if job_data + .ota_document + .files + .get(file_idx) + .map(|f| f.filesize) + .unwrap_or_default() + == 0 + { + return Err(OtaError::ZeroFileSize); + } + + let file_desc = job_data + .ota_document .files .get(file_idx) .ok_or(OtaError::InvalidFile)? .clone(); - // Initialize new `status_details' if not already present - let status = if let Some(details) = status_details { - details - } else { - let mut status = StatusDetailsOwned::new(); - status - .insert( - heapless::String::from("updated_by"), - current_version.to_string(), - ) - .map_err(|_| OtaError::Overflow)?; - status - }; - let signature = file_desc.signature(); let block_offset = 0; let bitmap = Bitmap::new(file_desc.filesize, config.block_size, block_offset); Ok(FileContext { - filepath: heapless::String::from(file_desc.filepath), + filepath: heapless::String::try_from(file_desc.filepath).unwrap(), filesize: file_desc.filesize, + protocols: job_data.ota_document.protocols, fileid: file_desc.fileid, - certfile: heapless::String::from(file_desc.certfile), - update_data_url: file_desc.update_data_url.map(heapless::String::from), - auth_scheme: file_desc.auth_scheme.map(heapless::String::from), + certfile: file_desc + .certfile + .map(|cert| heapless::String::try_from(cert).unwrap()), + update_data_url: file_desc + .update_data_url + .map(|s| heapless::String::try_from(s).unwrap()), + auth_scheme: file_desc + .auth_scheme + .map(|s| heapless::String::try_from(s).unwrap()), signature, file_type: file_desc.file_type, - status_details: status, - - job_name: heapless::String::from(job_name), + status_details: job_data + .status_details + .map(|s| { + s.iter() + .map(|(&k, &v)| { + ( + heapless::String::try_from(k).unwrap(), + heapless::String::try_from(v).unwrap(), + ) + }) + .collect() + }) + .unwrap_or_default(), + + job_name: heapless::String::try_from(job_data.job_name).unwrap(), block_offset, request_block_remaining: bitmap.len() as u32, - blocks_remaining: (file_desc.filesize + config.block_size - 1) / config.block_size, - stream_name: heapless::String::from(ota_job.streamname), + blocks_remaining: file_desc.filesize.div_ceil(config.block_size), + stream_name: heapless::String::try_from(job_data.ota_document.streamname).unwrap(), bitmap, }) } pub fn self_test(&self) -> bool { self.status_details - .get(&heapless::String::from("self_test")) + .get(&heapless::String::try_from("self_test").unwrap()) .and_then(|f| f.parse().ok()) .map(|reason: JobStatusReason| { reason == JobStatusReason::SigCheckPassed @@ -139,12 +155,6 @@ impl FileContext { }) .unwrap_or(false) } - - pub fn updated_by(&self) -> Option { - self.status_details - .get(&heapless::String::from("updated_by")) - .and_then(|s| Version::from_str(s.as_str()).ok()) - } } #[cfg(test)] @@ -156,6 +166,6 @@ mod tests { let bitmap = Bitmap::new(255000, 256, 0); let true_indices: Vec = bitmap.into_iter().collect(); - assert_eq!((0..31).into_iter().collect::>(), true_indices); + assert_eq!((0..31).collect::>(), true_indices); } } diff --git a/src/ota/error.rs b/src/ota/error.rs index ad533c4..4dd2966 100644 --- a/src/ota/error.rs +++ b/src/ota/error.rs @@ -1,12 +1,11 @@ -use crate::jobs::JobError; +use crate::jobs::{data_types::ErrorCode, JobError}; use super::pal::OtaPalError; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum OtaError { NoActiveJob, - SignalEventFailed, Momentum, MomentumAbort, InvalidInterface, @@ -14,27 +13,34 @@ pub enum OtaError { BlockOutOfRange, ZeroFileSize, Overflow, + DataStreamEnded, + UnexpectedTopic, InvalidFile, - Mqtt(mqttrust::MqttError), + UpdateRejected(ErrorCode), + Write( + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] + embedded_storage_async::nor_flash::NorFlashErrorKind, + ), + Mqtt(embedded_mqtt::Error), Encoding, Pal, - Timer, + Timeout, } impl OtaError { pub fn is_retryable(&self) -> bool { - matches!(self, Self::Encoding) + matches!(self, Self::Encoding | Self::Timeout) } } -impl From for OtaError { - fn from(e: mqttrust::MqttError) -> Self { +impl From for OtaError { + fn from(e: embedded_mqtt::Error) -> Self { Self::Mqtt(e) } } -impl From> for OtaError { - fn from(_e: OtaPalError) -> Self { +impl From for OtaError { + fn from(_e: OtaPalError) -> Self { Self::Pal } } @@ -44,7 +50,7 @@ impl From for OtaError { match e { JobError::Overflow => Self::Overflow, JobError::Encoding => Self::Encoding, - JobError::Mqtt(m) => Self::Mqtt(m), + JobError::Mqtt(e) => Self::Mqtt(e), } } } diff --git a/src/ota/mod.rs b/src/ota/mod.rs index a627e88..86ebc93 100644 --- a/src/ota/mod.rs +++ b/src/ota/mod.rs @@ -1,47 +1,629 @@ -//! ## Over-the-air (OTA) flashing of firmware -//! -//! AWS IoT OTA works by using AWS IoT Jobs to manage firmware transfer and -//! status reporting of OTA. -//! -//! The OTA Jobs API makes use of the following special MQTT Topics: -//! - $aws/things/{thing_name}/jobs/$next/get/accepted -//! - $aws/things/{thing_name}/jobs/notify-next -//! - $aws/things/{thing_name}/jobs/$next/get -//! - $aws/things/{thing_name}/jobs/{job_id}/update -//! - $aws/things/{thing_name}/streams/{stream_id}/data/cbor -//! - $aws/things/{thing_name}/streams/{stream_id}/get/cbor -//! -//! Most of the data structures for the Jobs API has been copied from Rusoto: -//! -//! -//! ### OTA Flow: -//! 1. Device subscribes to notification topics for AWS IoT jobs and listens for -//! update messages. -//! 2. When an update is available, the OTA agent publishes requests to AWS IoT -//! and receives updates using the HTTP or MQTT protocol, depending on the -//! settings you chose. -//! 3. The OTA agent checks the digital signature of the downloaded files and, -//! if the files are valid, installs the firmware update to the appropriate -//! flash bank. -//! -//! The OTA depends on working, and correctly setup: -//! - Bootloader -//! - MQTT Client -//! - Code sign verification -//! - CBOR deserializer - -pub mod agent; -pub mod builder; pub mod config; pub mod control_interface; pub mod data_interface; pub mod encoding; pub mod error; pub mod pal; -pub mod state; + +use core::ops::DerefMut as _; #[cfg(feature = "ota_mqtt_data")] pub use data_interface::mqtt::{Encoding, Topic}; +use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex, signal::Signal}; +use embedded_storage_async::nor_flash::{NorFlash, NorFlashError as _}; + +use crate::{ + jobs::{data_types::JobStatus, StatusDetailsOwned}, + ota::{data_interface::BlockTransfer, encoding::json::JobStatusReason}, +}; + +use self::{ + control_interface::ControlInterface, + data_interface::DataInterface, + encoding::{Bitmap, FileContext}, + pal::{ImageState, ImageStateReason}, +}; + +#[derive(PartialEq)] +pub struct JobEventData<'a> { + pub job_name: &'a str, + pub ota_document: encoding::json::OtaJob<'a>, + pub status_details: Option>, +} + +pub struct Updater; + +impl Updater { + pub async fn check_for_job<'a, C: ControlInterface>( + control: &C, + ) -> Result<(), error::OtaError> { + control.request_job().await?; + Ok(()) + } + + pub async fn perform_ota<'a, 'b, C: ControlInterface, D: DataInterface>( + control: &C, + data: &D, + file_ctx: FileContext, + pal: &mut impl pal::OtaPal, + config: &config::Config, + ) -> Result<(), error::OtaError> { + info!( + "[OTA] Starting perform_ota for job={} stream={} size={}", + file_ctx.job_name, file_ctx.stream_name, file_ctx.filesize + ); + let progress_state = Mutex::new(ProgressState { + total_blocks: file_ctx.filesize.div_ceil(config.block_size), + blocks_remaining: file_ctx.filesize.div_ceil(config.block_size), + block_offset: file_ctx.block_offset, + request_block_remaining: file_ctx.bitmap.len() as u32, + bitmap: file_ctx.bitmap.clone(), + file_size: file_ctx.filesize, + request_momentum: None, + status_details: file_ctx.status_details.clone(), + }); + + // Create the JobUpdater + let mut job_updater = JobUpdater::new(&file_ctx, &progress_state, config, control); + + match job_updater.initialize::(pal).await? { + Some(()) => {} + None => return Ok(()), + }; + + info!("Job document was accepted. Attempting to begin the update"); + + // Spawn the request momentum future + let momentum_fut = Self::handle_momentum(data, config, &file_ctx, &progress_state); + + // Spawn the status update future + let status_update_fut = job_updater.handle_status_updates(); + + // Spawn the data handling future + let data_fut = async { + // Create/Open the OTA file on the file system + let mut block_writer = match pal.create_file_for_rx(&file_ctx).await { + Ok(block_writer) => block_writer, + Err(e) => { + job_updater + .set_image_state_with_reason( + pal, + ImageState::Aborted(ImageStateReason::Pal(e)), + ) + .await?; + + pal.close_file(&file_ctx).await?; + return Err(e.into()); + } + }; + + info!("Initialized file handler! Requesting file blocks"); + + // Outer loop to handle resubscription on clean session + loop { + // Prepare the storage layer on receiving a new file + let mut subscription = data.init_file_transfer(&file_ctx).await?; + + { + let mut progress = progress_state.lock().await; + data.request_file_blocks(&file_ctx, &mut progress, config) + .await?; + } + + info!("Awaiting file blocks!"); + + // Inner loop to process blocks + loop { + // Select over the futures + match subscription.next_block().await { + Ok(Some(mut payload)) => { + // Decode the file block received + let mut progress = progress_state.lock().await; + + match Self::ingest_data_block( + data, + &mut block_writer, + config, + &mut progress, + payload.deref_mut(), + ) + .await + { + Ok(true) => { + // ... (Handle end of file) ... + match pal.close_file(&file_ctx).await { + Err(e) => { + // FIXME: This seems like duplicate status update, as it will also report during cleanup + // job_updater.signal_update( + // JobStatus::Failed, + // JobStatusReason::Pal(0), + // ); + + return Err(e.into()); + } + Ok(_) if file_ctx.file_type == Some(0) => { + job_updater.signal_update( + JobStatus::InProgress, + JobStatusReason::SigCheckPassed, + ); + return Ok(()); + } + Ok(_) => { + job_updater.signal_update( + JobStatus::Succeeded, + JobStatusReason::Accepted, + ); + return Ok(()); + } + } + } + Ok(false) => { + // ... (Handle successful block processing) ... + progress.request_momentum = Some(0); + + // Update the job status to reflect the download progress + if progress.blocks_remaining + % config.status_update_frequency as usize + == 0 + { + job_updater.signal_update( + JobStatus::InProgress, + JobStatusReason::Receiving, + ); + } + + if progress.request_block_remaining > 1 { + progress.request_block_remaining -= 1; + } else { + data.request_file_blocks(&file_ctx, &mut progress, config) + .await?; + } + } + Err(e) if e.is_retryable() => { + // ... (Handle retryable errors) ... + error!("Failed block validation: {:?}! Retrying", e); + } + Err(e) => { + // ... (Handle fatal errors) ... + return Err(e); + } + } + } + Ok(None) => { + warn!("[OTA] Data stream subscription ended (clean session/disconnect). Resubscribing and resuming..."); + + let blocks_remaining = { + let progress = progress_state.lock().await; + progress.blocks_remaining + }; + + info!("[OTA] Resuming OTA: {} blocks remaining", blocks_remaining); + + // Break inner loop to trigger resubscription in outer loop + break; + } + + // Handle status update future results + Err(e) => { + error!("Status update error: {:?}", e); + return Err(e); + } + } + } // End of inner block processing loop + } // End of outer resubscribe loop + }; + + let (data_res, _) = embassy_futures::join::join( + data_fut, + embassy_futures::select::select(status_update_fut, momentum_fut), + ) + .await; + + // Cleanup and update the job status accordingly + match data_res { + Ok(()) => { + let event = if let Some(0) = file_ctx.file_type { + pal::OtaEvent::Activate + } else { + pal::OtaEvent::UpdateComplete + }; + + info!( + "OTA Download finished! Running complete callback: {:?}", + event + ); + + pal.complete_callback(event).await?; + + Ok(()) + } + Err(error::OtaError::MomentumAbort) => { + warn!("[OTA] Momentum abort triggered"); + job_updater + .set_image_state_with_reason( + pal, + ImageState::Aborted(ImageStateReason::MomentumAbort), + ) + .await?; + + Err(error::OtaError::MomentumAbort) + } + Err(e) => { + // Signal the error status + job_updater + .update_job_status(JobStatus::Failed, JobStatusReason::Pal(0)) + .await?; + + pal.complete_callback(pal::OtaEvent::Fail).await?; + info!("Application callback! OtaEvent::Fail"); + + Err(e) + } + } + } + + async fn ingest_data_block<'a, D: DataInterface>( + data: &D, + block_writer: &mut impl NorFlash, + config: &config::Config, + progress: &mut ProgressState, + payload: &mut [u8], + ) -> Result { + let block = data.decode_file_block(payload)?; + + if block.validate(config.block_size, progress.file_size) { + if block.block_id < progress.block_offset as usize + || !progress + .bitmap + .get(block.block_id - progress.block_offset as usize) + { + info!( + "Block {:?} is a DUPLICATE. {:?} blocks remaining.", + block.block_id, progress.blocks_remaining + ); + + // Just return same progress as before + return Ok(false); + } + + info!( + "Received block {}. {:?} blocks remaining.", + block.block_id, progress.blocks_remaining + ); + + block_writer + .write( + (block.block_id * config.block_size) as u32, + block.block_payload, + ) + .await + .map_err(|e| error::OtaError::Write(e.kind()))?; + + let block_offset = progress.block_offset; + progress + .bitmap + .set(block.block_id - block_offset as usize, false); + + progress.blocks_remaining -= 1; + + if progress.blocks_remaining == 0 { + info!("Received final expected block of file."); + + // Return true to indicate end of file. + Ok(true) + } else { + if progress.bitmap.is_empty() { + progress.block_offset += 31; + progress.bitmap = encoding::Bitmap::new( + progress.file_size, + config.block_size, + progress.block_offset, + ); + } + + Ok(false) + } + } else { + error!( + "Error! Block {:?} out of expected range! Size {:?}", + block.block_id, block.block_size + ); + + Err(error::OtaError::BlockOutOfRange) + } + } + + async fn handle_momentum( + data: &D, + config: &config::Config, + file_ctx: &FileContext, + progress_state: &Mutex, + ) -> Result<(), error::OtaError> { + loop { + embassy_time::Timer::after(config.request_wait).await; + + let mut progress = progress_state.lock().await; + + if progress.blocks_remaining == 0 { + // No more blocks to request + break; + } + + let Some(request_momentum) = &mut progress.request_momentum else { + continue; + }; + + // Increment momentum + *request_momentum += 1; + + if *request_momentum == 1 { + continue; + } + + if *request_momentum <= config.max_request_momentum { + warn!("Momentum requesting more blocks!"); + + // Request data blocks + data.request_file_blocks(file_ctx, &mut progress, config) + .await?; + } else { + // Too much momentum, abort + return Err(error::OtaError::MomentumAbort); + } + } + + Ok(()) + } +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ProgressState { + pub total_blocks: usize, + pub blocks_remaining: usize, + pub file_size: usize, + pub block_offset: u32, + pub request_block_remaining: u32, + pub request_momentum: Option, + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] + pub bitmap: Bitmap, + #[cfg_attr(feature = "defmt", defmt(Debug2Format))] + pub status_details: StatusDetailsOwned, +} + +pub struct JobUpdater<'a, C: ControlInterface> { + pub file_ctx: &'a FileContext, + pub progress_state: &'a Mutex, + pub config: &'a config::Config, + pub control: &'a C, + pub status_update_signal: Signal, +} + +impl<'a, C: ControlInterface> JobUpdater<'a, C> { + pub fn new( + file_ctx: &'a FileContext, + progress_state: &'a Mutex, + config: &'a config::Config, + control: &'a C, + ) -> Self { + Self { + file_ctx, + progress_state, + config, + control, + status_update_signal: Signal::::new(), + } + } + + async fn initialize( + &mut self, + pal: &mut PAL, + ) -> Result, error::OtaError> { + // If the job is in self test mode, don't start an OTA update but + // instead do the following: + // + // If the firmware that performed the update was older than the + // currently running firmware, set the image state to "Testing." This is + // the success path. + // + // If it's the same or newer, reject the job since either the firmware + // was not accepted during self test or an incorrect image was sent by + // the OTA operator. + let platform_self_test = pal + .get_platform_image_state() + .await + .is_ok_and(|i| i == pal::PalImageState::PendingCommit); + + match (self.file_ctx.self_test(), platform_self_test) { + (true, true) => { + // Run self-test! + self.set_image_state_with_reason( + pal, + ImageState::Testing(ImageStateReason::VersionCheck), + ) + .await?; + + info!("Beginning self-test"); + + let test_fut = pal.complete_callback(pal::OtaEvent::StartTest); + + match self.config.self_test_timeout { + Some(timeout) => embassy_time::with_timeout(timeout, test_fut) + .await + .map_err(|_| error::OtaError::Timeout)?, + None => test_fut.await, + }?; + + let mut progress = self.progress_state.lock().await; + self.control + .update_job_status( + self.file_ctx, + &mut progress, + JobStatus::Succeeded, + JobStatusReason::Accepted, + ) + .await?; + + return Ok(None); + } + (false, false) => {} + (false, true) => { + // Received a job that is not in self-test but platform is, so + // reboot the device to allow roll back to previous image. + error!("Rejecting new image and rebooting: The platform is in the self-test state while the job is not."); + pal.reset_device().await?; + return Err(error::OtaError::ResetFailed); + } + (true, false) => { + // The job is in self test but the platform image state is not so it + // could be an attack on the platform image state. Reject the update + // (this should also cause the image to be erased), aborting the job + // and reset the device. + error!("Rejecting new image and rebooting: the job is in the self-test state while the platform is not."); + self.set_image_state_with_reason( + pal, + ImageState::Rejected(ImageStateReason::ImageStateMismatch), + ) + .await?; + + pal.reset_device().await?; + return Err(error::OtaError::ResetFailed); + } + } + + if !self.file_ctx.protocols.contains(&D::PROTOCOL) { + error!("Unable to handle current OTA job with given data interface ({:?}). Supported protocols: {:?}. Aborting current update.", D::PROTOCOL, self.file_ctx.protocols); + self.set_image_state_with_reason( + pal, + ImageState::Aborted(ImageStateReason::InvalidDataProtocol), + ) + .await?; + return Err(error::OtaError::InvalidInterface); + } + + Ok(Some(())) + } + + async fn handle_status_updates(&self) -> Result<(), error::OtaError> { + loop { + // Wait for a signal from the main loop + let (status, reason) = self.status_update_signal.wait().await; + + // Update the job status based on the signal + let mut progress = self.progress_state.lock().await; + self.control + .update_job_status(self.file_ctx, &mut progress, status, reason) + .await?; + + match status { + JobStatus::Queued | JobStatus::InProgress => {} + _ => return Ok(()), + } + } + } + + async fn set_image_state_with_reason( + &self, + pal: &mut PAL, + image_state: ImageState, + ) -> Result<(), error::OtaError> { + // Call the platform specific code to set the image state + let image_state = match pal.set_platform_image_state(image_state).await { + Err(e) if !matches!(image_state, ImageState::Aborted(_)) => { + // If the platform image state couldn't be set correctly, force + // fail the update by setting the image state to "Rejected" + // unless it's already in "Aborted". + + // Capture the failure reason if not already set (and we're not + // already Aborted as checked above). Otherwise Keep the + // original reject reason code since it is possible for the PAL + // to fail to update the image state in some cases (e.g. a reset + // already caused the bundle rollback and we failed to rollback + // again). + + // Intentionally override reason since we failed within this + // function + ImageState::Rejected(ImageStateReason::Pal(e)) + } + _ => image_state, + }; + + // Now update the image state and job status on server side + let mut progress = self.progress_state.lock().await; + + match image_state { + ImageState::Testing(_) => { + // We discovered we're ready for test mode, put job status + // in self_test active + self.control + .update_job_status( + self.file_ctx, + &mut progress, + JobStatus::InProgress, + JobStatusReason::SelfTestActive, + ) + .await?; + } + ImageState::Accepted => { + // Now that we have accepted the firmware update, we can + // complete the job + self.control + .update_job_status( + self.file_ctx, + &mut progress, + JobStatus::Succeeded, + JobStatusReason::Accepted, + ) + .await?; + } + ImageState::Rejected(_) => { + // The firmware update was rejected, complete the job as + // FAILED (Job service will not allow us to set REJECTED + // after the job has been started already). + + self.control + .update_job_status( + self.file_ctx, + &mut progress, + JobStatus::Failed, + JobStatusReason::Rejected, + ) + .await?; + } + _ => { + // The firmware update was aborted, complete the job as + // FAILED (Job service will not allow us to set REJECTED + // after the job has been started already). + + self.control + .update_job_status( + self.file_ctx, + &mut progress, + JobStatus::Failed, + JobStatusReason::Aborted, + ) + .await?; + } + } + Ok(()) + } + + // Function to signal the status update future + pub fn signal_update(&self, status: JobStatus, reason: JobStatusReason) { + self.status_update_signal.signal((status, reason)); + } + + // Function to update the job status + pub async fn update_job_status( + &mut self, + status: JobStatus, + reason: JobStatusReason, + ) -> Result<(), error::OtaError> { + let mut progress = self.progress_state.lock().await; -#[cfg(test)] -pub mod test; + self.control + .update_job_status(self.file_ctx, &mut progress, status, reason) + .await?; + Ok(()) + } +} diff --git a/src/ota/pal.rs b/src/ota/pal.rs index 2163a5d..7764e50 100644 --- a/src/ota/pal.rs +++ b/src/ota/pal.rs @@ -1,23 +1,35 @@ //! Platform abstraction trait for OTA updates - -use core::fmt::Write; -use core::str::FromStr; +use embedded_storage_async::nor_flash::NorFlash; use super::encoding::FileContext; -use super::state::ImageStateReason; +#[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum ImageState { +pub enum ImageStateReason { + NewerJob, + FailedIngest, + MomentumAbort, + ImageStateMismatch, + SignatureCheckPassed, + InvalidDataProtocol, + UserAbort, + VersionCheck, + Pal(OtaPalError), +} + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ImageState { Unknown, - Aborted(ImageStateReason), - Rejected(ImageStateReason), + Aborted(ImageStateReason), + Rejected(ImageStateReason), Accepted, - Testing(ImageStateReason), + Testing(ImageStateReason), } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum OtaPalError { +pub enum OtaPalError { SignatureCheckFailed, FileWriteFailed, FileTooLarge, @@ -27,13 +39,7 @@ pub enum OtaPalError { BadImageState, CommitFailed, VersionCheck, - Custom(E), -} - -impl From for OtaPalError { - fn from(value: E) -> Self { - Self::Custom(value) - } + Other, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -62,92 +68,9 @@ pub enum OtaEvent { UpdateComplete, } -#[derive(Debug, Clone, Eq)] -pub struct Version { - major: u8, - minor: u8, - patch: u8, -} - -#[cfg(feature = "defmt")] -impl defmt::Format for Version { - fn format(&self, fmt: defmt::Formatter) { - defmt::write!(fmt, "{=u8}.{=u8}.{=u8}", self.major, self.minor, self.patch) - } -} - -impl Default for Version { - fn default() -> Self { - Self::new(0, 0, 0) - } -} - -impl FromStr for Version { - type Err = (); - - fn from_str(s: &str) -> Result { - let mut iter = s.split('.'); - Ok(Self { - major: iter.next().and_then(|v| v.parse().ok()).ok_or(())?, - minor: iter.next().and_then(|v| v.parse().ok()).ok_or(())?, - patch: iter.next().and_then(|v| v.parse().ok()).ok_or(())?, - }) - } -} - -impl Version { - pub fn new(major: u8, minor: u8, patch: u8) -> Self { - Self { - major, - minor, - patch, - } - } - - pub fn to_string(&self) -> heapless::String { - let mut s = heapless::String::new(); - s.write_fmt(format_args!("{}.{}.{}", self.major, self.minor, self.patch)) - .unwrap(); - s - } -} - -impl core::cmp::PartialEq for Version { - #[inline] - fn eq(&self, other: &Self) -> bool { - self.major == other.major && self.minor == other.minor && self.patch == other.patch - } -} - -impl core::cmp::PartialOrd for Version { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl core::cmp::Ord for Version { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - match self.major.cmp(&other.major) { - core::cmp::Ordering::Equal => {} - r => return r, - } - - match self.minor.cmp(&other.minor) { - core::cmp::Ordering::Equal => {} - r => return r, - } - - match self.patch.cmp(&other.patch) { - core::cmp::Ordering::Equal => {} - r => return r, - } - - core::cmp::Ordering::Equal - } -} /// Platform abstraction layer for OTA jobs pub trait OtaPal { - type Error; + type BlockWriter: NorFlash; /// OTA abort. /// @@ -156,7 +79,7 @@ pub trait OtaPal { /// aborted. /// /// - `file`: [`FileContext`] File description of the job being aborted - fn abort(&mut self, file: &FileContext) -> Result<(), OtaPalError>; + async fn abort(&mut self, file: &FileContext) -> Result<(), OtaPalError>; /// Activate the newest MCU image received via OTA. /// @@ -168,8 +91,8 @@ pub trait OtaPal { /// /// **return**: The OTA PAL layer error code combined with the MCU specific /// error code. - fn activate_new_image(&mut self) -> Result<(), OtaPalError> { - self.reset_device() + async fn activate_new_image(&mut self) -> Result<(), OtaPalError> { + self.reset_device().await } /// OTA create file to store received data. @@ -179,7 +102,10 @@ pub trait OtaPal { /// is created. /// /// - `file`: [`FileContext`] File description of the job being aborted - fn create_file_for_rx(&mut self, file: &FileContext) -> Result<(), OtaPalError>; + async fn create_file_for_rx( + &mut self, + file: &FileContext, + ) -> Result<&mut Self::BlockWriter, OtaPalError>; /// Get the state of the OTA update image. /// @@ -196,7 +122,7 @@ pub trait OtaPal { /// timer is not started. /// /// **return** An [`PalImageState`]. - fn get_platform_image_state(&mut self) -> Result>; + async fn get_platform_image_state(&mut self) -> Result; /// Attempt to set the state of the OTA update image. /// @@ -208,10 +134,10 @@ pub trait OtaPal { /// /// **return** The [`OtaPalError`] error code combined with the MCU specific /// error code. - fn set_platform_image_state( + async fn set_platform_image_state( &mut self, - image_state: ImageState, - ) -> Result<(), OtaPalError>; + image_state: ImageState, + ) -> Result<(), OtaPalError>; /// Reset the device. /// @@ -222,7 +148,7 @@ pub trait OtaPal { /// /// **return** The OTA PAL layer error code combined with the MCU specific /// error code. - fn reset_device(&mut self) -> Result<(), OtaPalError>; + async fn reset_device(&mut self) -> Result<(), OtaPalError>; /// Authenticate and close the underlying receive file in the specified OTA /// context. @@ -234,23 +160,7 @@ pub trait OtaPal { /// /// **return** The OTA PAL layer error code combined with the MCU specific /// error code. - fn close_file(&mut self, file: &FileContext) -> Result<(), OtaPalError>; - - /// Write a block of data to the specified file at the given offset. - /// - /// - `file`: [`FileContext`] File description of the job being aborted. - /// - `block_offset`: Byte offset to write to from the beginning of the - /// file. - /// - `block_payload`: Byte array of data to write. - /// - /// **return** The number of bytes written on a success, or a negative error - /// code from the platform abstraction layer. - fn write_block( - &mut self, - file: &FileContext, - block_offset: usize, - block_payload: &[u8], - ) -> Result>; + async fn close_file(&mut self, file: &FileContext) -> Result<(), OtaPalError>; /// OTA update complete. /// @@ -284,9 +194,9 @@ pub trait OtaPal { /// the OTA update job has failed in some way and should be rejected. /// /// - `event` [`OtaEvent`] An OTA update event from the `OtaEvent` enum. - fn complete_callback(&mut self, event: OtaEvent) -> Result<(), OtaPalError> { + async fn complete_callback(&mut self, event: OtaEvent) -> Result<(), OtaPalError> { match event { - OtaEvent::Activate => self.activate_new_image(), + OtaEvent::Activate => self.activate_new_image().await, OtaEvent::Fail | OtaEvent::UpdateComplete => { // Nothing special to do. The OTA agent handles it Ok(()) @@ -294,7 +204,7 @@ pub trait OtaPal { OtaEvent::StartTest => { // Accept the image since it was a good transfer // and networking and services are all working. - self.set_platform_image_state(ImageState::Accepted)?; + self.set_platform_image_state(ImageState::Accepted).await?; Ok(()) } OtaEvent::SelfTestFailed => { @@ -308,7 +218,4 @@ pub trait OtaPal { } } } - - /// - fn get_active_firmware_version(&self) -> Result>; } diff --git a/src/ota/state.rs b/src/ota/state.rs deleted file mode 100644 index 3f4e576..0000000 --- a/src/ota/state.rs +++ /dev/null @@ -1,1181 +0,0 @@ -use smlang::statemachine; - -use super::config::Config; -use super::control_interface::ControlInterface; -use super::data_interface::{DataInterface, Protocol}; -use super::encoding::json::JobStatusReason; -use super::encoding::json::OtaJob; -use super::encoding::FileContext; -use super::pal::OtaPal; -use super::pal::OtaPalError; - -use crate::jobs::{data_types::JobStatus, StatusDetails}; -use crate::ota::encoding::Bitmap; -use crate::ota::pal::OtaEvent; - -use fugit_timer::ExtU32; - -use super::{ - error::OtaError, - pal::{ImageState, PalImageState}, -}; - -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum ImageStateReason { - NewerJob, - FailedIngest, - MomentumAbort, - ImageStateMismatch, - SignatureCheckPassed, - InvalidDataProtocol, - UserAbort, - VersionCheck, - Pal(OtaPalError), -} - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum RestartReason { - Activate(u8), - Restart(u8), -} - -impl RestartReason { - #[must_use] - pub fn inc(self) -> Self { - match self { - Self::Activate(cnt) => Self::Activate(cnt + 1), - Self::Restart(cnt) => Self::Restart(cnt + 1), - } - } -} - -#[derive(PartialEq)] -pub struct JobEventData<'a> { - pub job_name: &'a str, - pub ota_document: &'a OtaJob<'a>, - pub status_details: Option<&'a StatusDetails<'a>>, -} - -statemachine! { - guard_error: OtaError, - transitions: { - *Ready + Start [start_handler] = RequestingJob, - RequestingJob | WaitingForFileBlock + RequestJobDocument [request_job_handler] = WaitingForJob, - RequestingJob + RequestTimer [request_job_handler] = WaitingForJob, - RequestingJob + ContinueJob = WaitingForFileBlock, - RequestingJob + ReplacementJob(JobEventData<'a>) [process_job_handler] = CreatingFile, - WaitingForJob + RequestJobDocument [request_job_handler] = WaitingForJob, - WaitingForJob + ReceivedJobDocument(JobEventData<'a>) [process_job_handler] = CreatingFile, - CreatingFile + StartSelfTest [in_self_test_handler] = WaitingForJob, - CreatingFile + CreateFile [init_file_handler] = RequestingFileBlock, - CreatingFile + RequestTimer [init_file_handler] = RequestingFileBlock, - CreatingFile | WaitingForJob | Restarting + Restart(RestartReason) [restart_handler] = Restarting, - RequestingFileBlock | WaitingForFileBlock + RequestFileBlock [request_data_handler] = WaitingForFileBlock, - RequestingFileBlock | WaitingForFileBlock + RequestTimer [request_data_handler] = WaitingForFileBlock, - WaitingForFileBlock + ReceivedFileBlock(&'a mut [u8]) [process_data_handler] = WaitingForFileBlock, - WaitingForFileBlock + ReceivedJobDocument(JobEventData<'a>) [job_notification_handler] = RequestingJob, - WaitingForFileBlock + CloseFile [close_file_handler] = WaitingForJob, - Suspended | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + Resume [resume_job_handler] = RequestingJob, - Ready | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + Suspend = Suspended, - Ready | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + UserAbort [user_abort_handler] = WaitingForJob, - Ready | RequestingJob | WaitingForJob | CreatingFile | RequestingFileBlock | WaitingForFileBlock + Shutdown [shutdown_handler] = Ready, - } -} - -#[cfg(feature = "defmt")] -impl defmt::Format for Error { - fn format(&self, fmt: defmt::Formatter) { - match self { - Error::InvalidEvent => defmt::write!(fmt, "Error::InvalidEvent"), - Error::GuardFailed(e) => defmt::write!(fmt, "Error::GuardFailed({:?})", e), - } - } -} - -pub(crate) enum Interface { - Primary(FileContext), - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Secondary(FileContext), -} - -impl Interface { - pub const fn file_ctx(&self) -> &FileContext { - match self { - Interface::Primary(i) => i, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Interface::Secondary(i) => i, - } - } - - pub fn mut_file_ctx(&mut self) -> &mut FileContext { - match self { - Interface::Primary(i) => i, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Interface::Secondary(i) => i, - } - } -} - -macro_rules! data_interface { - ($self:ident.$func:ident $(,$y:expr),*) => { - match $self.active_interface { - Some(Interface::Primary(ref mut ctx)) => $self.data_primary.$func(ctx, $($y),*), - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - Some(Interface::Secondary(ref mut ctx)) => $self.data_secondary.as_mut().ok_or(OtaError::InvalidInterface)?.$func(ctx, $($y),*), - _ => Err(OtaError::InvalidInterface) - } - }; -} - -// Context of current OTA Job, keeping state -pub(crate) struct SmContext<'a, C, DP, DS, T, ST, PAL, const L: usize, const TIMER_HZ: u32> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - pub(crate) events: heapless::spsc::Queue, L>, - pub(crate) control: &'a C, - pub(crate) data_primary: DP, - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - pub(crate) data_secondary: Option, - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - pub(crate) data_secondary: core::marker::PhantomData, - pub(crate) active_interface: Option, - pub(crate) pal: PAL, - pub(crate) request_momentum: u8, - pub(crate) request_timer: T, - pub(crate) self_test_timer: Option, - pub(crate) config: Config, - pub(crate) image_state: ImageState, -} - -impl<'a, C, DP, DS, T, ST, PAL, const L: usize, const TIMER_HZ: u32> - SmContext<'a, C, DP, DS, T, ST, PAL, L, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - /// Called to update the filecontext structure from the job - fn get_file_context_from_job( - &mut self, - job_name: &str, - ota_document: &OtaJob, - status_details: Option, - ) -> Result { - let file_idx = 0; - - if ota_document - .files - .get(file_idx) - .map(|f| f.filesize) - .unwrap_or_default() - == 0 - { - return Err(OtaError::ZeroFileSize); - } - - // If there's an active job, verify that it's the same as what's being - // reported now - let cur_file_ctx = self.active_interface.as_mut().map(|i| i.mut_file_ctx()); - let file_ctx = if let Some(file_ctx) = cur_file_ctx { - if file_ctx.stream_name != ota_document.streamname { - info!("New job document received, aborting current job"); - - // Abort the current job - // TODO:?? - self.pal - .set_platform_image_state(ImageState::Aborted(ImageStateReason::NewerJob))?; - - // Abort any active file access and release the file resource, - // if needed - self.pal.abort(file_ctx)?; - - // Cleanup related to selected protocol - data_interface!(self.cleanup, &self.config)?; - - // Set new active job - Ok(FileContext::new_from( - job_name, - ota_document, - status_details.map(|s| { - s.iter() - .map(|(&k, &v)| (heapless::String::from(k), heapless::String::from(v))) - .collect() - }), - file_idx, - &self.config, - self.pal.get_active_firmware_version()?, - )?) - } else { - // The same job is being reported so update the url - info!("New job document ID is identical to the current job: Updating the URL based on the new job document"); - file_ctx.update_data_url = ota_document - .files - .get(0) - .map(|f| f.update_data_url.map(heapless::String::from)) - .ok_or(OtaError::InvalidFile)?; - - Err(file_ctx.clone()) - } - } else { - Ok(FileContext::new_from( - job_name, - ota_document, - status_details.map(|s| { - s.iter() - .map(|(&k, &v)| (heapless::String::from(k), heapless::String::from(v))) - .collect() - }), - file_idx, - &self.config, - self.pal.get_active_firmware_version()?, - )?) - }; - - // If the job is in self test mode, don't start an OTA update but - // instead do the following: - // - // If the firmware that performed the update was older than the - // currently running firmware, set the image state to "Testing." This is - // the success path. - // - // If it's the same or newer, reject the job since either the firmware - // was not accepted during self test or an incorrect image was sent by - // the OTA operator. - let mut file_ctx = match file_ctx { - Ok(mut file_ctx) if file_ctx.self_test() => { - self.handle_self_test_job(&mut file_ctx)?; - return Ok(file_ctx); - } - Ok(file_ctx) => { - info!("Job document was accepted. Attempting to begin the update"); - file_ctx - } - Err(file_ctx) => { - info!("Job document for receiving an update received"); - // Don't create file again on update. - return Ok(file_ctx); - } - }; - - // Create/Open the OTA file on the file system - if let Err(e) = self.pal.create_file_for_rx(&file_ctx) { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::Pal(e)), - )?; - - self.ota_close()?; - // FIXME: - return Err(OtaError::Pal); - // return Err(e.into()); - } - - Ok(file_ctx) - } - - fn select_interface( - &self, - file_ctx: FileContext, - protocols: &[Protocol], - ) -> Result { - if protocols.contains(&DP::PROTOCOL) { - Ok(Interface::Primary(file_ctx)) - } else { - #[cfg(all(feature = "ota_mqtt_data", feature = "ota_http_data"))] - if protocols.contains(&DS::PROTOCOL) && self.data_secondary.is_some() { - Ok(Interface::Secondary(file_ctx)) - } else { - Err(file_ctx) - } - - #[cfg(not(all(feature = "ota_mqtt_data", feature = "ota_http_data")))] - Err(file_ctx) - } - } - - /// Check if the current image is `PendingCommit` and thus is in selftest - fn platform_in_selftest(&mut self) -> bool { - // Get the platform state from the OTA pal layer - self.pal - .get_platform_image_state() - .map_or(false, |i| i == PalImageState::PendingCommit) - } - - /// Validate update version when receiving job doc in self test state - fn handle_self_test_job(&mut self, file_ctx: &mut FileContext) -> Result<(), OtaError> { - info!("In self test mode"); - - let active_version = self.pal.get_active_firmware_version().unwrap_or_default(); - - let version_check = if file_ctx.fileid == 0 && file_ctx.file_type == Some(0) { - // Only check for versions if the target is self & always allow - // updates if updated_by is not present. - file_ctx - .updated_by() - .map_or(true, |updated_by| active_version > updated_by) - } else { - true - }; - - info!("Version check: {:?}", version_check); - - if self.config.allow_downgrade || version_check { - // The running firmware version is newer than the firmware that - // performed the update or downgrade is allowed so this means we're - // ready to start the self test phase. - // - // Set image state accordingly and update job status with self test - // identifier. - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Testing(ImageStateReason::VersionCheck), - )?; - - Ok(()) - } else { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Rejected(ImageStateReason::VersionCheck), - )?; - - self.pal.complete_callback(OtaEvent::SelfTestFailed)?; - - // Handle self-test failure in the platform specific implementation, - // example, reset the device in case of firmware upgrade. - self.events - .enqueue(Events::Restart(RestartReason::Restart(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - Ok(()) - } - } - - fn set_image_state_with_reason( - control: &C, - _pal: &mut PAL, - config: &Config, - file_ctx: &mut FileContext, - image_state: ImageState, - ) -> Result, OtaError> { - // debug!("set_image_state_with_reason {:?}", image_state); - // Call the platform specific code to set the image state - - // FIXME: - // let image_state = match pal.set_platform_image_state(image_state) { - // Err(e) if !matches!(image_state, ImageState::Aborted(_)) => { - // If the platform image state couldn't be set correctly, force - // fail the update by setting the image state to "Rejected" - // unless it's already in "Aborted". - - // Capture the failure reason if not already set (and we're not - // already Aborted as checked above). Otherwise Keep the - // original reject reason code since it is possible for the PAL - // to fail to update the image state in some cases (e.g. a reset - // already caused the bundle rollback and we failed to rollback - // again). - // - // Intentionally override reason since we failed within this - // function - // ImageState::Rejected(ImageStateReason::Pal(e)) - // } - // _ => image_state, - // }; - - // Now update the image state and job status on server side - match image_state { - ImageState::Testing(_) => { - // We discovered we're ready for test mode, put job status - // in self_test active - control.update_job_status( - file_ctx, - config, - JobStatus::InProgress, - JobStatusReason::SelfTestActive, - )?; - } - ImageState::Accepted => { - // Now that we have accepted the firmware update, we can - // complete the job - control.update_job_status( - file_ctx, - config, - JobStatus::Succeeded, - JobStatusReason::Accepted, - )?; - } - ImageState::Rejected(_) => { - // The firmware update was rejected, complete the job as - // FAILED (Job service will not allow us to set REJECTED - // after the job has been started already). - control.update_job_status( - file_ctx, - config, - JobStatus::Failed, - JobStatusReason::Rejected, - )?; - } - _ => { - // The firmware update was aborted, complete the job as - // FAILED (Job service will not allow us to set REJECTED - // after the job has been started already). - control.update_job_status( - file_ctx, - config, - JobStatus::Failed, - JobStatusReason::Aborted, - )?; - } - } - Ok(image_state) - } - - pub fn ota_close(&mut self) -> Result<(), OtaError> { - // Cleanup related to selected protocol. - data_interface!(self.cleanup, &self.config)?; - - // Abort any active file access and release the file resource, if needed - let file_ctx = self - .active_interface - .as_ref() - .ok_or(OtaError::InvalidInterface)? - .file_ctx(); - - self.pal.abort(file_ctx)?; - - self.active_interface = None; - Ok(()) - } - - fn ingest_data_block(&mut self, payload: &mut [u8]) -> Result { - let block = data_interface!(self.decode_file_block, payload)?; - - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - if block.validate(self.config.block_size, file_ctx.filesize) { - if block.block_id < file_ctx.block_offset as usize - || !file_ctx - .bitmap - .get(block.block_id - file_ctx.block_offset as usize) - { - info!( - "Block {:?} is a DUPLICATE. {:?} blocks remaining.", - block.block_id, file_ctx.blocks_remaining - ); - - // Just return same progress as before - return Ok(false); - } - - info!( - "Received block {}. {:?} blocks remaining.", - block.block_id, file_ctx.blocks_remaining - ); - - self.pal.write_block( - file_ctx, - block.block_id * self.config.block_size, - block.block_payload, - )?; - - file_ctx - .bitmap - .set(block.block_id - file_ctx.block_offset as usize, false); - - file_ctx.blocks_remaining -= 1; - - if file_ctx.blocks_remaining == 0 { - info!("Received final expected block of file."); - - // Stop the request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - self.pal.close_file(file_ctx)?; - - // Return true to indicate end of file. - Ok(true) - } else { - if file_ctx.bitmap.is_empty() { - file_ctx.block_offset += 31; - file_ctx.bitmap = Bitmap::new( - file_ctx.filesize, - self.config.block_size, - file_ctx.block_offset, - ); - } - - Ok(false) - } - } else { - error!( - "Error! Block {:?} out of expected range! Size {:?}", - block.block_id, block.block_size - ); - - Err(OtaError::BlockOutOfRange) - } - } -} - -impl<'a, C, DP, DS, T, ST, PAL, const L: usize, const TIMER_HZ: u32> StateMachineContext - for SmContext<'a, C, DP, DS, T, ST, PAL, L, TIMER_HZ> -where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, -{ - fn restart_handler(&mut self, reason: &RestartReason) -> Result<(), OtaError> { - debug!("restart_handler"); - match reason { - RestartReason::Activate(cnt) if *cnt > self.config.activate_delay => { - info!("Application callback! OtaEvent::Activate"); - self.pal.complete_callback(OtaEvent::Activate)?; - } - RestartReason::Restart(cnt) if *cnt > self.config.activate_delay => { - self.pal.reset_device()?; - } - r => { - self.events - .enqueue(Events::Restart(r.inc())) - .map_err(|_| OtaError::SignalEventFailed)?; - } - } - Ok(()) - } - - /// Start timers and initiate request for job document - fn start_handler(&mut self) -> Result<(), OtaError> { - debug!("start_handler"); - // Start self-test timer, if platform is in self-test. - if self.platform_in_selftest() { - // Start self-test timer - if let Some(ref mut self_test_timer) = self.self_test_timer { - self_test_timer - .start(self.config.self_test_timeout_ms.millis()) - .map_err(|_| OtaError::Timer)?; - } - } - - // Initialize the control interface - self.control.init()?; - - // Send event to OTA task to get job document - self.events - .enqueue(Events::RequestJobDocument) - .map_err(|_| OtaError::SignalEventFailed) - } - - fn resume_job_handler(&mut self) -> Result<(), OtaError> { - debug!("resume_job_handler"); - - // Initialize the control interface - self.control.init()?; - - // Send signal to request job document - self.events - .enqueue(Events::RequestJobDocument) - .map_err(|_| OtaError::SignalEventFailed) - } - - /// Initiate a request for a job - fn request_job_handler(&mut self) -> Result<(), OtaError> { - debug!("request_job_handler"); - match self.control.request_job() { - Err(e) => { - if self.request_momentum < self.config.max_request_momentum { - // Start request timer - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - self.request_momentum += 1; - Err(e) - } else { - // Stop request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Send shutdown event to the OTA Agent task - self.events - .enqueue(Events::Shutdown) - .map_err(|_| OtaError::SignalEventFailed)?; - - // Too many requests have been sent without a response or - // too many failures when trying to publish the request - // message. Abort. - Err(OtaError::MomentumAbort) - } - } - Ok(_) => { - // Stop request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Reset the request momentum - self.request_momentum = 0; - Ok(()) - } - } - } - - /// Initialize and handle file transfer - fn init_file_handler(&mut self) -> Result<(), OtaError> { - debug!("init_file_handler"); - match data_interface!(self.init_file_transfer) { - Err(e) => { - if self.request_momentum < self.config.max_request_momentum { - // Start request timer - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - self.request_momentum += 1; - Err(e) - } else { - // Stop request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Send shutdown event to the OTA Agent task - self.events - .enqueue(Events::Shutdown) - .map_err(|_| OtaError::SignalEventFailed)?; - - // Too many requests have been sent without a response or - // too many failures when trying to publish the request - // message. Abort. - - Err(OtaError::MomentumAbort) - } - } - Ok(_) => { - // Reset the request momentum - self.request_momentum = 0; - - // TODO: Reset the OTA statistics - - info!("Initialized file handler! Requesting file blocks"); - - self.events - .enqueue(Events::RequestFileBlock) - .map_err(|_| OtaError::SignalEventFailed)?; - - Ok(()) - } - } - } - - /// Handle self test - fn in_self_test_handler(&mut self) -> Result<(), OtaError> { - info!("Beginning self-test"); - // Check the platform's OTA update image state. It should also be in - // self test - let in_self_test = self.platform_in_selftest(); - // Clear self-test flag - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - if in_self_test { - self.pal.complete_callback(OtaEvent::StartTest)?; - info!("Application callback! OtaEvent::StartTest"); - - self.image_state = ImageState::Accepted; - self.control.update_job_status( - file_ctx, - &self.config, - JobStatus::Succeeded, - JobStatusReason::Accepted, - )?; - - file_ctx - .status_details - .insert( - heapless::String::from("self_test"), - heapless::String::from(JobStatusReason::Accepted.as_str()), - ) - .map_err(|_| OtaError::Overflow)?; - - // Stop the self test timer as it is no longer required - if let Some(ref mut self_test_timer) = self.self_test_timer { - self_test_timer.cancel().map_err(|_| OtaError::Timer)?; - } - } else { - // The job is in self test but the platform image state is not so it - // could be an attack on the platform image state. Reject the update - // (this should also cause the image to be erased), aborting the job - // and reset the device. - error!("Rejecting new image and rebooting: the job is in the self-test state while the platform is not."); - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Rejected(ImageStateReason::ImageStateMismatch), - )?; - - self.events - .enqueue(Events::Restart(RestartReason::Restart(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - } - Ok(()) - } - - /// Update file context from job document - fn process_job_handler(&mut self, data: &JobEventData<'_>) -> Result<(), OtaError> { - let JobEventData { - job_name, - ota_document, - status_details, - } = data; - - let file_ctx = self.get_file_context_from_job( - job_name, - ota_document, - status_details.map(Clone::clone), - )?; - - match self.select_interface(file_ctx, &ota_document.protocols) { - Ok(interface) => { - info!("Setting OTA data interface"); - self.active_interface = Some(interface); - } - Err(mut file_ctx) => { - // Failed to set the data interface so abort the OTA. If there - // is a valid job id, then a job status update will be sent. - - error!("Failed to set OTA data interface. Aborting current update."); - - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - &mut file_ctx, - ImageState::Aborted(ImageStateReason::InvalidDataProtocol), - )?; - return Err(OtaError::InvalidInterface); - } - } - - if self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .file_ctx() - .self_test() - { - // If the OTA job is in the self_test state, alert the application layer. - if matches!(self.image_state, ImageState::Testing(_)) { - self.events - .enqueue(Events::StartSelfTest) - .map_err(|_| OtaError::SignalEventFailed)?; - - Ok(()) - } else { - Err(OtaError::InvalidFile) - } - } else { - if !self.platform_in_selftest() { - // Received a valid context so send event to request file blocks - self.events - .enqueue(Events::CreateFile) - .map_err(|_| OtaError::SignalEventFailed)?; - } else { - // Received a job that is not in self-test but platform is, so - // reboot the device to allow roll back to previous image. - error!("Rejecting new image and rebooting: The platform is in the self-test state while the job is not."); - self.events - .enqueue(Events::Restart(RestartReason::Restart(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - } - Ok(()) - } - } - - /// Request for data blocks - fn request_data_handler(&mut self) -> Result<(), OtaError> { - debug!("request_data_handler"); - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - if file_ctx.blocks_remaining > 0 { - // Start the request timer - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - if self.request_momentum <= self.config.max_request_momentum { - // Each request increases the momentum until a response is - // received. Too much momentum is interpreted as a failure to - // communicate and will cause us to abort the OTA. - self.request_momentum += 1; - - // Request data blocks - data_interface!(self.request_file_block, &self.config) - } else { - // Stop the request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Failed to send data request abort and close file. - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - file_ctx, - ImageState::Aborted(ImageStateReason::MomentumAbort), - )?; - - warn!("Shutdown [request_data_handler]"); - self.events - .enqueue(Events::Shutdown) - .map_err(|_| OtaError::SignalEventFailed)?; - - // Reset the request momentum - self.request_momentum = 0; - - // Too many requests have been sent without a response or too - // many failures when trying to publish the request message. - // Abort. - Err(OtaError::MomentumAbort) - } - } else { - Err(OtaError::BlockOutOfRange) - } - } - - /// Upon receiving a new job document cancel current job if present and - /// initiate new download - fn job_notification_handler(&mut self, data: &JobEventData<'_>) -> Result<(), OtaError> { - if let Some(ref mut interface) = self.active_interface { - if interface.file_ctx().job_name.as_str() == data.job_name { - self.events - .enqueue(Events::ContinueJob) - .map_err(|_| OtaError::SignalEventFailed)?; - return Ok(()); - } else { - // Stop the request timer - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Abort the current job - // TODO: This should never write to current image flags?! - self.pal - .set_platform_image_state(ImageState::Aborted(ImageStateReason::NewerJob))?; - self.ota_close()?; - } - } - - // Start the new job! - Ok(()) - } - - /// Process incoming data blocks - fn process_data_handler(&mut self, payload: &mut [u8]) -> Result<(), OtaError> { - debug!("process_data_handler"); - // Decode the file block received - match self.ingest_data_block(payload) { - Ok(true) => { - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - // File is completed! Update progress accordingly. - let (status, reason, event) = if let Some(0) = file_ctx.file_type { - ( - JobStatus::InProgress, - JobStatusReason::SigCheckPassed, - OtaEvent::Activate, - ) - } else { - ( - JobStatus::Succeeded, - JobStatusReason::Accepted, - OtaEvent::UpdateComplete, - ) - }; - - self.control - .update_job_status(file_ctx, &self.config, status, reason)?; - - // Send event to close file. - self.events - .enqueue(Events::CloseFile) - .map_err(|_| OtaError::SignalEventFailed)?; - - // TODO: Last file block processed, increment the statistics - // otaAgent.statistics.otaPacketsProcessed++; - - match event { - OtaEvent::Activate => { - self.events - .enqueue(Events::Restart(RestartReason::Activate(0))) - .map_err(|_| OtaError::SignalEventFailed)?; - } - event => self.pal.complete_callback(event)?, - }; - } - Ok(false) => { - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - // File block processed, increment the statistics. - // otaAgent.statistics.otaPacketsProcessed++; - - // Reset the momentum counter since we received a good block - self.request_momentum = 0; - - // We're actively receiving a file so update the job status as - // needed - self.control.update_job_status( - file_ctx, - &self.config, - JobStatus::InProgress, - JobStatusReason::Receiving, - )?; - - if file_ctx.request_block_remaining > 1 { - file_ctx.request_block_remaining -= 1; - } else { - // Start the request timer. - self.request_timer - .start(self.config.request_wait_ms.millis()) - .map_err(|_| OtaError::Timer)?; - - self.events - .enqueue(Events::RequestFileBlock) - .map_err(|_| OtaError::SignalEventFailed)?; - } - } - Err(e) if e.is_retryable() => { - warn!("Failed to ingest data block, Error is retryable! ingest_data_block returned error {:?}", e); - } - Err(e) => { - let file_ctx = self - .active_interface - .as_mut() - .ok_or(OtaError::InvalidInterface)? - .mut_file_ctx(); - - error!("Failed to ingest data block, rejecting image: ingest_data_block returned error {:?}", e); - - // Call the platform specific code to reject the image - // TODO: This should never write to current image flags?! - self.pal.set_platform_image_state(ImageState::Rejected( - ImageStateReason::FailedIngest, - ))?; - - // TODO: Pal reason - self.control.update_job_status( - file_ctx, - &self.config, - JobStatus::Failed, - JobStatusReason::Pal(0), - )?; - - // Stop the request timer. - self.request_timer.cancel().map_err(|_| OtaError::Timer)?; - - // Send event to close file. - self.events - .enqueue(Events::CloseFile) - .map_err(|_| OtaError::SignalEventFailed)?; - - self.pal.complete_callback(OtaEvent::Fail)?; - info!("Application callback! OtaEvent::Fail"); - return Err(e); - } - } - - // TODO: Application callback for event processed. - // otaAgent.OtaAppCallback( OtaJobEventProcessed, ( const void * ) pEventData ); - Ok(()) - } - - /// Close file opened for download - fn close_file_handler(&mut self) -> Result<(), OtaError> { - self.ota_close() - } - - /// Handle user interrupt to abort task - fn user_abort_handler(&mut self) -> Result<(), OtaError> { - warn!("User abort OTA!"); - if let Some(ref mut interface) = self.active_interface { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - interface.mut_file_ctx(), - ImageState::Aborted(ImageStateReason::UserAbort), - )?; - self.ota_close() - } else { - Err(OtaError::NoActiveJob) - } - } - - /// Handle user interrupt to abort task - fn shutdown_handler(&mut self) -> Result<(), OtaError> { - warn!("Shutting down OTA!"); - if let Some(ref mut interface) = self.active_interface { - self.image_state = Self::set_image_state_with_reason( - self.control, - &mut self.pal, - &self.config, - interface.mut_file_ctx(), - ImageState::Aborted(ImageStateReason::UserAbort), - )?; - self.ota_close()?; - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - ota::{ - agent::OtaAgent, - pal::Version, - test::{ - mock::{MockPal, MockTimer}, - test_job_doc, - }, - }, - test::MockMqtt, - }; - - use super::*; - - #[test] - fn version_check_success() { - // The version check is run after swapping & rebooting, so the PAL will - // return the version of the newly flashed firmware, and `FileContext` - // will contain the `updated_by` version, which is the old firmware - // version. - - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let mut agent = OtaAgent::builder(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .build(); - - let ota_job = test_job_doc(); - let mut file_ctx = FileContext::new_from( - "Job-name", - &ota_job, - None, - 0, - &Config::default(), - Version::new(0, 1, 0), - ) - .unwrap(); - - let context = agent.state.context_mut(); - - assert_eq!(context.handle_self_test_job(&mut file_ctx), Ok(())); - - assert!( - matches!(context.image_state, ImageState::Testing(_)), - "Unexpected image state" - ); - } - - #[test] - fn version_check_rejected() { - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let mut agent = OtaAgent::builder(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .build(); - - let ota_job = test_job_doc(); - let mut file_ctx = FileContext::new_from( - "Job-name", - &ota_job, - None, - 0, - &Config::default(), - Version::new(1, 1, 0), - ) - .unwrap(); - - let context = agent.state.context_mut(); - - assert_eq!(context.handle_self_test_job(&mut file_ctx), Ok(())); - - assert!( - matches!(context.image_state, ImageState::Rejected(_)), - "Unexpected image state" - ); - } - - #[test] - fn version_check_allow_donwgrade() { - let mqtt = MockMqtt::new(); - - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - let mut agent = OtaAgent::builder(&mqtt, &mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 32000) - .allow_downgrade() - .build(); - - let ota_job = test_job_doc(); - let mut file_ctx = FileContext::new_from( - "Job-name", - &ota_job, - None, - 0, - &Config::default(), - Version::new(1, 1, 0), - ) - .unwrap(); - - let context = agent.state.context_mut(); - - assert_eq!(context.handle_self_test_job(&mut file_ctx), Ok(())); - - assert!( - matches!(context.image_state, ImageState::Testing(_)), - "Unexpected image state" - ); - } -} diff --git a/src/ota/test/mock.rs b/src/ota/test/mock.rs deleted file mode 100644 index 42e4ca3..0000000 --- a/src/ota/test/mock.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::ota::{ - encoding::FileContext, - pal::{ImageState, OtaPal, OtaPalError, PalImageState, Version}, -}; - -use super::TEST_TIMER_HZ; - -/// -/// Mock timer used for unit tests. Implements `fugit_timer::Timer` trait. -/// -pub struct MockTimer { - pub is_started: bool, -} -impl MockTimer { - pub fn new() -> Self { - Self { is_started: false } - } -} - -impl fugit_timer::Timer for MockTimer { - type Error = (); - - fn now(&mut self) -> fugit_timer::TimerInstantU32 { - todo!() - } - - fn start( - &mut self, - _duration: fugit_timer::TimerDurationU32, - ) -> Result<(), Self::Error> { - self.is_started = true; - Ok(()) - } - - fn cancel(&mut self) -> Result<(), Self::Error> { - self.is_started = false; - Ok(()) - } - - fn wait(&mut self) -> nb::Result<(), Self::Error> { - Ok(()) - } -} - -/// -/// Mock Platform abstration layer used for unit tests. Implements `OtaPal` -/// trait. -/// -pub struct MockPal {} - -impl OtaPal for MockPal { - type Error = (); - - fn abort(&mut self, _file: &FileContext) -> Result<(), OtaPalError> { - Ok(()) - } - - fn create_file_for_rx(&mut self, _file: &FileContext) -> Result<(), OtaPalError> { - Ok(()) - } - - fn get_platform_image_state(&mut self) -> Result> { - Ok(PalImageState::Valid) - } - - fn set_platform_image_state( - &mut self, - _image_state: ImageState, - ) -> Result<(), OtaPalError> { - Ok(()) - } - - fn reset_device(&mut self) -> Result<(), OtaPalError> { - Ok(()) - } - - fn close_file(&mut self, _file: &FileContext) -> Result<(), OtaPalError> { - Ok(()) - } - - fn write_block( - &mut self, - _file: &FileContext, - _block_offset: usize, - block_payload: &[u8], - ) -> Result> { - Ok(block_payload.len()) - } - - fn get_active_firmware_version(&self) -> Result> { - Ok(Version::new(1, 0, 0)) - } -} diff --git a/src/ota/test/mod.rs b/src/ota/test/mod.rs deleted file mode 100644 index c535337..0000000 --- a/src/ota/test/mod.rs +++ /dev/null @@ -1,523 +0,0 @@ -use super::{ - config::Config, - data_interface::Protocol, - encoding::{ - json::{FileDescription, OtaJob}, - FileContext, - }, - pal::Version, -}; - -pub mod mock; - -pub const TEST_TIMER_HZ: u32 = 8_000_000; - -pub fn test_job_doc() -> OtaJob<'static> { - OtaJob { - protocols: heapless::Vec::from_slice(&[Protocol::Mqtt]).unwrap(), - streamname: "test_stream", - files: heapless::Vec::from_slice(&[FileDescription { - filepath: "", - filesize: 123456, - fileid: 0, - certfile: "cert", - update_data_url: None, - auth_scheme: None, - sha1_rsa: Some(""), - file_type: Some(0), - sha256_rsa: None, - sha1_ecdsa: None, - sha256_ecdsa: None, - }]) - .unwrap(), - } -} - -pub fn test_file_ctx(config: &Config) -> FileContext { - let ota_job = test_job_doc(); - FileContext::new_from("Job-name", &ota_job, None, 0, config, Version::default()).unwrap() -} - -pub mod ota_tests { - use crate::jobs::data_types::{DescribeJobExecutionResponse, JobExecution, JobStatus}; - use crate::ota::data_interface::Protocol; - use crate::ota::encoding::json::{FileDescription, OtaJob}; - use crate::ota::error::OtaError; - use crate::ota::state::{Error, Events, States}; - use crate::ota::test::test_job_doc; - use crate::ota::{ - agent::OtaAgent, - control_interface::ControlInterface, - data_interface::{DataInterface, NoInterface}, - pal::OtaPal, - test::mock::{MockPal, MockTimer}, - }; - use crate::test::MockMqtt; - use mqttrust::encoding::v4::{decode_slice, utils::Pid, PacketType}; - use mqttrust::{MqttError, Packet, QoS, SubscribeTopic}; - use serde::Deserialize; - use serde_json_core::from_slice; - - use super::TEST_TIMER_HZ; - - /// All known job document that the device knows how to process. - #[derive(Debug, PartialEq, Deserialize)] - pub enum JobDetails<'a> { - #[serde(rename = "afr_ota")] - #[serde(borrow)] - Ota(OtaJob<'a>), - - #[serde(other)] - Unknown, - } - - fn new_agent( - mqtt: &MockMqtt, - ) -> OtaAgent<'_, MockMqtt, &MockMqtt, NoInterface, MockTimer, MockTimer, MockPal, TEST_TIMER_HZ> - { - let request_timer = MockTimer::new(); - let self_test_timer = MockTimer::new(); - let pal = MockPal {}; - - OtaAgent::builder(mqtt, mqtt, request_timer, pal) - .with_self_test_timeout(self_test_timer, 16000) - .build() - } - - fn run_to_state<'a, C, DP, DS, T, ST, PAL, const TIMER_HZ: u32>( - agent: &mut OtaAgent<'a, C, DP, DS, T, ST, PAL, TIMER_HZ>, - state: States, - ) where - C: ControlInterface, - DP: DataInterface, - DS: DataInterface, - T: fugit_timer::Timer, - ST: fugit_timer::Timer, - PAL: OtaPal, - { - if agent.state.state() == &state { - return; - } - - match state { - States::Ready => { - println!( - "Running to 'States::Ready', events: {}", - agent.state.context().events.len() - ); - agent.state.process_event(Events::Shutdown).unwrap(); - } - States::CreatingFile => { - println!( - "Running to 'States::CreatingFile', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::WaitingForJob); - - let job_doc = test_job_doc(); - agent.job_update("Test-job", &job_doc, None).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::RequestingFileBlock => { - println!( - "Running to 'States::RequestingFileBlock', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::CreatingFile); - agent.state.process_event(Events::CreateFile).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::RequestingJob => { - println!( - "Running to 'States::RequestingJob', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::Ready); - agent.state.process_event(Events::Start).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::Suspended => { - println!( - "Running to 'States::Suspended', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::Ready); - agent.suspend().unwrap(); - } - States::WaitingForFileBlock => { - println!( - "Running to 'States::Suspended', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::RequestingFileBlock); - agent.state.process_event(Events::RequestFileBlock).unwrap(); - agent.state.context_mut().events.dequeue(); - } - States::WaitingForJob => { - println!( - "Running to 'States::WaitingForJob', events: {}", - agent.state.context().events.len() - ); - run_to_state(agent, States::RequestingJob); - agent.check_for_update().unwrap(); - } - States::Restarting => {} - } - } - - pub fn set_pid(buf: &mut [u8], pid: Pid) -> Result<(), ()> { - let mut offset = 0; - let (header, _) = mqttrust::encoding::v4::decoder::read_header(buf, &mut offset) - .map_err(|_| ())? - .ok_or(())?; - - match (header.typ, header.qos) { - (PacketType::Publish, QoS::AtLeastOnce | QoS::ExactlyOnce) => { - if buf[offset..].len() < 2 { - return Err(()); - } - let len = ((buf[offset] as usize) << 8) | buf[offset + 1] as usize; - - offset += 2; - if len > buf[offset..].len() { - return Err(()); - } else { - offset += len; - } - } - (PacketType::Subscribe | PacketType::Unsubscribe | PacketType::Suback, _) => {} - ( - PacketType::Puback - | PacketType::Pubrec - | PacketType::Pubrel - | PacketType::Pubcomp - | PacketType::Unsuback, - _, - ) => {} - _ => return Ok(()), - } - - pid.to_buffer(buf, &mut offset).map_err(|_| ()) - } - - #[test] - fn ready_when_stopped() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - assert!(matches!(ota_agent.state.state(), &States::Ready)); - run_to_state(&mut ota_agent, States::Ready); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(ota_agent.state.context().events.len(), 0); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - fn abort_when_stopped() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::Ready); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert_eq!( - ota_agent.abort().err(), - Some(Error::GuardFailed(OtaError::NoActiveJob)) - ); - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - fn resume_when_stopped() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::Ready); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert!(matches!( - ota_agent.resume().err().unwrap(), - Error::InvalidEvent - )); - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - fn resume_when_suspended() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::Suspended); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert!(matches!( - ota_agent.resume().unwrap(), - &States::RequestingJob - )); - assert_eq!(mqtt.tx.borrow_mut().len(), 1); - } - - #[test] - fn check_for_update() { - let mqtt = MockMqtt::new(); - let mut ota_agent = new_agent(&mqtt); - - run_to_state(&mut ota_agent, States::RequestingJob); - assert!(matches!(ota_agent.state.state(), &States::RequestingJob)); - - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert!(matches!( - ota_agent.check_for_update().unwrap(), - &States::WaitingForJob - )); - - let bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![SubscribeTopic { - topic_path: "$aws/things/test_client/jobs/notify-next", - qos: QoS::AtLeastOnce - }] - ); - - let mut bytes = mqtt.tx.borrow_mut().pop_front().unwrap(); - set_pid(bytes.as_mut_slice(), Pid::new()).expect("Failed to set valid PID"); - let packet = decode_slice(bytes.as_slice()).unwrap(); - - let publish = match packet { - Some(Packet::Publish(p)) => p, - _ => panic!(), - }; - - assert_eq!( - publish, - mqttrust::encoding::v4::publish::Publish { - dup: false, - qos: QoS::AtLeastOnce, - retain: false, - topic_name: "$aws/things/test_client/jobs/$next/get", - payload: &[123, 125], - pid: Some(Pid::new()), - } - ); - assert_eq!(mqtt.tx.borrow_mut().len(), 0); - } - - #[test] - #[ignore] - fn request_job_retry_fail() { - let mut mqtt = MockMqtt::new(); - - // Let MQTT publish fail so request job will also fail - mqtt.publish_fail(); - - let mut ota_agent = new_agent(&mqtt); - - // Place the OTA Agent into the state for requesting a job - run_to_state(&mut ota_agent, States::RequestingJob); - assert!(matches!(ota_agent.state.state(), &States::RequestingJob)); - assert_eq!(ota_agent.state.context().events.len(), 0); - - assert_eq!( - ota_agent.check_for_update().err(), - Some(Error::GuardFailed(OtaError::Mqtt(MqttError::Full))) - ); - - // Fail the maximum number of attempts to request a job document - for _ in 0..ota_agent.state.context().config.max_request_momentum { - ota_agent.process_event().unwrap(); - assert!(ota_agent.state.context().request_timer.is_started); - ota_agent.timer_callback().ok(); - assert!(matches!(ota_agent.state.state(), &States::RequestingJob)); - } - - // Attempt to request another job document after failing the maximum - // number of times, triggering a shutdown event. - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::Ready)); - assert_eq!(mqtt.tx.borrow_mut().len(), 4); - } - - #[test] - fn init_file_transfer_mqtt() { - let mqtt = MockMqtt::new(); - - let mut ota_agent = new_agent(&mqtt); - - // Place the OTA Agent into the state for creating file - run_to_state(&mut ota_agent, States::CreatingFile); - assert!(matches!(ota_agent.state.state(), &States::CreatingFile)); - assert_eq!(ota_agent.state.context().events.len(), 0); - - ota_agent.process_event().unwrap(); - assert!(matches!(ota_agent.state.state(), &States::CreatingFile)); - ota_agent.process_event().unwrap(); - - ota_agent.state.process_event(Events::CreateFile).unwrap(); - - // Above will automatically enqueue `RequestFileBlock` - assert!(matches!( - ota_agent.state.state(), - &States::RequestingFileBlock - )); - - // Check the latest MQTT message - let bytes = mqtt.tx.borrow_mut().pop_back().unwrap(); - - let packet = decode_slice(bytes.as_slice()).unwrap(); - let topics = match packet { - Some(Packet::Subscribe(ref s)) => s.topics().collect::>(), - _ => panic!(), - }; - - assert_eq!( - topics, - vec![SubscribeTopic { - topic_path: "$aws/things/test_client/streams/test_stream/data/cbor", - qos: QoS::AtLeastOnce - }] - ); - - // Should still contain: - // - subscription to `$aws/things/test_client/jobs/notify-next` - // - publish to `$aws/things/test_client/jobs/$next/get` - assert_eq!(mqtt.tx.borrow_mut().len(), 2); - } - - #[test] - fn request_file_block_mqtt() { - let mqtt = MockMqtt::new(); - - let mut ota_agent = new_agent(&mqtt); - - // Place the OTA Agent into the state for requesting file block - run_to_state(&mut ota_agent, States::RequestingFileBlock); - assert!(matches!( - ota_agent.state.state(), - &States::RequestingFileBlock - )); - assert_eq!(ota_agent.state.context().events.len(), 0); - - ota_agent - .state - .process_event(Events::RequestFileBlock) - .unwrap(); - - assert!(matches!( - ota_agent.state.state(), - &States::WaitingForFileBlock - )); - - let bytes = mqtt.tx.borrow_mut().pop_back().unwrap(); - - let publish = match decode_slice(bytes.as_slice()).unwrap() { - Some(Packet::Publish(p)) => p, - _ => panic!(), - }; - - // Check the latest MQTT message - assert_eq!( - publish, - mqttrust::encoding::v4::publish::Publish { - dup: false, - qos: QoS::AtMostOnce, - retain: false, - topic_name: "$aws/things/test_client/streams/test_stream/get/cbor", - payload: &[ - 164, 97, 102, 0, 97, 108, 25, 1, 0, 97, 111, 0, 97, 98, 68, 255, 255, 255, 127 - ], - pid: None - } - ); - - // Should still contain: - // - subscription to `$aws/things/test_client/jobs/notify-next` - // - publish to `$aws/things/test_client/jobs/$next/get` - // - subscription to - // `$aws/things/test_client/streams/test_stream/data/cbor` - assert_eq!(mqtt.tx.borrow_mut().len(), 3); - } - - #[test] - fn deserialize_describe_job_execution_response_ota() { - let payload = br#"{ - "clientToken":"0:rustot-test", - "timestamp":1624445100, - "execution":{ - "jobId":"AFR_OTA-rustot_test_1", - "status":"QUEUED", - "queuedAt":1624440618, - "lastUpdatedAt":1624440618, - "versionNumber":1, - "executionNumber":1, - "jobDocument":{ - "afr_ota":{ - "protocols":["MQTT"], - "streamname":"AFR_OTA-0ba01295-9417-4ba7-9a99-4b31fb03d252", - "files":[{ - "filepath":"IMG_test.jpg", - "filesize":2674792, - "fileid":0, - "certfile":"nope", - "fileType":0, - "sig-sha256-ecdsa":"This is my signature! Better believe it!" - }] - } - } - } - }"#; - - let (response, _) = - from_slice::>(payload).unwrap(); - - assert_eq!( - response, - DescribeJobExecutionResponse { - execution: Some(JobExecution { - execution_number: Some(1), - job_document: Some(JobDetails::Ota(OtaJob { - protocols: heapless::Vec::from_slice(&[Protocol::Mqtt]).unwrap(), - streamname: "AFR_OTA-0ba01295-9417-4ba7-9a99-4b31fb03d252", - files: heapless::Vec::from_slice(&[FileDescription { - filepath: "IMG_test.jpg", - filesize: 2674792, - fileid: 0, - certfile: "nope", - update_data_url: None, - auth_scheme: None, - sha1_rsa: None, - sha256_rsa: None, - sha1_ecdsa: None, - sha256_ecdsa: Some("This is my signature! Better believe it!"), - file_type: Some(0), - }]) - .unwrap(), - })), - job_id: "AFR_OTA-rustot_test_1", - last_updated_at: 1624440618, - queued_at: 1624440618, - status_details: None, - status: JobStatus::Queued, - version_number: 1, - approximate_seconds_before_timed_out: None, - started_at: None, - thing_name: None, - }), - timestamp: 1624445100, - client_token: Some("0:rustot-test"), - } - ); - } -} diff --git a/src/provisioning/data_types.rs b/src/provisioning/data_types.rs index 7349425..b3e1088 100644 --- a/src/provisioning/data_types.rs +++ b/src/provisioning/data_types.rs @@ -1,4 +1,3 @@ -use heapless::LinearMap; use serde::{Deserialize, Serialize}; /// To receive error responses, subscribe to @@ -94,7 +93,7 @@ pub struct CreateKeysAndCertificateResponse<'a> { /// **:** The provisioning template name. #[derive(Debug, PartialEq, Serialize)] // #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct RegisterThingRequest<'a, const P: usize> { +pub struct RegisterThingRequest<'a, P: Serialize> { /// The token to prove ownership of the certificate. The token is generated /// by AWS IoT when you create a certificate over MQTT. #[serde(rename = "certificateOwnershipToken")] @@ -102,8 +101,8 @@ pub struct RegisterThingRequest<'a, const P: usize> { /// Optional. Key-value pairs from the device that are used by the /// pre-provisioning hooks to evaluate the registration request. - #[serde(rename = "parameters")] - pub parameters: Option>, + #[serde(rename = "parameters", skip_serializing_if = "Option::is_none")] + pub parameters: Option

, } /// Subscribe to @@ -113,12 +112,101 @@ pub struct RegisterThingRequest<'a, const P: usize> { /// **:** The provisioning template name. #[derive(Debug, PartialEq, Deserialize)] // #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct RegisterThingResponse<'a, const P: usize> { +pub struct RegisterThingResponse<'a, C> { /// The device configuration defined in the template. #[serde(rename = "deviceConfiguration")] - pub device_configuration: LinearMap<&'a str, &'a str, P>, + pub device_configuration: Option, /// The name of the IoT thing created during provisioning. #[serde(rename = "thingName")] pub thing_name: &'a str, } + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Serialize)] + struct Parameters<'a> { + some_key: &'a str, + } + + #[derive(Debug, Deserialize, PartialEq)] + struct DeviceConfiguration { + some_key: heapless::String<64>, + } + + #[test] + fn serialize_optional_parameters() { + let register_request = RegisterThingRequest { + certificate_ownership_token: "my_ownership_token", + parameters: Some(Parameters { + some_key: "optional_key", + }), + }; + + let json = serde_json_core::to_string::<_, 128>(®ister_request).unwrap(); + assert_eq!( + json.as_str(), + r#"{"certificateOwnershipToken":"my_ownership_token","parameters":{"some_key":"optional_key"}}"# + ); + + let register_request_none: RegisterThingRequest<'_, Parameters> = RegisterThingRequest { + certificate_ownership_token: "my_ownership_token", + parameters: None, + }; + + let json = serde_json_core::to_string::<_, 128>(®ister_request_none).unwrap(); + assert_eq!( + json.as_str(), + r#"{"certificateOwnershipToken":"my_ownership_token"}"# + ); + } + + #[test] + fn deserialize_optional_device_configuration() { + let register_response = + r#"{"thingName":"my_thing","deviceConfiguration":{"some_key":"optional_key"}}"#; + + let (response, _) = + serde_json_core::from_str::>( + register_response, + ) + .unwrap(); + assert_eq!( + response, + RegisterThingResponse { + thing_name: "my_thing", + device_configuration: Some(DeviceConfiguration { + some_key: heapless::String::try_from("optional_key").unwrap() + }), + } + ); + + let register_response_none = r#"{"thingName":"my_thing"}"#; + + let (response, _) = + serde_json_core::from_str::>(register_response_none).unwrap(); + assert_eq!( + response, + RegisterThingResponse { + thing_name: "my_thing", + device_configuration: None, + } + ); + + // // FIXME + // let register_response_none = r#"{"thingName":"my_thing","deviceConfiguration":{}}"#; + + // let (response, _) = + // serde_json_core::from_str::>(®ister_response_none) + // .unwrap(); + // assert_eq!( + // response, + // RegisterThingResponse { + // thing_name: "my_thing", + // device_configuration: None, + // } + // ); + } +} diff --git a/src/provisioning/error.rs b/src/provisioning/error.rs index e43c01c..20961fb 100644 --- a/src/provisioning/error.rs +++ b/src/provisioning/error.rs @@ -1,20 +1,17 @@ #[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub enum Error { Overflow, InvalidPayload, InvalidState, - Mqtt(mqttrust::MqttError), - DeserializeJson(serde_json_core::de::Error), + Mqtt(embedded_mqtt::Error), + DeserializeJson(#[cfg_attr(feature = "defmt", defmt(Debug2Format))] serde_json_core::de::Error), DeserializeCbor, + CertificateStorage, Response(u16), } -impl From for Error { - fn from(e: mqttrust::MqttError) -> Self { - Self::Mqtt(e) - } -} - impl From for Error { fn from(_: serde_json_core::ser::Error) -> Self { Self::Overflow @@ -27,8 +24,14 @@ impl From for Error { } } -impl From for Error { - fn from(_e: serde_cbor::Error) -> Self { +impl From for Error { + fn from(_e: minicbor_serde::error::DecodeError) -> Self { Self::DeserializeCbor } } + +impl From for Error { + fn from(e: embedded_mqtt::Error) -> Self { + Self::Mqtt(e) + } +} diff --git a/src/provisioning/mod.rs b/src/provisioning/mod.rs index 1c29c56..2d55b59 100644 --- a/src/provisioning/mod.rs +++ b/src/provisioning/mod.rs @@ -2,20 +2,32 @@ pub mod data_types; mod error; pub mod topics; -use heapless::LinearMap; -use mqttrust::Mqtt; -#[cfg(feature = "provision_cbor")] -use serde::Serialize; +use core::future::Future; + +use embassy_sync::blocking_mutex::raw::RawMutex; +use embedded_mqtt::{ + DeferredPayload, EncodingError, Publish, Subscribe, SubscribeTopic, Subscription, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +pub use error::Error; use self::{ data_types::{ - CreateCertificateFromCsrResponse, CreateKeysAndCertificateResponse, ErrorResponse, - RegisterThingRequest, RegisterThingResponse, + CreateCertificateFromCsrRequest, CreateCertificateFromCsrResponse, + CreateKeysAndCertificateResponse, ErrorResponse, RegisterThingRequest, + RegisterThingResponse, }, - error::Error, - topics::{PayloadFormat, Subscribe, Topic, Unsubscribe}, + topics::{PayloadFormat, Topic}, }; +pub trait CredentialHandler { + fn store_credentials( + &mut self, + credentials: Credentials<'_>, + ) -> impl Future>; +} + #[derive(Debug)] pub struct Credentials<'a> { pub certificate_id: &'a str, @@ -23,236 +35,377 @@ pub struct Credentials<'a> { pub private_key: Option<&'a str>, } -#[derive(Debug)] -pub enum Response<'a, const P: usize> { - Credentials(Credentials<'a>), - DeviceConfiguration(LinearMap<&'a str, &'a str, P>), -} - -pub struct FleetProvisioner<'a, M> -where - M: Mqtt, -{ - mqtt: &'a M, - template_name: &'a str, - ownership_token: Option>, - payload_format: PayloadFormat, -} - -impl<'a, M> FleetProvisioner<'a, M> -where - M: Mqtt, -{ - /// Instantiate a new `FleetProvisioner`, using `template_name` for the provisioning - pub fn new(mqtt: &'a M, template_name: &'a str) -> Self { - Self { +pub struct FleetProvisioner; + +impl FleetProvisioner { + pub async fn provision<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, + template_name: &str, + parameters: Option, + credential_handler: &mut impl CredentialHandler, + ) -> Result, Error> + where + C: DeserializeOwned, + { + Self::provision_inner( mqtt, template_name, - ownership_token: None, - payload_format: PayloadFormat::Json, - } + parameters, + None, + credential_handler, + PayloadFormat::Json, + ) + .await } - #[cfg(feature = "provision_cbor")] - pub fn new_cbor(mqtt: &'a M, template_name: &'a str) -> Self { - Self { + pub async fn provision_csr<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, + template_name: &str, + parameters: Option, + csr: &str, + credential_handler: &mut impl CredentialHandler, + ) -> Result, Error> + where + C: DeserializeOwned, + { + Self::provision_inner( mqtt, template_name, - ownership_token: None, - payload_format: PayloadFormat::Cbor, - } - } - - pub fn initialize(&self) -> Result<(), Error> { - Subscribe::<4>::new() - .topic( - Topic::CreateKeysAndCertificateAccepted(self.payload_format), - mqttrust::QoS::AtLeastOnce, - ) - .topic( - Topic::CreateKeysAndCertificateRejected(self.payload_format), - mqttrust::QoS::AtLeastOnce, - ) - .topic( - Topic::RegisterThingAccepted(self.template_name, self.payload_format), - mqttrust::QoS::AtLeastOnce, - ) - .topic( - Topic::RegisterThingRejected(self.template_name, self.payload_format), - mqttrust::QoS::AtLeastOnce, - ) - .send(self.mqtt)?; - - Ok(()) + parameters, + Some(csr), + credential_handler, + PayloadFormat::Json, + ) + .await } - // TODO: Can we handle this better? If sent from `initialize` it causes a - // race condition with the subscription ack. - pub fn begin(&mut self) -> Result<(), Error> { - self.mqtt.publish( - Topic::CreateKeysAndCertificate(self.payload_format) - .format::<29>()? - .as_str(), - b"", - mqttrust::QoS::AtLeastOnce, - )?; - - Ok(()) + #[cfg(feature = "provision_cbor")] + pub async fn provision_cbor<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, + template_name: &str, + parameters: Option, + credential_handler: &mut impl CredentialHandler, + ) -> Result, Error> + where + C: DeserializeOwned, + { + Self::provision_inner( + mqtt, + template_name, + parameters, + None, + credential_handler, + PayloadFormat::Cbor, + ) + .await } - pub fn register_thing<'b, const P: usize>( - &mut self, - parameters: Option>, - ) -> Result<(), Error> { - let certificate_ownership_token = self.ownership_token.take().ok_or(Error::InvalidState)?; - - let register_request = RegisterThingRequest { - certificate_ownership_token: &certificate_ownership_token, + #[cfg(feature = "provision_cbor")] + pub async fn provision_csr_cbor<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, + template_name: &str, + parameters: Option, + csr: &str, + credential_handler: &mut impl CredentialHandler, + ) -> Result, Error> + where + C: DeserializeOwned, + { + Self::provision_inner( + mqtt, + template_name, parameters, - }; - - let payload = &mut [0u8; 1024]; - - let payload_len = match self.payload_format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - let mut serializer = - serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new(payload)); - register_request.serialize(&mut serializer)?; - serializer.into_inner().bytes_written() - } - PayloadFormat::Json => serde_json_core::to_slice(®ister_request, payload)?, - }; - - self.mqtt.publish( - Topic::RegisterThing(self.template_name, self.payload_format) - .format::<69>()? - .as_str(), - &payload[..payload_len], - mqttrust::QoS::AtLeastOnce, - )?; - - Ok(()) + Some(csr), + credential_handler, + PayloadFormat::Cbor, + ) + .await } - pub fn handle_message<'b, const P: usize>( - &mut self, - topic_name: &'b str, - payload: &'b mut [u8], - ) -> Result>, Error> { - match Topic::from_str(topic_name) { + #[cfg(feature = "provision_cbor")] + async fn provision_inner<'a, C, M: RawMutex>( + mqtt: &embedded_mqtt::MqttClient<'a, M>, + template_name: &str, + parameters: Option, + csr: Option<&str>, + credential_handler: &mut impl CredentialHandler, + payload_format: PayloadFormat, + ) -> Result, Error> + where + C: DeserializeOwned, + { + let mut create_subscription = Self::begin(mqtt, csr, payload_format).await?; + let mut message = create_subscription + .next_message() + .await + .ok_or(Error::InvalidState)?; + + let ownership_token = match Topic::from_str(message.topic_name()) { Some(Topic::CreateKeysAndCertificateAccepted(format)) => { - trace!( - "Topic::CreateKeysAndCertificateAccepted {:?}. Payload len: {:?}", - format, - payload.len() - ); - - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(payload)? - } - PayloadFormat::Json => { - serde_json_core::from_slice::(payload)?.0 - } - }; - - self.ownership_token - .replace(heapless::String::from(response.certificate_ownership_token)); - - Ok(Some(Response::Credentials(Credentials { - certificate_id: response.certificate_id, - certificate_pem: response.certificate_pem, - private_key: Some(response.private_key), - }))) + let response = + Self::deserialize::(format, &mut message)?; + + credential_handler + .store_credentials(Credentials { + certificate_id: response.certificate_id, + certificate_pem: response.certificate_pem, + private_key: Some(response.private_key), + }) + .await?; + + response.certificate_ownership_token } + Some(Topic::CreateCertificateFromCsrAccepted(format)) => { - trace!("Topic::CreateCertificateFromCsrAccepted {:?}", format); + let response = Self::deserialize::( + format, + message.payload_mut(), + )?; + + credential_handler + .store_credentials(Credentials { + certificate_id: response.certificate_id, + certificate_pem: response.certificate_pem, + private_key: None, + }) + .await?; + + response.certificate_ownership_token + } - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(payload)? - } - PayloadFormat::Json => { - serde_json_core::from_slice::(payload)?.0 - } - }; + // Error happened! + Some( + Topic::CreateKeysAndCertificateRejected(format) + | Topic::CreateCertificateFromCsrRejected(format), + ) => { + return Err(Self::handle_error(format, message.payload_mut()).unwrap_err()); + } - self.ownership_token - .replace(heapless::String::from(response.certificate_ownership_token)); + t => { + warn!("Got unexpected packet on topic {:?}", t); - Ok(Some(Response::Credentials(Credentials { - certificate_id: response.certificate_id, - certificate_pem: response.certificate_pem, - private_key: None, - }))) + return Err(Error::InvalidState); } - Some(Topic::RegisterThingAccepted(_, format)) => { - trace!("Topic::RegisterThingAccepted {:?}", format); + }; + + let register_request = RegisterThingRequest { + certificate_ownership_token: ownership_token, + parameters, + }; - let response = match format { + let payload = DeferredPayload::new( + |buf| { + Ok(match payload_format { #[cfg(feature = "provision_cbor")] PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::>(payload)? - } - PayloadFormat::Json => { - serde_json_core::from_slice::>(payload)?.0 + let mut serializer = minicbor_serde::Serializer::new( + minicbor::encode::write::Cursor::new(buf), + ); + register_request + .serialize(&mut serializer) + .map_err(|_| EncodingError::BufferSize)?; + serializer.into_encoder().writer().position() } - }; - - assert_eq!(response.thing_name, self.mqtt.client_id()); + PayloadFormat::Json => serde_json_core::to_slice(®ister_request, buf) + .map_err(|_| EncodingError::BufferSize)?, + }) + }, + 1024, + ); + + debug!("Starting RegisterThing"); + + let mut register_subscription = mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + Topic::RegisterThingAccepted(template_name, payload_format) + .format::<150>()? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::RegisterThingRejected(template_name, payload_format) + .format::<150>()? + .as_str(), + ) + .build() + ]) + .build(), + ) + .await?; + + mqtt.publish( + Publish::builder() + .topic_name( + Topic::RegisterThing(template_name, payload_format) + .format::<69>()? + .as_str(), + ) + .payload(payload) + .build(), + ) + .await?; + + drop(message); + create_subscription.unsubscribe().await?; + + let mut message = register_subscription + .next_message() + .await + .ok_or(Error::InvalidState)?; + + match Topic::from_str(message.topic_name()) { + Some(Topic::RegisterThingAccepted(_, format)) => { + let response = Self::deserialize::>( + format, + message.payload_mut(), + )?; - Ok(Some(Response::DeviceConfiguration( - response.device_configuration, - ))) + Ok(response.device_configuration) } // Error happened! - Some( - Topic::CreateKeysAndCertificateRejected(format) - | Topic::CreateCertificateFromCsrRejected(format) - | Topic::RegisterThingRejected(_, format), - ) => { - let response = match format { - #[cfg(feature = "provision_cbor")] - PayloadFormat::Cbor => { - serde_cbor::de::from_mut_slice::(payload)? - } - PayloadFormat::Json => serde_json_core::from_slice::(payload)?.0, - }; - - error!("{:?}: {:?}", topic_name, response); - - Err(Error::Response(response.status_code)) + Some(Topic::RegisterThingRejected(_, format)) => { + Err(Self::handle_error(format, message.payload_mut()).unwrap_err()) } t => { trace!("{:?}", t); - Ok(None) + + Err(Error::InvalidState) } } } -} -impl<'a, M> Drop for FleetProvisioner<'a, M> -where - M: Mqtt, -{ - fn drop(&mut self) { - Unsubscribe::<4>::new() - .topic(Topic::CreateKeysAndCertificateAccepted(self.payload_format)) - .topic(Topic::CreateKeysAndCertificateRejected(self.payload_format)) - .topic(Topic::RegisterThingAccepted( - self.template_name, - self.payload_format, - )) - .topic(Topic::RegisterThingRejected( - self.template_name, - self.payload_format, - )) - .send(self.mqtt) - .ok(); + async fn begin<'a, 'b, M: RawMutex>( + mqtt: &'b embedded_mqtt::MqttClient<'a, M>, + csr: Option<&str>, + payload_format: PayloadFormat, + ) -> Result, Error> { + if let Some(csr) = csr { + let subscription = mqtt + .subscribe( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::CreateCertificateFromCsrRejected(payload_format) + .format::<47>()? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::CreateCertificateFromCsrAccepted(payload_format) + .format::<47>()? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await?; + + let request = CreateCertificateFromCsrRequest { + certificate_signing_request: csr, + }; + + let payload = DeferredPayload::new( + |buf| { + Ok(match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => { + let mut serializer = minicbor_serde::Serializer::new( + minicbor::encode::write::Cursor::new(buf), + ); + request + .serialize(&mut serializer) + .map_err(|_| EncodingError::BufferSize)?; + serializer.into_encoder().writer().position() + } + PayloadFormat::Json => serde_json_core::to_slice(&request, buf) + .map_err(|_| EncodingError::BufferSize)?, + }) + }, + csr.len() + 32, + ); + + mqtt.publish( + Publish::builder() + .topic_name( + Topic::CreateCertificateFromCsr(payload_format) + .format::<40>()? + .as_str(), + ) + .payload(payload) + .build(), + ) + .await?; + + Ok(subscription) + } else { + let subscription = mqtt + .subscribe( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + Topic::CreateKeysAndCertificateAccepted(payload_format) + .format::<38>()? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + Topic::CreateKeysAndCertificateRejected(payload_format) + .format::<38>()? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await?; + + mqtt.publish( + Publish::builder() + .topic_name( + Topic::CreateKeysAndCertificate(payload_format) + .format::<29>()? + .as_str(), + ) + .payload(b"") + .build(), + ) + .await?; + + Ok(subscription) + } + } + + fn deserialize<'a, R: Deserialize<'a>>( + payload_format: PayloadFormat, + payload: &'a mut [u8], + ) -> Result { + Ok(match payload_format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => minicbor_serde::from_slice::(payload)?, + PayloadFormat::Json => serde_json_core::from_slice::(payload)?.0, + }) + } + + fn handle_error(format: PayloadFormat, payload: &mut [u8]) -> Result<(), Error> { + let response = match format { + #[cfg(feature = "provision_cbor")] + PayloadFormat::Cbor => minicbor_serde::from_slice::(payload)?, + PayloadFormat::Json => serde_json_core::from_slice::(payload)?.0, + }; + + error!("{:?}", response); + + Err(Error::Response(response.status_code)) } } diff --git a/src/provisioning/topics.rs b/src/provisioning/topics.rs index 9ced90a..bef6710 100644 --- a/src/provisioning/topics.rs +++ b/src/provisioning/topics.rs @@ -3,7 +3,6 @@ use core::fmt::Write; use core::str::FromStr; use heapless::String; -use mqttrust::{Mqtt, QoS, SubscribeTopic}; use super::Error; @@ -59,18 +58,27 @@ pub enum Topic<'a> { CreateCertificateFromCsr(PayloadFormat), // ---- Incoming Topics + /// `$aws/provisioning-templates//provision//+` + RegisterThingAny(&'a str, PayloadFormat), + /// `$aws/provisioning-templates//provision//accepted` RegisterThingAccepted(&'a str, PayloadFormat), /// `$aws/provisioning-templates//provision//rejected` RegisterThingRejected(&'a str, PayloadFormat), + /// `$aws/certificates/create//+` + CreateKeysAndCertificateAny(PayloadFormat), + /// `$aws/certificates/create//accepted` CreateKeysAndCertificateAccepted(PayloadFormat), /// `$aws/certificates/create//rejected` CreateKeysAndCertificateRejected(PayloadFormat), + /// `$aws/certificates/create-from-csr//+` + CreateCertificateFromCsrAny(PayloadFormat), + /// `$aws/certificates/create-from-csr//accepted` CreateCertificateFromCsrAccepted(PayloadFormat), @@ -88,7 +96,7 @@ impl<'a> Topic<'a> { pub fn from_str(s: &'a str) -> Option { let tt = s.splitn(6, '/').collect::>(); - match (tt.get(0), tt.get(1)) { + match (tt.first(), tt.get(1)) { (Some(&"$aws"), Some(&"provisioning-templates")) => { // This is a register thing topic, now figure out which one. @@ -99,7 +107,7 @@ impl<'a> Topic<'a> { Some(payload_format), Some(&"accepted"), ) => Some(Topic::RegisterThingAccepted( - *template_name, + template_name, PayloadFormat::from_str(payload_format).ok()?, )), ( @@ -108,7 +116,7 @@ impl<'a> Topic<'a> { Some(payload_format), Some(&"rejected"), ) => Some(Topic::RegisterThingRejected( - *template_name, + template_name, PayloadFormat::from_str(payload_format).ok()?, )), _ => None, @@ -169,6 +177,14 @@ impl<'a> Topic<'a> { payload_format, )) } + Topic::RegisterThingAny(template_name, payload_format) => { + topic_path.write_fmt(format_args!( + "{}/{}/provision/{}/#", + Self::PROVISIONING_PREFIX, + template_name, + payload_format, + )) + } Topic::RegisterThingAccepted(template_name, payload_format) => { topic_path.write_fmt(format_args!( "{}/{}/provision/{}/accepted", @@ -192,6 +208,9 @@ impl<'a> Topic<'a> { payload_format, )), + Topic::CreateKeysAndCertificateAny(payload_format) => topic_path.write_fmt( + format_args!("{}/create/{}/#", Self::CERT_PREFIX, payload_format), + ), Topic::CreateKeysAndCertificateAccepted(payload_format) => topic_path.write_fmt( format_args!("{}/create/{}/accepted", Self::CERT_PREFIX, payload_format), ), @@ -204,120 +223,26 @@ impl<'a> Topic<'a> { Self::CERT_PREFIX, payload_format, )), - Topic::CreateCertificateFromCsrAccepted(payload_format) => topic_path.write_fmt( - format_args!("{}/create-from-csr/{}", Self::CERT_PREFIX, payload_format), - ), - Topic::CreateCertificateFromCsrRejected(payload_format) => topic_path.write_fmt( - format_args!("{}/create-from-csr/{}", Self::CERT_PREFIX, payload_format), + Topic::CreateCertificateFromCsrAny(payload_format) => topic_path.write_fmt( + format_args!("{}/create-from-csr/{}/+", Self::CERT_PREFIX, payload_format), ), + Topic::CreateCertificateFromCsrAccepted(payload_format) => { + topic_path.write_fmt(format_args!( + "{}/create-from-csr/{}/accepted", + Self::CERT_PREFIX, + payload_format + )) + } + Topic::CreateCertificateFromCsrRejected(payload_format) => { + topic_path.write_fmt(format_args!( + "{}/create-from-csr/{}/rejected", + Self::CERT_PREFIX, + payload_format + )) + } } .map_err(|_| Error::Overflow)?; Ok(topic_path) } } - -#[derive(Default)] -pub struct Subscribe<'a, const N: usize> { - topics: heapless::Vec<(Topic<'a>, QoS), N>, -} - -impl<'a, const N: usize> Subscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>, qos: QoS) -> Self { - // Ignore attempts to subscribe to outgoing topics - if topic.direction() != Direction::Incoming { - return self; - } - - if self.topics.iter().any(|(t, _)| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push((topic, qos)).ok(); - - Self { topics } - } - - pub fn topics(self) -> Result, QoS), N>, Error> { - self.topics - .iter() - .map(|(topic, qos)| Ok((topic.clone().format()?, *qos))) - .collect() - } - - pub fn send(self, mqtt: &M) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics()?; - - debug!("Subscribing! {:?}", topic_paths); - - let topics: heapless::Vec<_, N> = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - qos: *qos, - }) - .collect(); - - for t in topics.chunks(5) { - mqtt.subscribe(t)?; - } - Ok(()) - } -} - -#[derive(Default)] -pub struct Unsubscribe<'a, const N: usize> { - topics: heapless::Vec, N>, -} - -impl<'a, const N: usize> Unsubscribe<'a, N> { - pub fn new() -> Self { - Self::default() - } - - pub fn topic(self, topic: Topic<'a>) -> Self { - // Ignore attempts to subscribe to outgoing topics - if topic.direction() != Direction::Incoming { - return self; - } - - if self.topics.iter().any(|t| t == &topic) { - return self; - } - - let mut topics = self.topics; - topics.push(topic).ok(); - Self { topics } - } - - pub fn topics(self) -> Result, N>, Error> { - self.topics - .iter() - .map(|topic| topic.clone().format()) - .collect() - } - - pub fn send(self, mqtt: &M) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics()?; - let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - - for t in topics.chunks(5) { - mqtt.unsubscribe(t)?; - } - - Ok(()) - } -} diff --git a/src/shadows/README.md b/src/shadows/README.md index a1ec0b0..9ea3132 100644 --- a/src/shadows/README.md +++ b/src/shadows/README.md @@ -8,4 +8,4 @@ You can find an example of how to use this crate for iot shadow states in the `t pfx identity files can be created from a set of device certificate and private key using OpenSSL as: `openssl pkcs12 -export -out identity.pfx -inkey private.pem.key -in certificate.pem.crt -certfile root-ca.pem` -The example functions as a CI integration test, that is run against `Blackbirds` integration account on every PR. This test will run through a statemachine of shadow delete, updates and gets from both device & cloud side with assertions in between. +The example functions as a CI integration test, that is run against Factbirds integration account on every PR. This test will run through a statemachine of shadow delete, updates and gets from both device & cloud side with assertions in between. diff --git a/src/shadows/dao.rs b/src/shadows/dao.rs index 875c5c2..c18d6c5 100644 --- a/src/shadows/dao.rs +++ b/src/shadows/dao.rs @@ -2,53 +2,23 @@ use serde::{de::DeserializeOwned, Serialize}; use super::{Error, ShadowState}; -pub trait ShadowDAO { - fn read(&mut self) -> Result; - fn write(&mut self, state: &S) -> Result<(), Error>; +pub trait ShadowDAO { + async fn read(&mut self) -> Result; + async fn write(&mut self, state: &S) -> Result<(), Error>; } -impl ShadowDAO for () { - fn read(&mut self) -> Result { - Err(Error::NoPersistance) - } - - fn write(&mut self, _state: &S) -> Result<(), Error> { - Err(Error::NoPersistance) - } -} - -pub struct EmbeddedStorageDAO(T); - -impl From for EmbeddedStorageDAO -where - T: embedded_storage::Storage, -{ - fn from(v: T) -> Self { - Self::new(v) - } -} +const U32_SIZE: usize = 4; -impl EmbeddedStorageDAO +impl ShadowDAO for T where - T: embedded_storage::Storage, -{ - pub fn new(storage: T) -> Self { - Self(storage) - } -} - -const U32_SIZE: usize = core::mem::size_of::(); - -impl ShadowDAO for EmbeddedStorageDAO -where - S: ShadowState + DeserializeOwned, - T: embedded_storage::Storage, + S: ShadowState + Serialize + DeserializeOwned, + T: embedded_storage_async::nor_flash::NorFlash, [(); S::MAX_PAYLOAD_SIZE + U32_SIZE]:, { - fn read(&mut self) -> Result { + async fn read(&mut self) -> Result { let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; - self.0.read(OFFSET, buf).map_err(|_| Error::DaoRead)?; + self.read(0, buf).await.map_err(|_| Error::DaoRead)?; match buf[..U32_SIZE].try_into() { Ok(len_bytes) => { @@ -58,29 +28,28 @@ where } Ok( - serde_cbor::de::from_mut_slice::( - &mut buf[U32_SIZE..len as usize + U32_SIZE], - ) - .map_err(|_| Error::InvalidPayload)?, + minicbor_serde::from_slice(&buf[U32_SIZE..len as usize + U32_SIZE]) + .map_err(|_| Error::InvalidPayload)?, ) } _ => Err(Error::InvalidPayload), } } - fn write(&mut self, state: &S) -> Result<(), Error> { - assert!(S::MAX_PAYLOAD_SIZE <= self.0.capacity() - OFFSET as usize); + async fn write(&mut self, state: &S) -> Result<(), Error> { + assert!(S::MAX_PAYLOAD_SIZE <= self.capacity()); let buf = &mut [0u8; S::MAX_PAYLOAD_SIZE + U32_SIZE]; - let mut serializer = serde_cbor::ser::Serializer::new(serde_cbor::ser::SliceWrite::new( + let mut serializer = minicbor_serde::Serializer::new(minicbor::encode::write::Cursor::new( &mut buf[U32_SIZE..], - )) - .packed_format(); + )); + state .serialize(&mut serializer) .map_err(|_| Error::InvalidPayload)?; - let len = serializer.into_inner().bytes_written(); + + let len = serializer.into_encoder().writer().position(); if len > S::MAX_PAYLOAD_SIZE { return Err(Error::Overflow); @@ -88,11 +57,11 @@ where buf[..U32_SIZE].copy_from_slice(&(len as u32).to_le_bytes()); - self.0 - .write(OFFSET, &buf[..len + U32_SIZE]) + self.write(0, &buf[..len + U32_SIZE]) + .await .map_err(|_| Error::DaoWrite)?; - debug!("Wrote {} bytes to DAO @ {}", len + U32_SIZE, OFFSET); + debug!("Wrote {} bytes to DAO", len + U32_SIZE); Ok(()) } @@ -124,11 +93,11 @@ where #[cfg(any(feature = "std", test))] impl ShadowDAO for StdIODAO where - S: ShadowState + DeserializeOwned, + S: ShadowState + Serialize + DeserializeOwned, T: std::io::Write + std::io::Read, [(); S::MAX_PAYLOAD_SIZE]:, { - fn read(&mut self) -> Result { + async fn read(&mut self) -> Result { let bytes = &mut [0u8; S::MAX_PAYLOAD_SIZE]; self.0.read(bytes).map_err(|_| Error::DaoRead)?; @@ -136,7 +105,7 @@ where Ok(shadow) } - fn write(&mut self, state: &S) -> Result<(), Error> { + async fn write(&mut self, state: &S) -> Result<(), Error> { let bytes = serde_json_core::to_vec::<_, { S::MAX_PAYLOAD_SIZE }>(state) .map_err(|_| Error::Overflow)?; diff --git a/src/shadows/data_types.rs b/src/shadows/data_types.rs index 7a25453..18fd2a3 100644 --- a/src/shadows/data_types.rs +++ b/src/shadows/data_types.rs @@ -1,19 +1,13 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] pub enum Patch { - #[serde(rename = "unset")] + #[default] Unset, - #[serde(rename = "set")] Set(T), } -impl Default for Patch { - fn default() -> Self { - Self::Unset - } -} - impl Clone for Patch where T: Clone, @@ -32,22 +26,21 @@ impl From for Patch { } } -#[derive(Debug, Serialize, Deserialize)] -pub struct State { - #[serde(rename = "desired")] - pub desired: Option, - #[serde(rename = "reported")] - pub reported: Option, +#[derive(Serialize)] +pub struct RequestState { + #[serde(skip_serializing_if = "Option::is_none")] + pub desired: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reported: Option, } -#[derive(Debug, Serialize, Deserialize)] -pub struct DeltaState { - #[serde(rename = "desired")] - pub desired: Option, - #[serde(rename = "reported")] - pub reported: Option, - #[serde(rename = "delta")] - pub delta: Option, +#[derive(Deserialize)] +pub struct DeltaState { + pub desired: Option, + + pub reported: Option, + + pub delta: Option, } /// A request state document has the following format: @@ -61,13 +54,15 @@ pub struct DeltaState { /// response by the client token. /// - **version** — If used, the Device Shadow service processes the update only /// if the specified version matches the latest version it has. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize)] #[serde(default)] -pub struct Request<'a, T> { - pub state: State, +pub struct Request<'a, D, R> { + pub state: RequestState, + #[serde(rename = "clientToken")] #[serde(skip_serializing_if = "Option::is_none")] pub client_token: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] pub version: Option, } @@ -92,13 +87,16 @@ pub struct Request<'a, T> { /// - **version** — The current version of the document for the device's shadow /// shared in AWS IoT. It is increased by one over the previous version of the /// document. -#[derive(Debug, Serialize, Deserialize)] -pub struct AcceptedResponse<'a, T> { - pub state: DeltaState, +#[derive(Deserialize)] +pub struct AcceptedResponse<'a, D, R> { + pub state: DeltaState, + // pub metadata: Metadata<>. pub timestamp: u64, + #[serde(rename = "clientToken")] #[serde(skip_serializing_if = "Option::is_none")] pub client_token: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] pub version: Option, } @@ -115,9 +113,10 @@ pub struct AcceptedResponse<'a, T> { /// - **version** — The current version of the document for the device's shadow /// shared in AWS IoT. It is increased by one over the previous version of the /// document. -#[derive(Debug, Serialize, Deserialize)] -pub struct DeltaResponse<'a, T> { - pub state: Option, +#[derive(Deserialize)] +pub struct DeltaResponse<'a, U> { + pub state: Option, + // pub metadata: Metadata<>. pub timestamp: u64, #[serde(rename = "clientToken")] #[serde(skip_serializing_if = "Option::is_none")] @@ -172,7 +171,7 @@ mod tests { exp_map .0 .insert( - heapless::String::from("1"), + heapless::String::try_from("1").unwrap(), Patch::Set(Test { field: true }), ) .unwrap(); @@ -189,7 +188,7 @@ mod tests { exp_map .0 .insert( - heapless::String::from("1"), + heapless::String::try_from("1").unwrap(), Patch::Set(Test { field: true }), ) .unwrap(); @@ -215,7 +214,7 @@ mod tests { let mut exp_map = TestMap(heapless::LinearMap::default()); exp_map .0 - .insert(heapless::String::from("1"), Patch::Unset) + .insert(heapless::String::try_from("1").unwrap(), Patch::Unset) .unwrap(); let (patch, _) = serde_json_core::from_str::(payload).unwrap(); diff --git a/src/shadows/error.rs b/src/shadows/error.rs index f7cd84b..54bd0b1 100644 --- a/src/shadows/error.rs +++ b/src/shadows/error.rs @@ -1,9 +1,4 @@ use core::convert::TryFrom; -use core::fmt::Display; -use core::str::FromStr; - -use heapless::String; -use mqttrust::MqttError; use super::data_types::ErrorResponse; @@ -11,21 +6,15 @@ use super::data_types::ErrorResponse; #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Error { Overflow, - NoPersistance, + NoPersistence, DaoRead, DaoWrite, InvalidPayload, WrongShadowName, - Mqtt(MqttError), + MqttError(embedded_mqtt::Error), ShadowError(ShadowError), } -impl From for Error { - fn from(e: MqttError) -> Self { - Self::Mqtt(e) - } -} - impl From for Error { fn from(e: ShadowError) -> Self { Self::ShadowError(e) @@ -47,7 +36,6 @@ pub enum ShadowError { Unauthorized, Forbidden, NotFound, - NoNamedShadow(String<64>), VersionConflict, PayloadTooLarge, UnsupportedEncoding, @@ -70,7 +58,7 @@ impl ShadowError { ShadowError::Unauthorized => 401, ShadowError::Forbidden => 403, - ShadowError::NotFound | ShadowError::NoNamedShadow(_) => 404, + ShadowError::NotFound => 404, ShadowError::VersionConflict => 409, ShadowError::PayloadTooLarge => 413, ShadowError::UnsupportedEncoding => 415, @@ -85,7 +73,7 @@ impl<'a> TryFrom> for ShadowError { fn try_from(e: ErrorResponse<'a>) -> Result { Ok(match e.code { - 400 | 404 => Self::from_str(e.message)?, + 400 | 404 => ShadowError::NotFound, 401 => ShadowError::Unauthorized, 403 => ShadowError::Forbidden, 409 => ShadowError::VersionConflict, @@ -97,67 +85,3 @@ impl<'a> TryFrom> for ShadowError { }) } } - -impl Display for ShadowError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::InvalidJson => write!(f, "Invalid JSON"), - Self::MissingState => write!(f, "Missing required node: state"), - Self::MalformedState => write!(f, "State node must be an object"), - Self::MalformedDesired => write!(f, "Desired node must be an object"), - Self::MalformedReported => write!(f, "Reported node must be an object"), - Self::InvalidVersion => write!(f, "Invalid version"), - Self::InvalidClientToken => write!(f, "Invalid clientToken"), - Self::JsonTooDeep => { - write!(f, "JSON contains too many levels of nesting; maximum is 6") - } - Self::InvalidStateNode => write!(f, "State contains an invalid node"), - Self::Unauthorized => write!(f, "Unauthorized"), - Self::Forbidden => write!(f, "Forbidden"), - Self::NotFound => write!(f, "Thing not found"), - Self::NoNamedShadow(shadow_name) => { - write!(f, "No shadow exists with name: {}", shadow_name) - } - Self::VersionConflict => write!(f, "Version conflict"), - Self::PayloadTooLarge => write!(f, "The payload exceeds the maximum size allowed"), - Self::UnsupportedEncoding => write!( - f, - "Unsupported documented encoding; supported encoding is UTF-8" - ), - Self::TooManyRequests => write!(f, "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection"), - Self::InternalServerError => write!(f, "Internal service failure"), - } - } -} - -// TODO: This seems like an extremely brittle way of doing this??! -impl FromStr for ShadowError { - type Err = (); - - fn from_str(s: &str) -> Result { - Ok(match s.trim() { - "Invalid JSON" => Self::InvalidJson, - "Missing required node: state" => Self::MissingState, - "State node must be an object" => Self::MalformedState, - "Desired node must be an object" => Self::MalformedDesired, - "Reported node must be an object" => Self::MalformedReported, - "Invalid version" => Self::InvalidVersion, - "Invalid clientToken" => Self::InvalidClientToken, - "JSON contains too many levels of nesting; maximum is 6" => Self::JsonTooDeep, - "State contains an invalid node" => Self::InvalidStateNode, - "Unauthorized" => Self::Unauthorized, - "Forbidden" => Self::Forbidden, - "Thing not found" => Self::NotFound, - // TODO: - "No shadow exists with name: " => Self::NoNamedShadow(String::new()), - "Version conflict" => Self::VersionConflict, - "The payload exceeds the maximum size allowed" => Self::PayloadTooLarge, - "Unsupported documented encoding; supported encoding is UTF-8" => { - Self::UnsupportedEncoding - } - "The Device Shadow service will generate this error message when there are more than 10 in-flight requests on a single connection" => Self::TooManyRequests, - "Internal service failure" => Self::InternalServerError, - _ => return Err(()), - }) - } -} diff --git a/src/shadows/mod.rs b/src/shadows/mod.rs index 31ce599..22b6739 100644 --- a/src/shadows/mod.rs +++ b/src/shadows/mod.rs @@ -1,21 +1,25 @@ pub mod dao; pub mod data_types; -mod error; -mod shadow_diff; +pub mod error; pub mod topics; -use core::marker::PhantomData; +pub use rustot_derive; -use mqttrust::{Mqtt, QoS}; +use core::{marker::PhantomData, ops::DerefMut}; pub use data_types::Patch; +use embassy_sync::{ + blocking_mutex::raw::{NoopRawMutex, RawMutex}, + mutex::Mutex, +}; +use embedded_mqtt::{DeferredPayload, Publish, Subscribe, SubscribeTopic, ToPayload}; pub use error::Error; -use serde::de::DeserializeOwned; -pub use shadow_derive as derive; -pub use shadow_diff::ShadowPatch; +use serde::{de::DeserializeOwned, Serialize}; -use data_types::{AcceptedResponse, DeltaResponse, ErrorResponse}; -use topics::{Direction, Subscribe, Topic, Unsubscribe}; +use data_types::{ + AcceptedResponse, DeltaResponse, DeltaState, ErrorResponse, Request, RequestState, +}; +use topics::Topic; use self::dao::ShadowDAO; @@ -25,312 +29,443 @@ const CLASSIC_SHADOW: &str = "Classic"; pub trait ShadowState: ShadowPatch { const NAME: Option<&'static str>; + const PREFIX: &'static str = "$aws"; const MAX_PAYLOAD_SIZE: usize = 512; } -struct ShadowHandler<'a, M: Mqtt, S: ShadowState> -where - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, -{ - mqtt: &'a M, +pub trait ShadowPatch: Default + Clone + Sized { + // Contains all fields from `Self` as optionals + type Delta: DeserializeOwned + Serialize + Clone + Default; + + // Contains all fields from `Delta` + additional optional fields + type Reported: From + Serialize + Default; + + fn apply_patch(&mut self, delta: Self::Delta); +} + +struct ShadowHandler<'a, 'm, M: RawMutex, S> { + mqtt: &'m embedded_mqtt::MqttClient<'a, M>, + subscription: Mutex>>, _shadow: PhantomData, } -impl<'a, M: Mqtt, S: ShadowState> ShadowHandler<'a, M, S> -where - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, -{ - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - Subscribe::<7>::new() - .topic(Topic::GetAccepted, QoS::AtLeastOnce) - .topic(Topic::GetRejected, QoS::AtLeastOnce) - .topic(Topic::DeleteAccepted, QoS::AtLeastOnce) - .topic(Topic::DeleteRejected, QoS::AtLeastOnce) - .topic(Topic::UpdateAccepted, QoS::AtLeastOnce) - .topic(Topic::UpdateRejected, QoS::AtLeastOnce) - .topic(Topic::UpdateDelta, QoS::AtLeastOnce) - .send(self.mqtt, S::NAME)?; +impl<'a, M: RawMutex, S: ShadowState> ShadowHandler<'a, '_, M, S> { + async fn handle_delta(&self) -> Result, Error> { + // Loop to automatically retry on clean session + loop { + let mut sub_ref = self.subscription.lock().await; + + let delta_subscription = match sub_ref.deref_mut() { + Some(sub) => sub, + None => { + info!("Subscribing to delta topic"); + self.mqtt.wait_connected().await; + + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + topics::Topic::UpdateDelta + .format::<64>(S::PREFIX, self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build()]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + let _ = sub_ref.insert(sub); + + let delta_state = self.get_shadow().await?; + + return Ok(delta_state.delta); + } + }; - Ok(()) - } + let delta_message = match delta_subscription.next_message().await { + Some(msg) => msg, + None => { + // Clear subscription if we get clean session + info!( + "[{:?}] Clean session detected, resubscribing to delta topic", + S::NAME.unwrap_or(CLASSIC_SHADOW) + ); + sub_ref.take(); + // Drop the lock and continue the loop to retry + drop(sub_ref); + continue; + } + }; - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - Unsubscribe::<7>::new() - .topic(Topic::GetAccepted) - .topic(Topic::GetRejected) - .topic(Topic::DeleteAccepted) - .topic(Topic::DeleteRejected) - .topic(Topic::UpdateAccepted) - .topic(Topic::UpdateRejected) - .topic(Topic::UpdateDelta) - .send(self.mqtt, S::NAME)?; + // Update the device's state to match the desired state in the + // message body. + debug!( + "[{:?}] Received shadow delta event.", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); - Ok(()) - } + // Buffer to temporarily hold escaped characters data + let mut buf = [0u8; 64]; + + // Use from_slice_escaped to properly handle escaped characters + let (delta, _) = serde_json_core::from_slice_escaped::>( + delta_message.payload(), + &mut buf, + ) + .map_err(|_| Error::InvalidPayload)?; - /// Helper function to check whether a topic name is relevant for this - /// particular shadow. - pub fn should_handle_topic(&mut self, topic: &str) -> bool { - if let Some((_, thing_name, shadow_name)) = Topic::from_str(topic) { - return thing_name == self.mqtt.client_id() && shadow_name == S::NAME; + return Ok(delta.state); } - false } /// Internal helper function for applying a delta state to the actual shadow /// state, and update the cloud shadow. - fn change_shadow_value( - &mut self, - state: &mut S, - delta: Option, - update_desired: Option, - ) -> Result<(), Error> { - if let Some(ref delta) = delta { - state.apply_patch(delta.clone()); + async fn update_shadow( + &self, + desired: Option, + reported: Option, + ) -> Result, Error> { + debug!( + "[{:?}] Updating reported shadow value.", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + + if desired.is_some() && reported.is_some() { + // Do not edit both reported and desired at the same time + return Err(Error::ShadowError(error::ShadowError::Forbidden)); } - debug!( - "[{:?}] Updating reported shadow value. Update_desired: {:?}", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - update_desired + let request: Request<'_, S::Delta, S::Reported> = Request { + state: RequestState { desired, reported }, + client_token: Some(self.mqtt.client_id()), + version: None, + }; + + let payload = DeferredPayload::new( + |buf: &mut [u8]| { + serde_json_core::to_slice(&request, buf) + .map_err(|_| embedded_mqtt::EncodingError::BufferSize) + }, + S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD, ); - if let Some(update_desired) = update_desired { - let desired = if update_desired { Some(&state) } else { None }; + // Wait for mqtt to connect + self.mqtt.wait_connected().await; - let request = data_types::Request { - state: data_types::State { - reported: Some(&state), - desired, - }, - client_token: None, - version: None, - }; + let mut sub = self.publish_and_subscribe(Topic::Update, payload).await?; - let payload = serde_json_core::to_vec::< - _, - { S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD }, - >(&request) - .map_err(|_| Error::Overflow)?; + //*** WAIT RESPONSE ***/ + debug!("Wait for Accepted or Rejected"); - let update_topic = - Topic::Update.format::(self.mqtt.client_id(), S::NAME)?; - self.mqtt - .publish(update_topic.as_str(), &payload, QoS::AtLeastOnce)?; - } + loop { + let message = sub.next_message().await.ok_or(Error::InvalidPayload)?; - Ok(()) + match Topic::from_str(S::PREFIX, message.topic_name()) { + Some((Topic::UpdateAccepted, _, _)) => { + let mut buf = [0u8; 64]; + let (response, _) = serde_json_core::from_slice_escaped::< + // FIXME: + AcceptedResponse, + >(message.payload(), &mut buf) + .map_err(|_| Error::InvalidPayload)?; + + if response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Ok(response.state); + } + Some((Topic::UpdateRejected, _, _)) => { + let mut buf = [0u8; 64]; + let (error_response, _) = serde_json_core::from_slice_escaped::( + message.payload(), + &mut buf, + ) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.client_token != Some(self.mqtt.client_id()) { + continue; + } + + return Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )); + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + return Err(Error::WrongShadowName); + } + } + } } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - let get_topic = Topic::Get.format::(self.mqtt.client_id(), S::NAME)?; - self.mqtt - .publish(get_topic.as_str(), b"", QoS::AtLeastOnce)?; - Ok(()) + async fn get_shadow(&self) -> Result, Error> { + // Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let mut sub = self.publish_and_subscribe(Topic::Get, b"").await?; + + let get_message = sub.next_message().await.ok_or(Error::InvalidPayload)?; + + // Check if topic is GetAccepted + // Deserialize message + // Persist shadow and return new shadow + match Topic::from_str(S::PREFIX, get_message.topic_name()) { + Some((Topic::GetAccepted, _, _)) => { + let mut buf = [0u8; 64]; + let (response, _) = serde_json_core::from_slice_escaped::< + AcceptedResponse, + >(get_message.payload(), &mut buf) + .map_err(|_| Error::InvalidPayload)?; + + Ok(response.state) + } + Some((Topic::GetRejected, _, _)) => { + let mut buf = [0u8; 64]; + let (error_response, _) = serde_json_core::from_slice_escaped::( + get_message.payload(), + &mut buf, + ) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + if error_response.code == 404 { + debug!( + "[{:?}] Thing has no shadow document. Creating with defaults...", + S::NAME.unwrap_or(CLASSIC_SHADOW) + ); + self.create_shadow().await?; + } + + Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )) + } + _ => { + error!( + "Expected topic name to be GetRejected or GetAccepted but got something else" + ); + Err(Error::WrongShadowName) + } + } + } + + pub async fn delete_shadow(&self) -> Result<(), Error> { + // Wait for mqtt to connect + self.mqtt.wait_connected().await; + + let mut sub = self + .publish_and_subscribe(topics::Topic::Delete, b"") + .await?; + + let message = sub.next_message().await.ok_or(Error::InvalidPayload)?; + + // Check if topic is DeleteAccepted + match Topic::from_str(S::PREFIX, message.topic_name()) { + Some((Topic::DeleteAccepted, _, _)) => Ok(()), + Some((Topic::DeleteRejected, _, _)) => { + let mut buf = [0u8; 64]; + let (error_response, _) = serde_json_core::from_slice_escaped::( + message.payload(), + &mut buf, + ) + .map_err(|_| Error::ShadowError(error::ShadowError::NotFound))?; + + Err(Error::ShadowError( + error_response + .try_into() + .unwrap_or(error::ShadowError::NotFound), + )) + } + _ => { + error!("Expected Topic name GetRejected or GetAccepted but got something else"); + Err(Error::WrongShadowName) + } + } + } + + pub async fn create_shadow(&self) -> Result, Error> { + debug!( + "[{:?}] Creating initial shadow value.", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + + self.update_shadow(None, Some(S::Reported::default())).await } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - let delete_topic = Topic::Delete.format::(self.mqtt.client_id(), S::NAME)?; + /// This function will subscribe to accepted and rejected topics and then do a publish. + /// It will only return when something is accepted or rejected + /// Topic is the topic you want to publish to + /// The function will automatically subscribe to the accepted and rejected topic related to the publish topic + async fn publish_and_subscribe( + &self, + topic: topics::Topic, + payload: impl ToPayload, + ) -> Result, Error> { + let (accepted, rejected) = match topic { + Topic::Get => (Topic::GetAccepted, Topic::GetRejected), + Topic::Update => (Topic::UpdateAccepted, Topic::UpdateRejected), + Topic::Delete => (Topic::DeleteAccepted, Topic::DeleteRejected), + _ => return Err(Error::ShadowError(error::ShadowError::Forbidden)), + }; + + //*** SUBSCRIBE ***/ + let sub = self + .mqtt + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + accepted + .format::<65>(S::PREFIX, self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + rejected + .format::<65>(S::PREFIX, self.mqtt.client_id(), S::NAME)? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + //*** PUBLISH REQUEST ***/ + let topic_name = + topic.format::(S::PREFIX, self.mqtt.client_id(), S::NAME)?; self.mqtt - .publish(delete_topic.as_str(), b"", QoS::AtLeastOnce)?; - Ok(()) + .publish( + Publish::builder() + .topic_name(topic_name.as_str()) + .payload(payload) + .build(), + ) + .await + .map_err(Error::MqttError)?; + + Ok(sub) } } -pub struct PersistedShadow<'a, S: ShadowState + DeserializeOwned, M: Mqtt, D: ShadowDAO> -where - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, -{ - handler: ShadowHandler<'a, M, S>, - pub(crate) dao: D, +pub struct PersistedShadow<'a, 'm, S, M: RawMutex, D> { + handler: ShadowHandler<'a, 'm, M, S>, + pub(crate) dao: Mutex, } -impl<'a, S, M, D> PersistedShadow<'a, S, M, D> +impl<'a, 'm, S, M, D> PersistedShadow<'a, 'm, S, M, D> where - S: ShadowState + DeserializeOwned, - M: Mqtt, + S: ShadowState + Serialize + DeserializeOwned, + M: RawMutex, D: ShadowDAO, - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, { /// Instantiate a new shadow that will be automatically persisted to NVM /// based on the passed `DAO`. - pub fn new( - initial_state: S, - mqtt: &'a M, - mut dao: D, - auto_subscribe: bool, - ) -> Result { - if dao.read().is_err() { - dao.write(&initial_state)?; - } - + pub fn new(mqtt: &'m embedded_mqtt::MqttClient<'a, M>, dao: D) -> Self { let handler = ShadowHandler { mqtt, + subscription: Mutex::new(None), _shadow: PhantomData, }; - if auto_subscribe { - handler.subscribe()?; - } - Ok(Self { handler, dao }) - } - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - self.handler.subscribe() - } - - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - self.handler.unsubscribe() - } - - /// Helper function to check whether a topic name is relevant for this - /// particular shadow. - pub fn should_handle_topic(&mut self, topic: &str) -> bool { - self.handler.should_handle_topic(topic) + Self { + handler, + dao: Mutex::new(dao), + } } - /// Handle incomming publish messages from the cloud on any topics relevant - /// for this particular shadow. + /// Wait delta will subscribe if not already to Updatedelta and wait for changes /// - /// This function needs to be fed all relevant incoming MQTT payloads in - /// order for the shadow manager to work. - #[must_use] - pub fn handle_message( - &mut self, - topic: &str, - payload: &[u8], - ) -> Result<(S, Option), Error> { - let (topic, thing_name, shadow_name) = - Topic::from_str(topic).ok_or(Error::WrongShadowName)?; - - assert_eq!(thing_name, self.handler.mqtt.client_id()); - assert_eq!(topic.direction(), Direction::Incoming); - - if shadow_name != S::NAME { - return Err(Error::WrongShadowName); - } - - let mut state = self.dao.read()?; - - let delta = match topic { - Topic::GetAccepted => { - // The actions necessary to process the state document in the - // message body. - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(response, _)| { - if let Some(_) = response.state.delta { - debug!( - "[{:?}] Received delta state", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.handler.change_shadow_value( - &mut state, - response.state.delta.clone(), - Some(false), - )?; - } else if let Some(_) = response.state.reported { - self.handler.change_shadow_value( - &mut state, - response.state.reported, - None, - )?; - } - Ok(response.state.delta) - })? - } - Topic::GetRejected | Topic::UpdateRejected => { - // Respond to the error message in the message body. - if let Ok((error, _)) = serde_json_core::from_slice::(payload) { - if error.code == 404 && matches!(topic, Topic::GetRejected) { - debug!( - "[{:?}] Thing has no shadow document. Creating with defaults...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.report_shadow()?; - } else { - error!( - "{:?} request was rejected. code: {:?} message:'{:?}'", - if matches!(topic, Topic::GetRejected) { - "Get" - } else { - "Update" - }, - error.code, - error.message - ); - } - } - None + pub async fn wait_delta(&self) -> Result<(S, Option), Error> { + let mut dao = self.dao.lock().await; + + let mut state = match dao.read().await { + Ok(state) => state, + Err(_) => { + error!("Could not read state from flash writing default"); + let state = S::default(); + dao.write(&state).await?; + state } - Topic::UpdateDelta => { - // Update the device's state to match the desired state in the - // message body. - debug!( - "[{:?}] Received shadow delta event.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(delta, _)| { - if let Some(_) = delta.state { - debug!( - "[{:?}] Delta reports new desired value. Changing local value...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - } - self.handler.change_shadow_value( - &mut state, - delta.state.clone(), - Some(false), - )?; - Ok(delta.state) - })? - } - Topic::UpdateAccepted => { - // Confirm the updated data in the message body matches the - // device state. + }; - debug!( - "[{:?}] Finished updating reported shadow value.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); + // Drop the lock to avoid deadlock + drop(dao); - None - } - _ => None, - }; + let delta = self.handler.handle_delta().await?; // Something has changed as part of handling a message. Persist it // to NVM storage. - if delta.is_some() { - self.dao.write(&state)?; + if let Some(ref delta) = delta { + debug!( + "[{:?}] Delta reports new desired value. Changing local value...", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); + + state.apply_patch(delta.clone()); + + self.handler + .update_shadow(None, Some(state.clone().into())) + .await?; + + self.dao.lock().await.write(&state).await?; } Ok((state, delta)) } /// Get an immutable reference to the internal local state. - pub fn try_get(&mut self) -> Result { - self.dao.read() + pub async fn try_get(&self) -> Result { + self.dao.lock().await.read().await } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - self.handler.get_shadow() + pub async fn get_shadow(&self) -> Result { + let delta_state = self.handler.get_shadow().await?; + + debug!("Persisting new state after get shadow request"); + let mut state = self.dao.lock().await.read().await.unwrap_or_default(); + if let Some(delta) = delta_state.delta { + state.apply_patch(delta.clone()); + self.dao.lock().await.write(&state).await?; + self.handler + .update_shadow(None, Some(state.clone().into())) + .await?; + } + + Ok(state) } - /// Initiate an `UpdateShadow` request, reporting the local state to the cloud. - pub fn report_shadow(&mut self) -> Result<(), Error> { - let mut state = self.dao.read()?; - self.handler - .change_shadow_value(&mut state, None, Some(false))?; + /// Report the state of the shadow. + pub async fn report(&self) -> Result<(), Error> { + let mut dao = self.dao.lock().await; + + let state = match dao.read().await { + Ok(state) => state, + Err(_) => { + error!("Could not read state from flash writing default"); + let state = S::default(); + dao.write(&state).await?; + state + } + }; + + // Drop the lock to avoid deadlock + drop(dao); + + self.handler.update_shadow(None, Some(state.into())).await?; Ok(()) } @@ -340,179 +475,95 @@ where /// and depending on whether the state update is rejected or accepted, it /// will automatically update the local version after response /// - /// The returned `bool` from the update closure will determine wether the + /// The returned `bool` from the update closure will determine whether the /// update is persisted using the `DAO`, or just updated in the cloud. This /// can be handy for activity or status field updates that are not relevant - /// to store persistant on the device, but are required to be part of the + /// to store persistent on the device, but are required to be part of the /// same cloud shadow. - pub fn update bool>(&mut self, f: F) -> Result<(), Error> { - let mut desired = S::PatchState::default(); - let mut state = self.dao.read()?; - let should_persist = f(&state, &mut desired); + pub async fn update(&self, f: F) -> Result<(), Error> { + let mut update = S::Reported::default(); + let mut state = self.dao.lock().await.read().await?; + f(&state, &mut update); - self.handler - .change_shadow_value(&mut state, Some(desired), Some(false))?; + let response = self.handler.update_shadow(None, Some(update)).await?; + + if let Some(delta) = response.delta { + state.apply_patch(delta.clone()); + + self.dao.lock().await.write(&state).await?; + } + + Ok(()) + } - if should_persist { - self.dao.write(&state)?; + /// Updating desired should only be done on user requests e.g. button press or similar. + /// State changes within the device should only change reported state. + pub async fn update_desired(&self, f: F) -> Result<(), Error> { + let mut update = S::Delta::default(); + f(&mut update); + + let response = self.handler.update_shadow(Some(update), None).await?; + + if let Some(delta) = response.delta { + let mut state = self.dao.lock().await.read().await?; + state.apply_patch(delta.clone()); + self.dao.lock().await.write(&state).await?; } Ok(()) } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - self.handler.delete_shadow() + pub async fn delete_shadow(&self) -> Result<(), Error> { + self.handler.delete_shadow().await?; + self.dao.lock().await.write(&S::default()).await?; + Ok(()) } } -pub struct Shadow<'a, S: ShadowState, M: Mqtt> -where - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, -{ +pub struct Shadow<'a, 'm, S, M: RawMutex> { state: S, - handler: ShadowHandler<'a, M, S>, + handler: ShadowHandler<'a, 'm, M, S>, } -impl<'a, S, M> Shadow<'a, S, M> +impl<'a, 'm, S, M> Shadow<'a, 'm, S, M> where S: ShadowState, - M: Mqtt, - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, + M: RawMutex, { /// Instantiate a new non-persisted shadow - pub fn new(state: S, mqtt: &'a M, auto_subscribe: bool) -> Result { + pub fn new(state: S, mqtt: &'m embedded_mqtt::MqttClient<'a, M>) -> Self { let handler = ShadowHandler { mqtt, + subscription: Mutex::new(None), _shadow: PhantomData, }; - if auto_subscribe { - handler.subscribe()?; - } - Ok(Self { handler, state }) + Self { handler, state } } - /// Subscribes to all the topics required for keeping a shadow in sync - pub fn subscribe(&self) -> Result<(), Error> { - self.handler.subscribe() - } - - /// Unsubscribes from all the topics required for keeping a shadow in sync - pub fn unsubscribe(&self) -> Result<(), Error> { - self.handler.unsubscribe() - } - - /// Handle incomming publish messages from the cloud on any topics relevant + /// Handle incoming publish messages from the cloud on any topics relevant /// for this particular shadow. /// /// This function needs to be fed all relevant incoming MQTT payloads in /// order for the shadow manager to work. - #[must_use] - pub fn handle_message( - &mut self, - topic: &str, - payload: &[u8], - ) -> Result<(&S, Option), Error> { - let (topic, thing_name, shadow_name) = - Topic::from_str(topic).ok_or(Error::WrongShadowName)?; - - assert_eq!(thing_name, self.handler.mqtt.client_id()); - assert_eq!(topic.direction(), Direction::Incoming); - - if shadow_name != S::NAME { - return Err(Error::WrongShadowName); - } - - let delta = match topic { - Topic::GetAccepted => { - // The actions necessary to process the state document in the - // message body. - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(response, _)| { - if let Some(_) = response.state.delta { - debug!( - "[{:?}] Received delta state", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.handler.change_shadow_value( - &mut self.state, - response.state.delta.clone(), - Some(false), - )?; - } else if let Some(_) = response.state.reported { - self.handler.change_shadow_value( - &mut self.state, - response.state.reported, - None, - )?; - } - Ok(response.state.delta) - })? - } - Topic::GetRejected | Topic::UpdateRejected => { - // Respond to the error message in the message body. - if let Ok((error, _)) = serde_json_core::from_slice::(payload) { - if error.code == 404 && matches!(topic, Topic::GetRejected) { - debug!( - "[{:?}] Thing has no shadow document. Creating with defaults...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); - self.report_shadow()?; - } else { - error!( - "{:?} request was rejected. code: {:?} message:'{:?}'", - if matches!(topic, Topic::GetRejected) { - "Get" - } else { - "Update" - }, - error.code, - error.message - ); - } - } - None - } - Topic::UpdateDelta => { - // Update the device's state to match the desired state in the - // message body. - debug!( - "[{:?}] Received shadow delta event.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); + pub async fn wait_delta(&mut self) -> Result<(&S, Option), Error> { + let delta = self.handler.handle_delta().await?; - serde_json_core::from_slice::>(payload) - .map_err(|_| Error::InvalidPayload) - .and_then(|(delta, _)| { - if let Some(_) = delta.state { - debug!( - "[{:?}] Delta reports new desired value. Changing local value...", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), - ); - } - self.handler.change_shadow_value( - &mut self.state, - delta.state.clone(), - Some(false), - )?; - Ok(delta.state) - })? - } - Topic::UpdateAccepted => { - // Confirm the updated data in the message body matches the - // device state. + // Something has changed as part of handling a message. Persist it + // to NVM storage. + if let Some(ref delta) = delta { + debug!( + "[{:?}] Delta reports new desired value. Changing local value...", + S::NAME.unwrap_or(CLASSIC_SHADOW), + ); - debug!( - "[{:?}] Finished updating reported shadow value.", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW) - ); + self.state.apply_patch(delta.clone()); - None - } - _ => None, - }; + self.handler + .update_shadow(None, Some(self.state.clone().into())) + .await?; + } - Ok((self.get(), delta)) + Ok((&self.state, delta)) } /// Get an immutable reference to the internal local state. @@ -520,10 +571,11 @@ where &self.state } - /// Initiate an `UpdateShadow` request, reporting the local state to the cloud. - pub fn report_shadow(&mut self) -> Result<(), Error> { + /// Report the state of the shadow. + pub async fn report(&mut self) -> Result<(), Error> { self.handler - .change_shadow_value(&mut self.state, None, Some(false))?; + .update_shadow(None, Some(self.state.clone().into())) + .await?; Ok(()) } @@ -532,48 +584,59 @@ where /// This function will update the desired state of the shadow in the cloud, /// and depending on whether the state update is rejected or accepted, it /// will automatically update the local version after response - pub fn update(&mut self, f: F) -> Result<(), Error> { - let mut desired = S::PatchState::default(); - f(&self.state, &mut desired); + pub async fn update(&mut self, f: F) -> Result<(), Error> { + let mut update = S::Reported::default(); + f(&self.state, &mut update); - self.handler - .change_shadow_value(&mut self.state, Some(desired), Some(false))?; + let response = self.handler.update_shadow(None, Some(update)).await?; + + if let Some(delta) = response.delta { + self.state.apply_patch(delta.clone()); + } Ok(()) } /// Initiate a `GetShadow` request, updating the local state from the cloud. - pub fn get_shadow(&self) -> Result<(), Error> { - self.handler.get_shadow() + pub async fn get_shadow(&mut self) -> Result<&S, Error> { + let delta_state = self.handler.get_shadow().await?; + + debug!("Persisting new state after get shadow request"); + if let Some(delta) = delta_state.delta { + self.state.apply_patch(delta.clone()); + self.handler + .update_shadow(None, Some(self.state.clone().into())) + .await?; + } + + Ok(&self.state) } - pub fn delete_shadow(&mut self) -> Result<(), Error> { - self.handler.delete_shadow() + pub async fn delete_shadow(&mut self) -> Result<(), Error> { + self.handler.delete_shadow().await } } -impl<'a, S, M> core::fmt::Debug for Shadow<'a, S, M> +impl core::fmt::Debug for Shadow<'_, '_, S, M> where S: ShadowState + core::fmt::Debug, - M: Mqtt, - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, + M: RawMutex, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!( f, "[{:?}] = {:?}", - S::NAME.unwrap_or_else(|| CLASSIC_SHADOW), + S::NAME.unwrap_or(CLASSIC_SHADOW), self.get() ) } } #[cfg(feature = "defmt")] -impl<'a, S, M> defmt::Format for Shadow<'a, S, M> +impl<'a, 'm, S, M> defmt::Format for Shadow<'a, 'm, S, M> where S: ShadowState + defmt::Format, - M: Mqtt, - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, + M: RawMutex, { fn format(&self, fmt: defmt::Formatter) { defmt::write!( @@ -585,19 +648,59 @@ where } } -impl<'a, S, M> Drop for Shadow<'a, S, M> -where - S: ShadowState, - M: Mqtt, - [(); S::MAX_PAYLOAD_SIZE + PARTIAL_REQUEST_OVERHEAD]:, -{ - fn drop(&mut self) { - self.unsubscribe().ok(); +#[cfg(test)] +mod tests { + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] + struct TestDelta { + field: heapless::String<20>, + } + + #[test] + fn test_from_slice_escaped() { + let delta_message = b"{\"field\":\"\\\\HELLO WORLD\"}"; // FROM 4 backslashes in my string is saved in json as 2 backslashes which will be deserialized to 1 backslash + let mut buf = [0u8; 64]; + + let (delta, _) = serde_json_core::from_slice_escaped::(delta_message, &mut buf) + .expect("Failed to deserialize"); + + println!("{}", delta.field); + + assert_eq!(delta.field.as_str(), "\\HELLO WORLD"); + } + + #[test] + fn test_to_slice_escaping() { + // Create a struct with a backslash in the string + let mut test_data = TestDelta::default(); + test_data.field.push_str("\\HELLO WORLD").unwrap(); + + let mut output = [0u8; 128]; + let bytes_written = + serde_json_core::to_slice(&test_data, &mut output).expect("Failed to serialize"); + + let serialized = &output[..bytes_written]; + let json_str = core::str::from_utf8(serialized).unwrap(); + println!("Serialized JSON: {}", json_str); + + // The JSON should contain \\ (escaped backslash) + assert!( + json_str.contains("\\\\"), + "JSON should contain escaped backslash" + ); + assert_eq!(json_str, r#"{"field":"\\HELLO WORLD"}"#); + + // Now test round-trip: deserialize it back + let mut buf = [0u8; 64]; + let (deserialized, _) = + serde_json_core::from_slice_escaped::(serialized, &mut buf) + .expect("Failed to deserialize"); + + assert_eq!(deserialized.field.as_str(), "\\HELLO WORLD"); } } -// #[cfg(test)] -// mod tests { // use super::*; // use crate as rustot; // use crate::test::MockMqtt; diff --git a/src/shadows/shadow_diff/_impl.rs b/src/shadows/shadow_diff/_impl.rs deleted file mode 100644 index 8854f38..0000000 --- a/src/shadows/shadow_diff/_impl.rs +++ /dev/null @@ -1,55 +0,0 @@ -use serde::{de::DeserializeOwned, Serialize}; - -use crate::shadows::data_types::Patch; - -use super::ShadowPatch; - -macro_rules! impl_shadow_patch { - ($($ident: ty),+) => { - $( - impl ShadowPatch for $ident { - type PatchState = $ident; - - fn apply_patch(&mut self, opt: Self::PatchState) { - *self = opt; - } - } - )+ - }; -} - -// Rust primitive types: https://doc.rust-lang.org/reference/types.html -impl_shadow_patch!(bool); -impl_shadow_patch!(u8, u16, u32, u64, u128, usize); -impl_shadow_patch!(i8, i16, i32, i64, i128, isize); -impl_shadow_patch!(f32, f64); -impl_shadow_patch!(char); - -impl ShadowPatch for Option { - type PatchState = Patch; - - fn apply_patch(&mut self, opt: Self::PatchState) { - if let Patch::Set(v) = opt { - *self = Some(v); - } else { - *self = None; - } - } -} - -// Heapless stuff -impl ShadowPatch for heapless::String { - type PatchState = heapless::String; - - fn apply_patch(&mut self, opt: Self::PatchState) { - *self = opt; - } -} - -impl ShadowPatch for heapless::Vec { - type PatchState = heapless::Vec; - - fn apply_patch(&mut self, opt: Self::PatchState) { - *self = opt; - } -} diff --git a/src/shadows/shadow_diff/mod.rs b/src/shadows/shadow_diff/mod.rs deleted file mode 100644 index 3de0c0f..0000000 --- a/src/shadows/shadow_diff/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod _impl; - -use serde::{de::DeserializeOwned, Serialize}; - -pub trait ShadowPatch: Serialize { - type PatchState: Serialize + DeserializeOwned + Default + Clone; - - fn apply_patch(&mut self, opt: Self::PatchState); -} diff --git a/src/shadows/topics.rs b/src/shadows/topics.rs index c73e35a..630d002 100644 --- a/src/shadows/topics.rs +++ b/src/shadows/topics.rs @@ -2,8 +2,8 @@ use core::fmt::Write; +use embedded_mqtt::QoS; use heapless::String; -use mqttrust::{Mqtt, QoS, SubscribeTopic}; use crate::jobs::MAX_THING_NAME_LEN; @@ -33,19 +33,22 @@ pub enum Topic { UpdateRejected, DeleteAccepted, DeleteRejected, + Any, } impl Topic { - const PREFIX: &'static str = "$aws/things"; + const PREFIX: &'static str = "things"; const SHADOW: &'static str = "shadow"; - pub fn from_str(s: &str) -> Option<(Self, &str, Option<&str>)> { + pub fn from_str<'a>(prefix: &str, s: &'a str) -> Option<(Self, &'a str, Option<&'a str>)> { let tt = s.splitn(9, '/').collect::>(); - match (tt.get(0), tt.get(1), tt.get(2), tt.get(3)) { - (Some(&"$aws"), Some(&"things"), Some(thing_name), Some(&Self::SHADOW)) => { + match (tt.first(), tt.get(1), tt.get(2), tt.get(3)) { + (Some(tt_prefix), Some(tt_things), Some(thing_name), Some(&Self::SHADOW)) + if *tt_prefix == prefix && *tt_things == Self::PREFIX => + { // This is a shadow topic, now figure out which one. let (shadow_name, next_index) = if let Some(&"name") = tt.get(4) { - (tt.get(5).map(|s| *s), 6) + (tt.get(5).copied(), 6) } else { (None, 4) }; @@ -92,6 +95,7 @@ impl Topic { pub fn format( &self, + prefix: &str, thing_name: &str, shadow_name: Option<&'static str>, ) -> Result, Error> { @@ -100,7 +104,8 @@ impl Topic { let mut topic_path = String::new(); match self { Self::Get => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/get", + "{}/{}/{}/{}{}{}/get", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -108,7 +113,8 @@ impl Topic { shadow_name )), Self::Update => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/update", + "{}/{}/{}/{}{}{}/update", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -116,7 +122,8 @@ impl Topic { shadow_name )), Self::Delete => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/delete", + "{}/{}/{}/{}{}{}/delete", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -125,7 +132,8 @@ impl Topic { )), Self::GetAccepted => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/get/accepted", + "{}/{}/{}/{}{}{}/get/accepted", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -133,7 +141,8 @@ impl Topic { shadow_name )), Self::GetRejected => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/get/rejected", + "{}/{}/{}/{}{}{}/get/rejected", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -141,7 +150,8 @@ impl Topic { shadow_name )), Self::UpdateDelta => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/update/delta", + "{}/{}/{}/{}{}{}/update/delta", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -149,7 +159,8 @@ impl Topic { shadow_name )), Self::UpdateAccepted => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/update/accepted", + "{}/{}/{}/{}{}{}/update/accepted", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -157,7 +168,8 @@ impl Topic { shadow_name )), Self::UpdateDocuments => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/update/documents", + "{}/{}/{}/{}{}{}/update/documents", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -165,7 +177,8 @@ impl Topic { shadow_name )), Self::UpdateRejected => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/update/rejected", + "{}/{}/{}/{}{}{}/update/rejected", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -173,7 +186,8 @@ impl Topic { shadow_name )), Self::DeleteAccepted => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/delete/accepted", + "{}/{}/{}/{}{}{}/delete/accepted", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -181,7 +195,17 @@ impl Topic { shadow_name )), Self::DeleteRejected => topic_path.write_fmt(format_args!( - "{}/{}/{}{}{}/delete/rejected", + "{}/{}/{}/{}{}{}/delete/rejected", + prefix, + Self::PREFIX, + thing_name, + Self::SHADOW, + name_prefix, + shadow_name + )), + Self::Any => topic_path.write_fmt(format_args!( + "{}/{}/{}/{}{}{}/#", + prefix, Self::PREFIX, thing_name, Self::SHADOW, @@ -223,6 +247,7 @@ impl Subscribe { pub fn topics( self, + prefix: &str, thing_name: &str, shadow_name: Option<&'static str>, ) -> Result, QoS), N>, Error> { @@ -230,32 +255,9 @@ impl Subscribe { self.topics .iter() - .map(|(topic, qos)| Ok((Topic::from(*topic).format(thing_name, shadow_name)?, *qos))) + .map(|(topic, qos)| Ok(((*topic).format(prefix, thing_name, shadow_name)?, *qos))) .collect() } - - pub fn send(self, mqtt: &M, shadow_name: Option<&'static str>) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics(mqtt.client_id(), shadow_name)?; - - let topics: heapless::Vec<_, N> = topic_paths - .iter() - .map(|(s, qos)| SubscribeTopic { - topic_path: s.as_str(), - qos: *qos, - }) - .collect(); - - debug!("Subscribing!"); - - for t in topics.chunks(5) { - mqtt.subscribe(t)?; - } - Ok(()) - } } #[derive(Default)] @@ -285,6 +287,7 @@ impl Unsubscribe { pub fn topics( self, + prefix: &str, thing_name: &str, shadow_name: Option<&'static str>, ) -> Result, N>, Error> { @@ -292,22 +295,7 @@ impl Unsubscribe { self.topics .iter() - .map(|topic| Topic::from(*topic).format(thing_name, shadow_name)) + .map(|topic| (*topic).format(prefix, thing_name, shadow_name)) .collect() } - - pub fn send(self, mqtt: &M, shadow_name: Option<&'static str>) -> Result<(), Error> { - if self.topics.is_empty() { - return Ok(()); - } - - let topic_paths = self.topics(mqtt.client_id(), shadow_name)?; - let topics: heapless::Vec<_, N> = topic_paths.iter().map(|s| s.as_str()).collect(); - - for t in topics.chunks(5) { - mqtt.unsubscribe(t)?; - } - - Ok(()) - } } diff --git a/src/test/mod.rs b/src/test/mod.rs deleted file mode 100644 index e27c028..0000000 --- a/src/test/mod.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::{cell::RefCell, collections::VecDeque}; - -use mqttrust::{encoding::v4::encode_slice, Mqtt, MqttError, Packet}; - -/// -/// Mock Mqtt client used for unit tests. Implements `mqttrust::Mqtt` trait. -/// -pub struct MockMqtt { - pub tx: RefCell>>, - publish_fail: bool, -} - -impl MockMqtt { - pub fn new() -> Self { - Self { - tx: RefCell::new(VecDeque::new()), - publish_fail: false, - } - } - - pub fn publish_fail(&mut self) { - self.publish_fail = true; - } -} - -impl Mqtt for MockMqtt { - fn send(&self, packet: Packet<'_>) -> Result<(), MqttError> { - let v = &mut [0u8; 1024]; - - let len = encode_slice(&packet, v).map_err(|_| MqttError::Full)?; - let packet = v[..len].iter().cloned().collect(); - self.tx.borrow_mut().push_back(packet); - - Ok(()) - } - - fn client_id(&self) -> &str { - "test_client" - } -} diff --git a/tests/README.md b/tests/README.md index da1c761..32d9cff 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,14 +1,87 @@ -This folder contains a number of examples that shows how to use this crate. +# AWS IoT Rust Examples -
+This repository contains examples demonstrating how to use the AWS IoT SDK for Rust. These examples are also integrated into our CI pipeline as integration tests. + +## Examples ### AWS IoT Fleet Provisioning (`provisioning.rs`) -
-This example can be run by `RUST_LOG=trace AWS_HOSTNAME=xxxxxxxx-ats.iot.eu-west-1.amazonaws.com cargo r --example provisioning --features log`, assuming you have an `examples/secrets/claim_identity.pfx` file with the claiming credentials. +This example demonstrates how to use the AWS IoT Fleet Provisioning service to provision a device. + +**Requirements:** + +* An AWS account with AWS IoT Core and AWS IoT Fleet Provisioning configured. +* A device certificate and private key. You can generate these using OpenSSL or your preferred method. +* A provisioning template configured in your AWS account. + +**To run the example:** + +1. **Create a PKCS #12 (.pfx) identity file:** + If you haven't already, create a PKCS #12 (.pfx) file containing your device certificate and private key. You can use OpenSSL for this: + + ```bash + openssl pkcs12 -export -out claim_identity.pfx -inkey private.pem.key -in certificate.pem.crt -certfile root-ca.pem + ``` + Replace `private.pem.key`, `certificate.pem.crt`, and `root-ca.pem` with your actual file names. + +2. **Store the Identity File:** + Place the `claim_identity.pfx` file in the `tests/secrets/` directory. + +3. **Set Environment Variables:** + Set the following environment variables: + * `IDENTITY_PASSWORD`: The password you set for the `claim_identity.pfx` file. + * `AWS_HOSTNAME`: Your AWS IoT endpoint. You can find this in the AWS IoT console. + +4. **Run the Test:** + + ```bash + cargo test --test provisioning --features "log,std" + ``` + +### AWS IoT OTA (`ota_mqtt.rs`) + +This example demonstrates how to perform an over-the-air (OTA) firmware update using AWS IoT Jobs. + +**Requirements:** + +* An AWS account with AWS IoT Core and AWS IoT Jobs configured. +* A device certificate and private key. +* A PKCS #12 (.pfx) file containing the device certificate and private key (see previous example for creation instructions). +* An OTA update job created in your AWS account. + +**To run the example:** + +1. **Create an OTA Job:** Create an OTA update job. You can find instructions on how to do this in the AWS IoT documentation or refer to the `scripts/create_ota.sh` script for inspiration. +2. **Store the Identity File:** Ensure the `identity.pfx` file (containing your device certificate and private key) is located in the `tests/secrets/` directory. +3. **Set Environment Variables:** + * `IDENTITY_PASSWORD`: The password for your `identity.pfx` file. + * `AWS_HOSTNAME`: Your AWS IoT endpoint. + +4. **Run the Test:** + + ```bash + cargo test --test ota_mqtt --features "log,std" + ``` + +### AWS IoT Shadows (`shadows.rs`) + +This example demonstrates how to interact with AWS IoT device shadows. Device shadows allow you to store and retrieve the latest state of your devices even when they are offline. + +**Requirements:** + +* An AWS account with AWS IoT Core and AWS IoT Device Shadows configured. +* A device certificate and private key. +* A PKCS #12 (.pfx) file containing the device certificate and private key (see previous examples for creation instructions). + +**To run the example:** + +1. **Store the Identity File:** Ensure the `claim_identity.pfx` file (containing your device certificate and private key) is in the `tests/secrets/` directory. +2. **Set Environment Variables:** + * `IDENTITY_PASSWORD`: The password for your `claim_identity.pfx` file. + * `AWS_HOSTNAME`: Your AWS IoT endpoint. -pfx identity files can be created from a set of device certificate and private key using OpenSSL as: `openssl pkcs12 -export -out claim_identity.pfx -inkey private.pem.key -in certificate.pem.crt -certfile root-ca.pem` -
-
+3. **Run the Test:** -### AWS IoT OTA (`ota.rs`) \ No newline at end of file + ```bash + cargo test --test shadows --features "log,std" + ``` diff --git a/tests/common/clock.rs b/tests/common/clock.rs deleted file mode 100644 index c486cc1..0000000 --- a/tests/common/clock.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::time::{SystemTime, UNIX_EPOCH}; - -pub struct SysClock { - start_time: u32, - end_time: Option>, -} - -impl SysClock { - pub fn new() -> Self { - Self { - start_time: Self::epoch(), - end_time: None, - } - } - - pub fn epoch() -> u32 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis() as u32 - } - - pub fn now(&self) -> u32 { - Self::epoch() - self.start_time - } -} - -impl fugit_timer::Timer<1000> for SysClock { - type Error = std::convert::Infallible; - - fn now(&mut self) -> fugit_timer::TimerInstantU32<1000> { - fugit_timer::TimerInstantU32::from_ticks(SysClock::now(self)) - } - - fn start(&mut self, duration: fugit_timer::TimerDurationU32<1000>) -> Result<(), Self::Error> { - let now = self.now(); - self.end_time.replace(now + duration); - Ok(()) - } - - fn cancel(&mut self) -> Result<(), Self::Error> { - self.end_time.take(); - Ok(()) - } - - fn wait(&mut self) -> nb::Result<(), Self::Error> { - match self.end_time.map(|end| end <= self.now()) { - Some(true) => { - self.end_time.take(); - Ok(()) - } - _ => Err(nb::Error::WouldBlock), - } - } -} diff --git a/tests/common/credentials.rs b/tests/common/credentials.rs index 6b9c7e3..b1bb872 100644 --- a/tests/common/credentials.rs +++ b/tests/common/credentials.rs @@ -4,6 +4,7 @@ use native_tls::{Certificate, Identity}; use p256::ecdsa::SigningKey; use pkcs8::DecodePrivateKey; +#[allow(dead_code)] pub fn identity() -> (&'static str, Identity) { let thing_name = option_env!("THING_NAME").unwrap_or_else(|| "rustot-test"); let pw = env::var("IDENTITY_PASSWORD").unwrap_or_default(); @@ -13,6 +14,7 @@ pub fn identity() -> (&'static str, Identity) { ) } +#[allow(dead_code)] pub fn claim_identity() -> (&'static str, Identity) { let thing_name = option_env!("THING_NAME").unwrap_or_else(|| "rustot-provision"); let pw = env::var("IDENTITY_PASSWORD").unwrap_or_default(); @@ -27,6 +29,7 @@ pub fn root_ca() -> Certificate { Certificate::from_pem(include_bytes!("../secrets/root-ca.pem")).unwrap() } +#[allow(dead_code)] pub fn signing_key() -> SigningKey { let pw = env::var("IDENTITY_PASSWORD").unwrap_or_default(); SigningKey::from_pkcs8_encrypted_pem(include_str!("../secrets/sign_private.pem"), pw).unwrap() diff --git a/tests/common/file_handler.rs b/tests/common/file_handler.rs index 9cae2e6..d4569f0 100644 --- a/tests/common/file_handler.rs +++ b/tests/common/file_handler.rs @@ -1,91 +1,156 @@ -use rustot::ota::pal::{OtaPal, OtaPalError, PalImageState}; -use std::fs::File; -use std::io::{Cursor, Write}; +use core::ops::Deref; +use embedded_storage_async::nor_flash::{ErrorType, NorFlash, ReadNorFlash}; +use rustot::ota::{ + encoding::json, + pal::{OtaPal, OtaPalError, PalImageState}, +}; +use sha2::{Digest, Sha256}; +use std::{ + convert::Infallible, + io::{Cursor, Write}, +}; + +#[derive(Debug, PartialEq, Eq)] +pub enum State { + Swap, + Boot, +} + +pub struct BlockFile { + filebuf: Cursor>, +} + +impl NorFlash for BlockFile { + const WRITE_SIZE: usize = 1; + + const ERASE_SIZE: usize = 1; + + async fn erase(&mut self, _from: u32, _to: u32) -> Result<(), Self::Error> { + Ok(()) + } + + async fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> { + self.filebuf.set_position(offset as u64); + self.filebuf.write_all(bytes).unwrap(); + Ok(()) + } +} + +impl ReadNorFlash for BlockFile { + const READ_SIZE: usize = 1; + + async fn read(&mut self, _offset: u32, _bytes: &mut [u8]) -> Result<(), Self::Error> { + todo!() + } + + fn capacity(&self) -> usize { + self.filebuf.get_ref().capacity() + } +} + +impl ErrorType for BlockFile { + type Error = Infallible; +} pub struct FileHandler { - filebuf: Option>>, + filebuf: Option, + compare_file_path: String, + pub plateform_state: State, } impl FileHandler { - pub fn new() -> Self { - FileHandler { filebuf: None } + #[allow(dead_code)] + pub fn new(compare_file_path: String) -> Self { + FileHandler { + filebuf: None, + compare_file_path, + plateform_state: State::Boot, + } } } impl OtaPal for FileHandler { - type Error = (); + type BlockWriter = BlockFile; - fn abort( + async fn abort( &mut self, _file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { + ) -> Result<(), OtaPalError> { Ok(()) } - fn create_file_for_rx( + async fn create_file_for_rx( &mut self, file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { - self.filebuf = Some(Cursor::new(Vec::with_capacity(file.filesize))); - Ok(()) + ) -> Result<&mut Self::BlockWriter, OtaPalError> { + Ok(self.filebuf.get_or_insert(BlockFile { + filebuf: Cursor::new(Vec::with_capacity(file.filesize)), + })) } - fn get_platform_image_state(&mut self) -> Result> { - Ok(PalImageState::Valid) + async fn get_platform_image_state(&mut self) -> Result { + Ok(match self.plateform_state { + State::Swap => PalImageState::PendingCommit, + State::Boot => PalImageState::Valid, + }) } - fn set_platform_image_state( + async fn set_platform_image_state( &mut self, - _image_state: rustot::ota::pal::ImageState<()>, - ) -> Result<(), OtaPalError> { + image_state: rustot::ota::pal::ImageState, + ) -> Result<(), OtaPalError> { + if matches!(image_state, rustot::ota::pal::ImageState::Accepted) { + self.plateform_state = State::Boot; + } + Ok(()) } - fn reset_device(&mut self) -> Result<(), OtaPalError> { + async fn reset_device(&mut self) -> Result<(), OtaPalError> { Ok(()) } - fn close_file( + async fn close_file( &mut self, file: &rustot::ota::encoding::FileContext, - ) -> Result<(), OtaPalError> { + ) -> Result<(), OtaPalError> { if let Some(ref mut buf) = &mut self.filebuf { log::debug!( "Closing completed file. Len: {}/{} -> {}", - buf.get_ref().len(), + buf.filebuf.get_ref().len(), file.filesize, file.filepath.as_str() ); - let mut file = - File::create(file.filepath.as_str()).map_err(|_| OtaPalError::FileWriteFailed)?; - file.write_all(buf.get_ref()) - .map_err(|_| OtaPalError::FileWriteFailed)?; - Ok(()) - } else { - Err(OtaPalError::BadFileHandle) - } - } + let expected_data = std::fs::read(self.compare_file_path.as_str()).unwrap(); + let mut expected_hasher = ::new(); + expected_hasher.update(&expected_data); + let expected_hash = expected_hasher.finalize(); - fn write_block( - &mut self, - _file: &rustot::ota::encoding::FileContext, - block_offset: usize, - block_payload: &[u8], - ) -> Result> { - if let Some(ref mut buf) = &mut self.filebuf { - buf.set_position(block_offset as u64); - buf.write(block_payload) - .map_err(|_e| OtaPalError::FileWriteFailed)?; - Ok(block_payload.len()) + log::info!( + "Comparing {:?} with {:?}", + self.compare_file_path, + file.filepath.as_str() + ); + assert_eq!(buf.filebuf.get_ref().len(), file.filesize); + + let mut hasher = ::new(); + hasher.update(buf.filebuf.get_ref()); + assert_eq!(hasher.finalize().deref(), expected_hash.deref()); + + // Check file signature + let signature = match file.signature.as_ref() { + Some(json::Signature::Sha256Ecdsa(ref s)) => s.as_str(), + sig => panic!("Unexpected signature format! {:?}", sig), + }; + + assert_eq!(signature, "This is my custom signature\\n"); + + self.plateform_state = State::Swap; + + Ok(()) } else { Err(OtaPalError::BadFileHandle) } } - - fn get_active_firmware_version( - &self, - ) -> Result> { - Ok(rustot::ota::pal::Version::new(0, 1, 0)) - } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index cd087b0..fc87c3f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,4 +1,3 @@ -pub mod clock; pub mod credentials; pub mod file_handler; pub mod network; diff --git a/tests/common/network.rs b/tests/common/network.rs index 968a093..d693822 100644 --- a/tests/common/network.rs +++ b/tests/common/network.rs @@ -1,252 +1,146 @@ -use embedded_nal::{AddrType, Dns, IpAddr, SocketAddr, TcpClientStack}; -use native_tls::{MidHandshakeTlsStream, TlsConnector, TlsStream}; -use std::io::{Read, Write}; -use std::marker::PhantomData; -use std::net::TcpStream; - -use dns_lookup::{lookup_addr, lookup_host}; - -/// An std::io::Error compatible error type returned when an operation is requested in the wrong -/// sequence (where the "right" is create a socket, connect, any receive/send, and possibly close). -#[derive(Debug)] -struct OutOfOrder; - -impl std::fmt::Display for OutOfOrder { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Out of order operations requested") - } -} +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -impl std::error::Error for OutOfOrder {} +use ::native_tls::Identity; +use embedded_io_adapters::tokio_1::FromTokio; +use embedded_nal_async::{AddrType, Dns, TcpConnect}; +use tokio_native_tls::native_tls; -impl Into> for OutOfOrder { - fn into(self) -> std::io::Result { - Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - OutOfOrder, - )) - } -} +use super::credentials; -pub struct Network { - tls_connector: Option<(TlsConnector, String)>, - _sec: PhantomData, -} +#[derive(Debug, Clone, Copy)] +pub struct Network; -impl Network> { - pub fn new_tls(tls_connector: TlsConnector, hostname: String) -> Self { - Self { - tls_connector: Some((tls_connector, hostname)), - _sec: PhantomData, - } +impl Network { + #[allow(dead_code)] + pub const fn new() -> Self { + Self } } -impl Network { - pub fn new() -> Self { - Self { - tls_connector: None, - _sec: PhantomData, - } - } -} +impl TcpConnect for Network { + type Error = std::io::Error; -pub(crate) fn to_nb(e: std::io::Error) -> nb::Error { - use std::io::ErrorKind::{TimedOut, WouldBlock}; - match e.kind() { - WouldBlock | TimedOut => nb::Error::WouldBlock, - _ => e.into(), - } -} + type Connection<'a> + = FromTokio + where + Self: 'a; -pub enum TlsState { - MidHandshake(MidHandshakeTlsStream), - Connected(T), + async fn connect<'a>( + &'a self, + remote: SocketAddr, + ) -> Result, Self::Error> { + let stream = tokio::net::TcpStream::connect(format!("{}", remote)).await?; + Ok(FromTokio::new(stream)) + } } -pub struct TcpSocket { - pub stream: Option>, -} +impl Dns for Network { + type Error = std::io::Error; -impl TcpSocket { - pub fn new() -> Self { - TcpSocket { stream: None } + async fn get_host_by_name( + &self, + host: &str, + addr_type: AddrType, + ) -> Result { + for ip in tokio::net::lookup_host(format!("{}:0", host)).await? { + match (&addr_type, ip) { + (AddrType::IPv4 | AddrType::Either, SocketAddr::V4(ip)) => { + return Ok(IpAddr::V4(Ipv4Addr::from(ip.ip().octets()))) + } + (AddrType::IPv6 | AddrType::Either, SocketAddr::V6(ip)) => { + return Ok(IpAddr::V6(Ipv6Addr::from(ip.ip().octets()))) + } + (_, _) => {} + } + } + Err(std::io::Error::new( + std::io::ErrorKind::AddrNotAvailable, + "", + )) } - pub fn get_running(&mut self) -> std::io::Result<&mut T> { - match self.stream { - Some(TlsState::Connected(ref mut s)) => Ok(s), - _ => OutOfOrder.into(), - } + async fn get_host_by_address( + &self, + _addr: IpAddr, + _result: &mut [u8], + ) -> Result { + unimplemented!() } } -impl Dns for Network { - type Error = (); +pub struct TlsNetwork { + identity: Identity, + domain: String, +} - fn get_host_by_address( - &mut self, - ip_addr: IpAddr, - ) -> nb::Result, Self::Error> { - let ip: std::net::IpAddr = format!("{}", ip_addr).parse().unwrap(); - let host = lookup_addr(&ip).unwrap(); - Ok(heapless::String::from(host.as_str())) - } - fn get_host_by_name( - &mut self, - hostname: &str, - _addr_type: AddrType, - ) -> nb::Result { - let ips: Vec = lookup_host(hostname).unwrap(); - let ip = ips - .iter() - .find(|s| matches!(s, std::net::IpAddr::V4(_))) - .unwrap(); - format!("{}", ip).parse().map_err(|_| nb::Error::Other(())) +impl TlsNetwork { + pub const fn new(domain: String, identity: Identity) -> Self { + Self { identity, domain } } } -impl TcpClientStack for Network> { +impl TcpConnect for TlsNetwork { type Error = std::io::Error; - type TcpSocket = TcpSocket>; - fn socket(&mut self) -> Result { - Ok(TcpSocket::new()) - } + type Connection<'a> + = FromTokio> + where + Self: 'a; - fn receive( - &mut self, - network: &mut Self::TcpSocket, - buf: &mut [u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.read(buf).map_err(to_nb) - } - - fn send( - &mut self, - network: &mut Self::TcpSocket, - buf: &[u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.write(buf).map_err(|e| { - if !matches!( - e.kind(), - std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut - ) { - log::error!("{:?}", e); - } - to_nb(e) - }) - } - - fn is_connected(&mut self, network: &Self::TcpSocket) -> Result { - Ok(matches!(network.stream, Some(TlsState::Connected(_)))) - } - - fn connect( - &mut self, - network: &mut Self::TcpSocket, + async fn connect<'a>( + &'a self, remote: SocketAddr, - ) -> nb::Result<(), Self::Error> { - let tls_stream = match network.stream.take() { - None => { - let soc = TcpStream::connect(remote.to_string())?; - soc.set_nonblocking(true)?; - - let (connector, hostname) = self.tls_connector.as_ref().unwrap(); - - let mut tls_stream = connector.connect(hostname, soc).map_err(|e| match e { - native_tls::HandshakeError::Failure(_) => nb::Error::Other( - std::io::Error::new(std::io::ErrorKind::Other, "Failed TLS handshake"), - ), - native_tls::HandshakeError::WouldBlock(h) => { - network.stream.replace(TlsState::MidHandshake(h)); - nb::Error::WouldBlock - } - })?; - tls_stream.get_mut().set_nonblocking(true)?; - tls_stream - } - Some(TlsState::MidHandshake(h)) => { - let mut tls_stream = h.handshake().map_err(|e| match e { - native_tls::HandshakeError::Failure(_) => nb::Error::Other( - std::io::Error::new(std::io::ErrorKind::Other, "Failed TLS handshake"), - ), - native_tls::HandshakeError::WouldBlock(h) => { - network.stream.replace(TlsState::MidHandshake(h)); - nb::Error::WouldBlock - } - })?; - tls_stream.get_mut().set_nonblocking(true)?; - tls_stream - } - Some(TlsState::Connected(_)) => return Ok(()), - }; - - network.stream.replace(TlsState::Connected(tls_stream)); - - Ok(()) - } - - fn close(&mut self, _network: Self::TcpSocket) -> Result<(), Self::Error> { - // No-op: Socket gets closed when it is freed - // - // Could wrap it in an Option, but really that'll only make things messier; users will - // probably drop the socket anyway after closing, and can't expect it to be usable with - // this API. - Ok(()) + ) -> Result, Self::Error> { + log::info!("Connecting to {:?}", remote); + let connector = tokio_native_tls::TlsConnector::from( + native_tls::TlsConnector::builder() + .identity(self.identity.clone()) + .add_root_certificate(credentials::root_ca()) + .build() + .unwrap(), + ); + let stream = tokio::net::TcpStream::connect(format!("{}", remote)).await?; + let tls_stream = connector + .connect(self.domain.as_str(), stream) + .await + .unwrap(); + Ok(FromTokio::new(tls_stream)) } } -impl TcpClientStack for Network { +impl Dns for TlsNetwork { type Error = std::io::Error; - type TcpSocket = TcpSocket; - fn socket(&mut self) -> Result { - Ok(TcpSocket::new()) - } - - fn receive( - &mut self, - network: &mut Self::TcpSocket, - buf: &mut [u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.read(buf).map_err(to_nb) - } - - fn send( - &mut self, - network: &mut Self::TcpSocket, - buf: &[u8], - ) -> nb::Result { - let socket = network.get_running()?; - socket.write(buf).map_err(to_nb) - } - - fn is_connected(&mut self, network: &Self::TcpSocket) -> Result { - Ok(matches!(network.stream, Some(TlsState::Connected(_)))) - } - - fn connect( - &mut self, - network: &mut Self::TcpSocket, - remote: SocketAddr, - ) -> nb::Result<(), Self::Error> { - let soc = TcpStream::connect(format!("{}", remote))?; - soc.set_nonblocking(true)?; - network.stream.replace(TlsState::Connected(soc)); - - Ok(()) + async fn get_host_by_name( + &self, + host: &str, + addr_type: AddrType, + ) -> Result { + log::info!("Looking up {}", host); + for ip in tokio::net::lookup_host(format!("{}:0", host)).await? { + log::info!("Found IP {}", ip); + + match (&addr_type, ip) { + (AddrType::IPv4 | AddrType::Either, SocketAddr::V4(ip)) => { + return Ok(IpAddr::V4(Ipv4Addr::from(ip.ip().octets()))) + } + (AddrType::IPv6 | AddrType::Either, SocketAddr::V6(ip)) => { + return Ok(IpAddr::V6(Ipv6Addr::from(ip.ip().octets()))) + } + (_, _) => {} + } + } + Err(std::io::Error::new( + std::io::ErrorKind::AddrNotAvailable, + "", + )) } - fn close(&mut self, _network: Self::TcpSocket) -> Result<(), Self::Error> { - // No-op: Socket gets closed when it is freed - // - // Could wrap it in an Option, but really that'll only make things messier; users will - // probably drop the socket anyway after closing, and can't expect it to be usable with - // this API. - Ok(()) + async fn get_host_by_address( + &self, + _addr: IpAddr, + _result: &mut [u8], + ) -> Result { + unimplemented!() } } diff --git a/tests/metric.rs b/tests/metric.rs new file mode 100644 index 0000000..76cb1f1 --- /dev/null +++ b/tests/metric.rs @@ -0,0 +1,111 @@ +//! ## Integration test of `AWS IoT Device defender metrics` +//! +//! +//! This test simulates publishing of metrics and expects a accepted response from aws +//! +//! The test runs through the following update sequence: +//! 1. Setup metric state +//! 2. Assert json format +//! 2. Publish metric +//! 3. Assert result from AWS + +mod common; + +use std::str::FromStr; + +use common::credentials; +use common::network::TlsNetwork; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{self, transport::embedded_nal::NalTransport, Config, DomainBroker, State}; +use heapless::LinearMap; +use rustot::defender_metrics::{ + data_types::{CustomMetric, Metric}, + MetricHandler, +}; +use static_cell::StaticCell; + +fn assert_json_format(json: &str) { + log::debug!("{json}"); + let format = "{\"hed\":{\"rid\":0,\"v\":\"1.0\"},\"met\":null,\"cmet\":{\"MyMetricOfType_Number\":[{\"number\":1}],\"MyMetricOfType_NumberList\":[{\"number_list\":[1,2,3]}],\"MyMetricOfType_StringList\":[{\"string_list\":[\"value_1\",\"value_2\"]}],\"MyMetricOfType_IpList\":[{\"ip_list\":[\"172.0.0.0\",\"172.0.0.10\"]}]}}"; + + assert_eq!(json, format); +} + +#[tokio::test(flavor = "current_thread")] +async fn test_publish_metric() { + env_logger::init(); + + let (thing_name, identity) = credentials::identity(); + let hostname = credentials::HOSTNAME.unwrap(); + + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); + + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); + + // Define metrics + let mut custom_metrics: LinearMap = LinearMap::new(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_Number").unwrap(), + [CustomMetric::Number(1)], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_NumberList").unwrap(), + [CustomMetric::NumberList(&[1, 2, 3])], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_StringList").unwrap(), + [CustomMetric::StringList(&["value_1", "value_2"])], + ) + .unwrap(); + + custom_metrics + .insert( + String::from_str("MyMetricOfType_IpList").unwrap(), + [CustomMetric::IpList(&["172.0.0.0", "172.0.0.10"])], + ) + .unwrap(); + + // Build metric + let metric = Metric::builder() + .custom_metrics(custom_metrics) + .header(Default::default()) + .build(); + + // Test the json format + let json = serde_json::to_string(&metric).unwrap(); + + assert_json_format(&json); + + let metric_handler = MetricHandler::new(&client); + + // Publish metric with mqtt + let mqtt_fut = async { assert!(metric_handler.publish_metric(metric, 2000).await.is_ok()) }; + + let mut transport = NalTransport::new(network, broker); + let _ = embassy_time::with_timeout( + embassy_time::Duration::from_secs(60), + select::select(stack.run(&mut transport), mqtt_fut), + ) + .await; +} diff --git a/tests/ota.rs b/tests/ota.rs deleted file mode 100644 index 9e459e6..0000000 --- a/tests/ota.rs +++ /dev/null @@ -1,235 +0,0 @@ -mod common; - -use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification, PublishNotification}; -use native_tls::TlsConnector; -use rustot::ota::state::States; -use serde::Deserialize; -use sha2::{Digest, Sha256}; -use std::{fs::File, io::Read, ops::Deref}; - -use common::{clock::SysClock, credentials, file_handler::FileHandler, network::Network}; -use rustot::{ - jobs::{ - self, - data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, - StatusDetails, - }, - ota::{self, agent::OtaAgent, encoding::json::OtaJob}, -}; - -static mut Q: BBBuffer<{ 1024 * 10 }> = BBBuffer::new(); - -#[derive(Debug, Deserialize)] -pub enum Jobs<'a> { - #[serde(rename = "afr_ota")] - #[serde(borrow)] - Ota(OtaJob<'a>), -} - -impl<'a> Jobs<'a> { - pub fn ota_job(self) -> Option> { - match self { - Jobs::Ota(ota_job) => Some(ota_job), - } - } -} - -enum OtaUpdate<'a> { - JobUpdate(&'a str, OtaJob<'a>, Option>), - Data(&'a mut [u8]), -} - -fn handle_ota<'a>(publish: &'a mut PublishNotification) -> Result, ()> { - match jobs::Topic::from_str(publish.topic_name.as_str()) { - Some(jobs::Topic::NotifyNext) => { - let (execution_changed, _) = - serde_json_core::from_slice::>(&publish.payload) - .map_err(drop)?; - let job = execution_changed.execution.ok_or(())?; - let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; - return Ok(OtaUpdate::JobUpdate( - job.job_id, - ota_job, - job.status_details, - )); - } - Some(jobs::Topic::DescribeAccepted(_)) => { - let (execution_changed, _) = - serde_json_core::from_slice::>(&publish.payload) - .map_err(drop)?; - let job = execution_changed.execution.ok_or(())?; - let ota_job = job.job_document.ok_or(())?.ota_job().ok_or(())?; - return Ok(OtaUpdate::JobUpdate( - job.job_id, - ota_job, - job.status_details, - )); - } - _ => {} - } - - match ota::Topic::from_str(publish.topic_name.as_str()) { - Some(ota::Topic::Data(_, _)) => { - return Ok(OtaUpdate::Data(&mut publish.payload)); - } - _ => {} - } - Err(()) -} - -pub struct FileInfo { - pub file_path: String, - pub filesize: usize, - pub signature: ota::encoding::json::Signature, -} - -#[test] -fn test_mqtt_ota() { - // Make sure this times out in case something went wrong setting up the OTA - // job in AWS IoT before starting. - timebomb::timeout_ms(test_mqtt_ota_inner, 100_000) -} - -fn test_mqtt_ota_inner() { - env_logger::init(); - - let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - - log::info!("Starting OTA test..."); - - let hostname = credentials::HOSTNAME.unwrap(); - let (thing_name, identity) = credentials::identity(); - - let connector = TlsConnector::builder() - .identity(identity) - .add_root_certificate(credentials::root_ca()) - .build() - .unwrap(); - - let mut network = Network::new_tls(connector, String::from(hostname)); - - let mut mqtt_eventloop = EventLoop::new( - c, - SysClock::new(), - MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), - ); - - let mqtt_client = mqttrust_core::Client::new(p, thing_name); - - let file_handler = FileHandler::new(); - - let mut ota_agent = - OtaAgent::builder(&mqtt_client, &mqtt_client, SysClock::new(), file_handler) - .request_wait_ms(3000) - .block_size(256) - .build(); - - let mut file_info = None; - - loop { - match mqtt_eventloop.connect(&mut network) { - Ok(true) => { - log::info!("Successfully connected to broker"); - ota_agent.init(); - } - Ok(false) => {} - Err(nb::Error::WouldBlock) => continue, - Err(e) => panic!("{:?}", e), - } - - match mqtt_eventloop.yield_event(&mut network) { - Ok(Notification::Publish(mut publish)) => { - // Check if the received file is a jobs topic, that we - // want to react to. - match handle_ota(&mut publish) { - Ok(OtaUpdate::JobUpdate(job_id, job_doc, status_details)) => { - log::debug!("Received job! Starting OTA! {:?}", job_doc.streamname); - - let file = &job_doc.files[0]; - file_info.replace(FileInfo { - file_path: file.filepath.to_string(), - filesize: file.filesize, - signature: file.signature(), - }); - ota_agent - .job_update(job_id, &job_doc, status_details.as_ref()) - .expect("Failed to start OTA job"); - } - Ok(OtaUpdate::Data(payload)) => { - if ota_agent.handle_message(payload).is_err() { - match ota_agent.state() { - States::CreatingFile => log::info!("State: CreatingFile"), - States::Ready => log::info!("State: Ready"), - States::RequestingFileBlock => { - log::info!("State: RequestingFileBlock") - } - States::RequestingJob => log::info!("State: RequestingJob"), - States::Restarting => log::info!("State: Restarting"), - States::Suspended => log::info!("State: Suspended"), - States::WaitingForFileBlock => { - log::info!("State: WaitingForFileBlock") - } - States::WaitingForJob => log::info!("State: WaitingForJob"), - } - } - } - Err(_) => {} - } - } - Ok(n) => { - log::trace!("{:?}", n); - } - _ => {} - } - - ota_agent.timer_callback().expect("Failed timer callback!"); - - match ota_agent.process_event() { - // Use the restarting state to indicate finished - Ok(States::Restarting) => break, - _ => {} - } - } - - let mut expected_file = File::open("tests/assets/ota_file").unwrap(); - let mut expected_data = Vec::new(); - expected_file.read_to_end(&mut expected_data).unwrap(); - let mut expected_hasher = Sha256::new(); - expected_hasher.update(&expected_data); - let expected_hash = expected_hasher.finalize(); - - let file_info = file_info.unwrap(); - - log::info!( - "Comparing {:?} with {:?}", - "tests/assets/ota_file", - file_info.file_path - ); - let mut file = File::open(file_info.file_path.clone()).unwrap(); - let mut data = Vec::new(); - file.read_to_end(&mut data).unwrap(); - drop(file); - std::fs::remove_file(file_info.file_path).unwrap(); - - assert_eq!(data.len(), file_info.filesize); - - let mut hasher = Sha256::new(); - hasher.update(&data); - assert_eq!(hasher.finalize().deref(), expected_hash.deref()); - - // Check file signature - match file_info.signature { - ota::encoding::json::Signature::Sha1Rsa(_) => { - panic!("Unexpected signature format: Sha1Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Rsa(_) => { - panic!("Unexpected signature format: Sha256Rsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha1Ecdsa(_) => { - panic!("Unexpected signature format: Sha1Ecdsa. Expected Sha256Ecdsa") - } - ota::encoding::json::Signature::Sha256Ecdsa(sig) => { - assert_eq!(&sig, "This is my custom signature\\n") - } - } -} diff --git a/tests/ota_mqtt.rs b/tests/ota_mqtt.rs new file mode 100644 index 0000000..4fd1be6 --- /dev/null +++ b/tests/ota_mqtt.rs @@ -0,0 +1,196 @@ +#![allow(async_fn_in_trait)] +#![feature(type_alias_impl_trait)] + +mod common; + +use common::credentials; +use common::file_handler::{FileHandler, State as FileHandlerState}; +use common::network::TlsNetwork; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::transport::embedded_nal::NalTransport; +use embedded_mqtt::{ + Config, DomainBroker, Message, SliceBufferProvider, State, Subscribe, SubscribeTopic, +}; +use serde::Deserialize; +use static_cell::StaticCell; + +use rustot::{ + jobs::{ + self, + data_types::{DescribeJobExecutionResponse, NextJobExecutionChanged}, + }, + ota::{ + self, + encoding::{json::OtaJob, FileContext}, + JobEventData, Updater, + }, +}; + +#[derive(Debug, Deserialize)] +pub enum Jobs<'a> { + #[serde(rename = "afr_ota")] + #[serde(borrow)] + Ota(OtaJob<'a>), +} + +impl<'a> Jobs<'a> { + pub fn ota_job(self) -> Option> { + match self { + Jobs::Ota(ota_job) => Some(ota_job), + } + } +} + +fn handle_ota<'a>( + message: Message<'a, NoopRawMutex, SliceBufferProvider<'a>>, + config: &ota::config::Config, +) -> Option { + let job = match jobs::Topic::from_str(message.topic_name()) { + Some(jobs::Topic::NotifyNext) => { + let (execution_changed, _) = + serde_json_core::from_slice::>(message.payload()) + .ok()?; + execution_changed.execution? + } + Some(jobs::Topic::DescribeAccepted(_)) => { + let (execution_changed, _) = serde_json_core::from_slice::< + DescribeJobExecutionResponse, + >(message.payload()) + .ok()?; + + if execution_changed.execution.is_none() { + panic!("No OTA jobs queued?"); + } + + execution_changed.execution? + } + _ => { + return None; + } + }; + + let ota_job = job.job_document?.ota_job()?; + + FileContext::new_from( + JobEventData { + job_name: job.job_id, + ota_document: ota_job, + status_details: job.status_details, + }, + 0, + config, + ) + .ok() +} + +#[tokio::test(flavor = "current_thread")] +async fn test_mqtt_ota() { + env_logger::init(); + + log::info!("Starting OTA test..."); + + let (thing_name, identity) = credentials::identity(); + + let hostname = credentials::HOSTNAME.unwrap(); + + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); + + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); + + let mut file_handler = FileHandler::new("tests/assets/ota_file".to_owned()); + + let ota_fut = async { + let mut jobs_subscription = client + .subscribe::<2>( + Subscribe::builder() + .topics(&[ + SubscribeTopic::builder() + .topic_path( + jobs::JobTopic::NotifyNext + .format::<64>(thing_name)? + .as_str(), + ) + .build(), + SubscribeTopic::builder() + .topic_path( + jobs::JobTopic::DescribeAccepted("$next") + .format::<64>(thing_name)? + .as_str(), + ) + .build(), + ]) + .build(), + ) + .await?; + + Updater::check_for_job(&client).await?; + + let config = ota::config::Config::default(); + + let message = jobs_subscription.next_message().await.unwrap(); + + if let Some(mut file_ctx) = handle_ota(message, &config) { + // Nested subscriptions are a problem for embedded-mqtt, so unsubscribe here + jobs_subscription.unsubscribe().await.unwrap(); + + // We have an OTA job, leeeets go! + Updater::perform_ota( + &client, + &client, + file_ctx.clone(), + &mut file_handler, + &config, + ) + .await?; + + assert_eq!(file_handler.plateform_state, FileHandlerState::Swap); + + log::info!("Running OTA handler second time to verify state match..."); + + // Run it twice in this particular integration test, in order to + // simulate image commit after bootloader swap + file_ctx + .status_details + .insert( + heapless::String::try_from("self_test").unwrap(), + heapless::String::try_from("active").unwrap(), + ) + .unwrap(); + + Updater::perform_ota(&client, &client, file_ctx, &mut file_handler, &config).await?; + + return Ok(()); + } + + Ok::<_, ota::error::OtaError>(()) + }; + + let mut transport = NalTransport::new(network, broker); + + match embassy_time::with_timeout( + embassy_time::Duration::from_secs(25), + select::select(stack.run(&mut transport), ota_fut), + ) + .await + .unwrap() + { + select::Either::First(_) => { + unreachable!() + } + select::Either::Second(result) => result.unwrap(), + }; + + assert_eq!(file_handler.plateform_state, FileHandlerState::Boot); +} diff --git a/tests/provisioning.rs b/tests/provisioning.rs index 5fa317d..bb54dc2 100644 --- a/tests/provisioning.rs +++ b/tests/provisioning.rs @@ -1,24 +1,23 @@ -mod common; - -use mqttrust::Mqtt; -use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification, PublishNotification}; +#![allow(async_fn_in_trait)] +#![feature(type_alias_impl_trait)] -use common::clock::SysClock; -use common::network::{Network, TcpSocket}; -use native_tls::{Identity, TlsConnector, TlsStream}; -use p256::ecdsa::signature::Signer; -use rustot::provisioning::{topics::Topic, Credentials, FleetProvisioner, Response}; -use std::net::TcpStream; -use std::ops::DerefMut; +mod common; use common::credentials; - -static mut Q: BBBuffer<{ 1024 * 10 }> = BBBuffer::new(); +use common::network::TlsNetwork; +use ecdsa::Signature; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{transport::embedded_nal::NalTransport, Config, DomainBroker, State}; +use p256::{ecdsa::signature::Signer, NistP256}; +use rustot::provisioning::{CredentialHandler, Credentials, Error, FleetProvisioner}; +use serde::{Deserialize, Serialize}; +use static_cell::StaticCell; pub struct OwnedCredentials { - certificate_id: String, - certificate_pem: String, - private_key: Option, + pub certificate_id: String, + pub certificate_pem: String, + pub private_key: Option, } impl<'a> From> for OwnedCredentials { @@ -31,120 +30,105 @@ impl<'a> From> for OwnedCredentials { } } -fn provision_credentials<'a, const L: usize>( - hostname: &'a str, - identity: Identity, - mqtt_eventloop: &mut EventLoop<'a, 'a, TcpSocket>, SysClock, 1000, L>, - mqtt_client: &mqttrust_core::Client, -) -> Result { - let template_name = - std::env::var("TEMPLATE_NAME").unwrap_or_else(|_| "duoProvisioningTemplate".to_string()); - - let connector = TlsConnector::builder() - .identity(identity) - .add_root_certificate(credentials::root_ca()) - .build() - .unwrap(); - - let mut network = Network::new_tls(connector, String::from(hostname)); - - nb::block!(mqtt_eventloop.connect(&mut network)) - .expect("To connect to MQTT with claim credentials"); - - log::info!("Successfully connected to broker with claim credentials"); - - #[cfg(not(feature = "provision_cbor"))] - let mut provisioner = FleetProvisioner::new(mqtt_client, &template_name); - #[cfg(feature = "provision_cbor")] - let mut provisioner = FleetProvisioner::new_cbor(mqtt_client, &template_name); +pub struct CredentialDAO { + pub creds: Option, +} - provisioner - .initialize() - .expect("Failed to initialize FleetProvisioner"); +impl CredentialHandler for CredentialDAO { + async fn store_credentials(&mut self, credentials: Credentials<'_>) -> Result<(), Error> { + log::info!("Provisioned credentials: {:#?}", credentials); - let mut provisioned_credentials: Option = None; + self.creds.replace(credentials.into()); - let signing_key = credentials::signing_key(); - let signature = hex::encode(signing_key.sign(mqtt_client.client_id().as_bytes())); - - let result = loop { - match mqtt_eventloop.yield_event(&mut network) { - Ok(Notification::Publish(mut publish)) if Topic::check(publish.topic_name.as_str()) => { - let PublishNotification { - topic_name, - payload, - .. - } = publish.deref_mut(); - - match provisioner.handle_message::<4>(topic_name.as_str(), payload) { - Ok(Some(Response::Credentials(credentials))) => { - log::info!("Got credentials! {:?}", credentials); - provisioned_credentials = Some(credentials.into()); - - let mut parameters = heapless::LinearMap::new(); - parameters.insert("uuid", mqtt_client.client_id()).unwrap(); - parameters.insert("signature", &signature).unwrap(); - - provisioner - .register_thing::<2>(Some(parameters)) - .expect("To successfully publish to RegisterThing"); - } - Ok(Some(Response::DeviceConfiguration(config))) => { - // Store Device configuration parameters, if any. - - log::info!("Got device config! {:?}", config); - - break Ok(()); - } - Ok(None) => {} - Err(e) => { - log::error!("Got provision error! {:?}", e); - provisioned_credentials = None; - - break Err(()); - } - } - } - Ok(Notification::Suback(_)) => { - log::info!("Starting provisioning"); - provisioner.begin().expect("To begin provisioning"); - } - Ok(n) => { - log::trace!("{:?}", n); - } - _ => {} - } - }; + Ok(()) + } +} - // Disconnect from AWS IoT Core - mqtt_eventloop.disconnect(&mut network); +#[derive(Debug, Serialize)] +struct Parameters<'a> { + uuid: &'a str, + signature: &'a str, +} - result.and_then(|_| provisioned_credentials.ok_or(())) +#[derive(Debug, Deserialize, PartialEq)] +struct DeviceConfig { + #[serde(rename = "SoftwareId")] + software_id: heapless::String<64>, } -#[test] -fn test_provisioning() { +#[tokio::test(flavor = "current_thread")] +async fn test_provisioning() { env_logger::init(); - let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - log::info!("Starting provisioning test..."); let (thing_name, claim_identity) = credentials::claim_identity(); // Connect to AWS IoT Core with provisioning claim credentials let hostname = credentials::HOSTNAME.unwrap(); + let template_name = + std::env::var("TEMPLATE_NAME").unwrap_or_else(|_| "duoProvisioningTemplate".to_string()); - let mut mqtt_eventloop = EventLoop::new( - c, - SysClock::new(), - MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), - ); + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), claim_identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); + + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); + + let signing_key = credentials::signing_key(); + let signature: Signature = signing_key.sign(thing_name.as_bytes()); + let hex_signature: String = hex::encode(signature.to_bytes()); - let mqtt_client = mqttrust_core::Client::new(p, thing_name); + let parameters = Parameters { + uuid: thing_name, + signature: &hex_signature, + }; - let credentials = - provision_credentials(hostname, claim_identity, &mut mqtt_eventloop, &mqtt_client).unwrap(); + let mut credential_handler = CredentialDAO { creds: None }; - assert!(credentials.certificate_id.len() > 0); + #[cfg(not(feature = "provision_cbor"))] + let provision_fut = FleetProvisioner::provision::( + &client, + &template_name, + Some(parameters), + &mut credential_handler, + ); + #[cfg(feature = "provision_cbor")] + let provision_fut = FleetProvisioner::provision_cbor::( + &client, + &template_name, + Some(parameters), + &mut credential_handler, + ); + + let mut transport = NalTransport::new(network, broker); + + let device_config = match embassy_time::with_timeout( + embassy_time::Duration::from_secs(15), + select::select(stack.run(&mut transport), provision_fut), + ) + .await + .unwrap() + { + select::Either::First(_) => { + unreachable!() + } + select::Either::Second(result) => result.unwrap(), + }; + assert_eq!( + device_config, + Some(DeviceConfig { + software_id: heapless::String::try_from("82b3509e0e924e06ab1bdb1cf1625dcb").unwrap() + }) + ); + assert!(!credential_handler.creds.unwrap().certificate_id.is_empty()); } diff --git a/tests/shadows.rs b/tests/shadows.rs index bb83e9c..853372b 100644 --- a/tests/shadows.rs +++ b/tests/shadows.rs @@ -1,4 +1,3 @@ -//! //! ## Integration test of `AWS IoT Shadows` //! //! @@ -22,481 +21,209 @@ //! 12. Assert on shadow state //! -mod common; +#![allow(async_fn_in_trait)] +#![feature(type_alias_impl_trait)] -use core::fmt::Write; +mod common; -use common::{clock::SysClock, credentials, network::Network}; -use embedded_nal::Ipv4Addr; -use mqttrust::Mqtt; -use mqttrust_core::{bbqueue::BBBuffer, EventLoop, MqttOptions, Notification}; -use native_tls::TlsConnector; -use rustot::shadows::{ - derive::ShadowState, topics::Topic, Patch, Shadow, ShadowPatch, ShadowState, +use common::credentials; +use common::network::TlsNetwork; +use embassy_futures::select; +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embedded_mqtt::{ + self, transport::embedded_nal::NalTransport, Config, DomainBroker, MqttClient, Publish, QoS, + State, Subscribe, SubscribeTopic, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; - -use smlang::statemachine; - -const Q_SIZE: usize = 1024 * 6; -static mut Q: BBBuffer = BBBuffer::new(); - -#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] -pub struct ConfigId(pub u8); - -impl Serialize for ConfigId { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut str = heapless::String::<3>::new(); - write!(str, "{}", self.0).map_err(serde::ser::Error::custom)?; - serializer.serialize_str(&str) - } -} - -impl<'de> Deserialize<'de> for ConfigId { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - heapless::String::<3>::deserialize(deserializer)? - .parse() - .map(ConfigId) - .map_err(serde::de::Error::custom) - } -} - -impl From for ConfigId { - fn from(v: u8) -> Self { - Self(v) - } -} - -#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] -pub struct NetworkMap(heapless::LinearMap>, N>); - -impl NetworkMap -where - K: Eq, -{ - pub fn insert(&mut self, k: impl Into, v: V) -> Result<(), ()> { - self.0.insert(k.into(), Some(Patch::Set(v))).map_err(drop)?; - Ok(()) - } - - pub fn remove(&mut self, k: impl Into) -> Result<(), ()> { - self.0.insert(k.into(), None).map_err(drop)?; - Ok(()) - } -} - -impl ShadowPatch for NetworkMap -where - K: Clone + Default + Eq + Serialize + DeserializeOwned, - V: Clone + Default + Serialize + DeserializeOwned, -{ - type PatchState = NetworkMap; - - fn apply_patch(&mut self, opt: Self::PatchState) { - for (id, network) in opt.0.into_iter() { - match network { - Some(Patch::Set(v)) => { - self.insert(id.clone(), v.clone()).ok(); - } - None | Some(Patch::Unset) => { - self.remove(id.clone()).ok(); - } - } - } - } -} - -const MAX_NETWORKS: usize = 5; -type KnownNetworks = NetworkMap; - -#[derive(Debug, Clone, Default, Serialize, Deserialize, ShadowState)] -#[shadow("wifi")] -pub struct WifiConfig { - pub enabled: bool, - - pub known_networks: KnownNetworks, -} - -#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] -pub struct ConnectionOptions { - pub ssid: heapless::String<64>, - pub password: Option>, - - pub ip: Option, - pub subnet: Option, - pub gateway: Option, -} - -#[derive(Debug, Clone)] -pub enum UpdateAction { - Insert(u8, ConnectionOptions), - Remove(u8), - Enabled(bool), -} - -statemachine! { - transitions: { - *Begin + Delete = DeleteShadow, - DeleteShadow + Get = GetShadow, - GetShadow + Load / load_initial = LoadShadow(Option), - LoadShadow(Option) + CheckInitial / check_initial = Check(Option), - UpdateFromDevice(UpdateAction) + CheckState / check = Check(Option), - UpdateFromCloud(UpdateAction) + Ack = AckUpdate, - AckUpdate + CheckState / check_cloud = Check(Option), - Check(Option) + UpdateStateFromDevice / get_next_device = UpdateFromDevice(UpdateAction), - Check(Option) + UpdateStateFromCloud / get_next_cloud = UpdateFromCloud(UpdateAction), - } -} - -impl core::fmt::Debug for States { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Begin => write!(f, "Self::Begin"), - Self::DeleteShadow => write!(f, "Self::DeleteShadow"), - Self::GetShadow => write!(f, "Self::GetShadow"), - Self::AckUpdate => write!(f, "Self::AckUpdate"), - Self::LoadShadow(t) => write!(f, "Self::LoadShadow({:?})", t), - Self::UpdateFromDevice(t) => write!(f, "Self::UpdateFromDevice({:?})", t), - Self::UpdateFromCloud(t) => write!(f, "Self::UpdateFromCloud({:?})", t), - Self::Check(t) => write!(f, "Self::Check({:?})", t), - } - } -} - -fn asserts(id: usize) -> ConnectionOptions { - match id { - 0 => ConnectionOptions { - ssid: heapless::String::from("MySSID"), - password: None, - ip: None, - subnet: None, - gateway: None, - }, - 1 => ConnectionOptions { - ssid: heapless::String::from("MyProtectedSSID"), - password: Some(heapless::String::from("SecretPass")), - ip: None, - subnet: None, - gateway: None, - }, - 2 => ConnectionOptions { - ssid: heapless::String::from("CloudSSID"), - password: Some(heapless::String::from("SecretCloudPass")), - ip: Some(Ipv4Addr::new(1, 2, 3, 4)), - subnet: None, - gateway: None, - }, - _ => panic!("Unknown assert ID"), - } +use rustot::shadows::{Shadow, ShadowState}; +use rustot_derive::shadow; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use static_cell::StaticCell; + +#[shadow(name = "state")] +#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)] +pub struct TestShadow { + pub foo: u32, } -pub struct TestContext<'a> { - shadow: Shadow<'a, WifiConfig, mqttrust_core::Client<'static, 'static, Q_SIZE>>, - update_cnt: u8, -} - -impl<'a> StateMachineContext for TestContext<'a> { - fn check_initial( - &mut self, - _last_update_action: &Option, - ) -> Option { - self.check(&UpdateAction::Remove(0)) - } - - fn check_cloud(&mut self) -> Option { - self.check(&UpdateAction::Remove(0)) - } - - fn check(&mut self, _last_update_action: &UpdateAction) -> Option { - let mut known_networks = KnownNetworks::default(); - - match self.update_cnt { - 0 => { - // After load_initial - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - } - 1 => { - // After get_next_device - known_networks.remove(0).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - } - 2 => { - // After get_next_cloud - known_networks.remove(0).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - known_networks.insert(2, asserts(2)).unwrap(); - } - 3 => { - // After get_next_device - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - known_networks.insert(2, asserts(2)).unwrap(); - } - 4 => { - // After get_next_cloud - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - known_networks.remove(2).unwrap(); - } - 5 => return None, - _ => {} - } - - Some(known_networks) - } - - fn get_next_device(&mut self, _: &Option) -> UpdateAction { - self.update_cnt += 1; - match self.update_cnt { - 1 => UpdateAction::Remove(0), - 3 => UpdateAction::Insert(0, asserts(0)), - 5 => UpdateAction::Remove(0), - _ => panic!("Unexpected update counter in `get_next_device`"), - } - } - - fn get_next_cloud(&mut self, _: &Option) -> UpdateAction { - self.update_cnt += 1; - - match self.update_cnt { - 2 => UpdateAction::Insert(2, asserts(2)), - 4 => UpdateAction::Remove(2), - _ => panic!("Unexpected update counter in `get_next_cloud`"), - } - } - - fn load_initial(&mut self) -> Option { - let mut known_networks = KnownNetworks::default(); - known_networks.insert(0, asserts(0)).unwrap(); - known_networks.insert(1, asserts(1)).unwrap(); - Some(known_networks) - } +/// Helper function to mimic cloud side updates using MQTT client directly +async fn cloud_update(client: &MqttClient<'static, NoopRawMutex>, payload: &[u8]) { + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Update + .format::<128>(TestShadow::PREFIX, client.client_id(), TestShadow::NAME) + .unwrap() + .as_str(), + ) + .payload(payload) + .qos(QoS::AtLeastOnce) + .build(), + ) + .await + .unwrap(); } -impl<'a> StateMachine> { - pub fn spin( - &mut self, - notification: Notification, - mqtt_client: &mqttrust_core::Client<'static, 'static, Q_SIZE>, - ) -> bool { - log::info!("State: {:?}", self.state()); - match (self.state(), notification) { - (&States::Begin, Notification::Suback(_)) => { - self.process_event(Events::Delete).unwrap(); - } - (&States::DeleteShadow, Notification::Suback(_)) => { - mqtt_client - .publish( - &Topic::Update - .format::<128>( - mqtt_client.client_id(), - ::NAME, - ) - .unwrap(), - b"{\"state\":{\"desired\":null,\"reported\":null}}", - mqttrust::QoS::AtLeastOnce, - ) - .unwrap(); - - self.process_event(Events::Get).unwrap(); - } - (&States::GetShadow, Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/update/accepted" - ) => - { - self.context_mut().shadow.get_shadow().unwrap(); - self.process_event(Events::Load).unwrap(); - } - (&States::LoadShadow(ref initial_map), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/get/accepted" - ) => - { - let initial_map = initial_map.clone(); - - self.context_mut() - .shadow - .update(|_current, desired| { - desired.known_networks = Some(initial_map.unwrap()); - }) - .unwrap(); - self.process_event(Events::CheckInitial).unwrap(); - } - (&States::UpdateFromDevice(ref update_action), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/get/accepted" - ) => - { - let action = update_action.clone(); - self.context_mut() - .shadow - .update(|current, desired| match action { - UpdateAction::Insert(id, options) => { - let mut desired_map = current.known_networks.clone(); - desired_map.insert(id, options).unwrap(); - desired.known_networks = Some(desired_map); - } - UpdateAction::Remove(id) => { - let mut desired_map = current.known_networks.clone(); - desired_map.remove(id).unwrap(); - desired.known_networks = Some(desired_map); - } - UpdateAction::Enabled(en) => { - desired.enabled = Some(en); - } - }) - .unwrap(); - self.process_event(Events::CheckState).unwrap(); - } - (&States::UpdateFromCloud(ref update_action), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/get/accepted" - ) => - { - let desired_known_networks = match update_action { - UpdateAction::Insert(id, options) => format!( - "\"known_networks\": {{\"{}\":{{\"set\":{}}}}}", - id, - serde_json_core::to_string::<_, 256>(options).unwrap() - ), - UpdateAction::Remove(id) => { - format!("\"known_networks\": {{\"{}\":\"unset\"}}", id) - } - &UpdateAction::Enabled(en) => format!("\"enabled\": {}", en), - }; - - let payload = format!( - "{{\"state\":{{\"desired\":{{{}}}, \"reported\":{}}}}}", - desired_known_networks, - serde_json_core::to_string::<_, 512>(self.context().shadow.get()).unwrap() - ); - - log::debug!("Update from cloud: {:?}", payload); - - mqtt_client - .publish( - &Topic::Update - .format::<128>( - mqtt_client.client_id(), - ::NAME, - ) - .unwrap(), - payload.as_bytes(), - mqttrust::QoS::AtLeastOnce, +/// Helper function to assert on the current shadow state +async fn assert_shadow(client: &MqttClient<'static, NoopRawMutex>, expected: serde_json::Value) { + let mut get_shadow_sub = client + .subscribe::<1>( + Subscribe::builder() + .topics(&[SubscribeTopic::builder() + .topic_path( + rustot::shadows::topics::Topic::GetAccepted + .format::<128>(TestShadow::PREFIX, client.client_id(), TestShadow::NAME) + .unwrap() + .as_str(), ) - .unwrap(); - self.process_event(Events::Ack).unwrap(); - } - (&States::AckUpdate, Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/update/delta" - ) => - { - self.context_mut() - .shadow - .handle_message(&publish.topic_name, &publish.payload) - .unwrap(); + .build()]) + .build(), + ) + .await + .unwrap(); - self.process_event(Events::CheckState).unwrap(); - } - (&States::Check(ref expected_map), Notification::Publish(publish)) - if matches!( - publish.topic_name.as_str(), - "$aws/things/rustot-test/shadow/name/wifi/update/accepted" - | "$aws/things/rustot-test/shadow/name/wifi/update/delta" - ) => - { - let expected = expected_map.clone(); - self.context_mut() - .shadow - .handle_message(&publish.topic_name, &publish.payload) - .unwrap(); + client + .publish( + Publish::builder() + .topic_name( + rustot::shadows::topics::Topic::Get + .format::<128>(TestShadow::PREFIX, client.client_id(), TestShadow::NAME) + .unwrap() + .as_str(), + ) + .payload(b"") + .build(), + ) + .await + .unwrap(); - match expected { - Some(expected_map) => { - assert_eq!(self.context().shadow.get().known_networks, expected_map); - self.context_mut().shadow.get_shadow().unwrap(); - let event = if self.context().update_cnt % 2 == 0 { - Events::UpdateStateFromDevice - } else { - Events::UpdateStateFromCloud - }; - self.process_event(event).unwrap(); - } - None => return true, - } - } - (_, Notification::Publish(publish)) => { - log::warn!("TOPIC: {}", publish.topic_name); - self.context_mut() - .shadow - .handle_message(&publish.topic_name, &publish.payload) - .unwrap(); - } - _ => {} - } + let current_shadow = get_shadow_sub.next_message().await.unwrap(); - false - } + assert_eq!( + serde_json::from_slice::(current_shadow.payload()) + .unwrap() + .get("state") + .unwrap(), + &expected, + ); } -#[test] -fn test_shadows() { +#[tokio::test(flavor = "current_thread")] +async fn test_shadow_update_from_device() { env_logger::init(); - let (p, c) = unsafe { Q.try_split_framed().unwrap() }; - - log::info!("Starting shadows test..."); + const DESIRED_1: &str = r#"{ + "state": { + "desired": { + "foo": 42 + } + } + }"#; - let hostname = credentials::HOSTNAME.unwrap(); let (thing_name, identity) = credentials::identity(); + let hostname = credentials::HOSTNAME.unwrap(); - let connector = TlsConnector::builder() - .identity(identity) - .add_root_certificate(credentials::root_ca()) - .build() - .unwrap(); - - let mut network = Network::new_tls(connector, std::string::String::from(hostname)); - - let mut mqtt_eventloop = EventLoop::new( - c, - SysClock::new(), - MqttOptions::new(thing_name, hostname.into(), 8883).set_clean_session(true), - ); - - let mqtt_client = mqttrust_core::Client::new(p, thing_name); - - let mut test_state = StateMachine::new(TestContext { - shadow: Shadow::new(WifiConfig::default(), &mqtt_client, true).unwrap(), - update_cnt: 0, - }); - - loop { - if nb::block!(mqtt_eventloop.connect(&mut network)).expect("to connect to mqtt") { - log::info!("Successfully connected to broker"); - } - - match mqtt_eventloop.yield_event(&mut network) { - Ok(notification) => { - if test_state.spin(notification, &mqtt_client) { - break; + static NETWORK: StaticCell = StaticCell::new(); + let network = NETWORK.init(TlsNetwork::new(hostname.to_owned(), identity)); + + // Create the MQTT stack + let broker = + DomainBroker::<_, 128>::new(format!("{}:8883", hostname).as_str(), network).unwrap(); + + let config = Config::builder() + .client_id(thing_name.try_into().unwrap()) + .keepalive_interval(embassy_time::Duration::from_secs(50)) + .build(); + + static STATE: StaticCell> = StaticCell::new(); + let state = STATE.init(State::new()); + let (mut stack, client) = embedded_mqtt::new(state, config); + + // Create the shadow + let mut shadow = Shadow::new(TestShadow::default(), &client); + + let mqtt_fut = async { + // 1. Setup clean starting point (`desired = null, reported = null`) + cloud_update( + &client, + r#"{"state": {"desired": null, "reported": null} }"#.as_bytes(), + ) + .await; + + // 2. Do a `GetShadow` request to sync empty state + let _ = shadow.get_shadow().await.unwrap(); + + // 3. Update to initial shadow state from the device + shadow.report().await.unwrap(); + + // 4. Assert on the initial state + assert_shadow( + &client, + json!({ + "reported": { + "foo": 0 } - } - Err(_) => {} - } - } + }), + ) + .await; + + // 5. Update state from device + // 6. Assert on shadow state + // 7. Update state from cloud + cloud_update(&client, DESIRED_1.as_bytes()).await; + + // 8. Assert on shadow state + // 9. Update state from device + + // 10. Assert on shadow state + assert_shadow( + &client, + json!({ + "reported": { + "foo": 0 + }, + "desired": { + "foo": 42 + }, + "delta": { + "foo": 42 + } + }), + ) + .await; + + // 11. Update desired state from cloud + cloud_update( + &client, + r#"{"state": {"desired": {"bar": true}}}"#.as_bytes(), + ) + .await; + + // 12. Assert on shadow state + assert_shadow( + &client, + json!({ + "reported": { + "foo": 0 + }, + "desired": { + "foo": 42, + "bar": true + }, + "delta": { + "foo": 42, + "bar": true + } + }), + ) + .await; + }; + + let mut transport = NalTransport::new(network, broker); + let _ = embassy_time::with_timeout( + embassy_time::Duration::from_secs(60), + select::select(stack.run(&mut transport), mqtt_fut), + ) + .await; }