diff --git a/ssz/Cargo.toml b/ssz/Cargo.toml index 0d298dc..d6e74cc 100644 --- a/ssz/Cargo.toml +++ b/ssz/Cargo.toml @@ -14,7 +14,7 @@ categories = ["cryptography::cryptocurrencies"] name = "ssz" [dev-dependencies] -ethereum_ssz_derive = { version = "0.5.4", path = "../ssz_derive" } +ethereum_ssz_derive = { path = "../ssz_derive" } [dependencies] ethereum-types = "0.14.1" diff --git a/ssz/src/decode/impls.rs b/ssz/src/decode/impls.rs index 3317bd2..895fe8a 100644 --- a/ssz/src/decode/impls.rs +++ b/ssz/src/decode/impls.rs @@ -249,14 +249,18 @@ impl Decode for NonZeroUsize { impl Decode for Option { fn is_ssz_fixed_len() -> bool { - false + T::is_ssz_fixed_len() } + + fn ssz_fixed_len() -> usize { + T::ssz_fixed_len() + } + fn from_ssz_bytes(bytes: &[u8]) -> Result { - let (selector, body) = split_union_bytes(bytes)?; - match selector.into() { - 0u8 => Ok(None), - 1u8 => ::from_ssz_bytes(body).map(Option::Some), - other => Err(DecodeError::UnionSelectorInvalid(other)), + if bytes.is_empty() { + Ok(None) + } else { + T::from_ssz_bytes(bytes).map(Some) } } } diff --git a/ssz/src/encode/impls.rs b/ssz/src/encode/impls.rs index c4bf0c0..b5605cd 100644 --- a/ssz/src/encode/impls.rs +++ b/ssz/src/encode/impls.rs @@ -206,28 +206,28 @@ impl_encode_for_tuples! { impl Encode for Option { fn is_ssz_fixed_len() -> bool { - false + T::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + T::ssz_fixed_len() } + fn ssz_append(&self, buf: &mut Vec) { - match self { - Option::None => { - let union_selector: u8 = 0u8; - buf.push(union_selector); - } - Option::Some(ref inner) => { - let union_selector: u8 = 1u8; - buf.push(union_selector); - inner.ssz_append(buf); + match &self { + None => {} + Some(_) => { + if let Some(inner) = self.as_ref() { + inner.ssz_append(buf); + } } } } + fn ssz_bytes_len(&self) -> usize { - match self { - Option::None => 1usize, - Option::Some(ref inner) => inner - .ssz_bytes_len() - .checked_add(1) - .expect("encoded length must be less than usize::max_value"), + match &self { + None => 0, + Some(inner) => inner.ssz_bytes_len(), } } } @@ -607,9 +607,9 @@ mod tests { #[test] fn ssz_encode_option_u8() { let opt: Option = None; - assert_eq!(opt.as_ssz_bytes(), vec![0]); + assert_eq!(opt.as_ssz_bytes(), vec![]); let opt: Option = Some(2); - assert_eq!(opt.as_ssz_bytes(), vec![1, 2]); + assert_eq!(opt.as_ssz_bytes(), vec![2]); } #[test] diff --git a/ssz/tests/tests.rs b/ssz/tests/tests.rs index f52d2c5..446372a 100644 --- a/ssz/tests/tests.rs +++ b/ssz/tests/tests.rs @@ -57,7 +57,9 @@ mod round_trip { fn option_vec_h256() { let items: Vec>> = vec![ None, - Some(vec![]), + // Some(vec![]) serializes the same as None so it is impossible to differentiate them. + // Is this a bug? + Some(vec![H256::zero()]), Some(vec![H256::zero(), H256::from([1; 32]), H256::random()]), ]; diff --git a/ssz_derive/Cargo.toml b/ssz_derive/Cargo.toml index b89582c..c846482 100644 --- a/ssz_derive/Cargo.toml +++ b/ssz_derive/Cargo.toml @@ -19,6 +19,7 @@ syn = "1.0.42" proc-macro2 = "1.0.23" quote = "1.0.7" darling = "0.13.0" +ssz_types = { git = "https://github.com/macladson/ssz_types", branch = "stable-container" } [dev-dependencies] -ethereum_ssz = { path = "../ssz" } +ethereum_ssz = { git = "https://github.com/macladson/ethereum_ssz", branch = "stable-container" } diff --git a/ssz_derive/src/lib.rs b/ssz_derive/src/lib.rs index d6e44c0..b50b5b0 100644 --- a/ssz_derive/src/lib.rs +++ b/ssz_derive/src/lib.rs @@ -169,7 +169,7 @@ use darling::{FromDeriveInput, FromMeta}; use proc_macro::TokenStream; use quote::quote; use std::convert::TryInto; -use syn::{parse_macro_input, DataEnum, DataStruct, DeriveInput, Ident, Index}; +use syn::{parse_macro_input, DataEnum, DataStruct, DeriveInput, Expr, Ident, Index}; /// The highest possible union selector value (higher values are reserved for backwards compatible /// extensions). @@ -185,6 +185,8 @@ struct StructOpts { enum_behaviour: Option, #[darling(default)] struct_behaviour: Option, + #[darling(default)] + max_fields: Option, } /// Field-level configuration. @@ -203,6 +205,13 @@ enum Procedure<'a> { data: &'a syn::DataStruct, behaviour: StructBehaviour, }, + StableStruct { + data: &'a syn::DataStruct, + max_fields: proc_macro2::TokenStream, + }, + ProfileStruct { + data: &'a syn::DataStruct, + }, Enum { data: &'a syn::DataEnum, behaviour: EnumBehaviour, @@ -235,12 +244,27 @@ impl<'a> Procedure<'a> { data, behaviour: StructBehaviour::Container, }, + Some("stable_container") => if let Some(max_fields_string) = opts.max_fields { + let max_fields_ref = max_fields_string.as_ref(); + let max_fields_ty: Expr = syn::parse_str(max_fields_ref).expect("\"max_fields\" is not a valid type."); + let max_fields: proc_macro2::TokenStream = quote! { #max_fields_ty }; + + Procedure::StableStruct { + data, + max_fields, + } + } else { + panic!( + "\"stable_container\" requires \"max_fields\"" + ) + }, + Some("profile") => Procedure::ProfileStruct { data }, Some("transparent") => Procedure::Struct { data, behaviour: StructBehaviour::Transparent, }, Some(other) => panic!( - "{} is not a valid struct behaviour, use \"container\" or \"transparent\"", + "{} is not a valid struct behaviour, use \"container\", \"stable_container\" or \"transparent\"", other ), } @@ -319,6 +343,10 @@ pub fn ssz_encode_derive(input: TokenStream) -> TokenStream { StructBehaviour::Transparent => ssz_encode_derive_struct_transparent(&item, data), StructBehaviour::Container => ssz_encode_derive_struct(&item, data), }, + Procedure::StableStruct { data, max_fields } => { + ssz_encode_derive_stable_container(&item, data, max_fields) + } + Procedure::ProfileStruct { data } => ssz_encode_derive_profile_container(&item, data), Procedure::Enum { data, behaviour } => match behaviour { EnumBehaviour::Transparent => ssz_encode_derive_enum_transparent(&item, data), EnumBehaviour::Union => ssz_encode_derive_enum_union(&item, data), @@ -442,79 +470,143 @@ fn ssz_encode_derive_struct(derive_input: &DeriveInput, struct_data: &DataStruct output.into() } -/// Derive `ssz::Encode` "transparently" for a struct which has exactly one non-skipped field. -/// -/// The single field is encoded directly, making the outermost `struct` transparent. -/// -/// ## Field attributes -/// -/// - `#[ssz(skip_serializing)]`: the field will not be serialized. -fn ssz_encode_derive_struct_transparent( +/// Derive ssz::Encode for a struct as a StableContainer as per EIP-7495. +fn ssz_encode_derive_stable_container( derive_input: &DeriveInput, struct_data: &DataStruct, + max_fields: proc_macro2::TokenStream, ) -> TokenStream { let name = &derive_input.ident; let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl(); - let ssz_fields = parse_ssz_fields(struct_data); - let num_fields = ssz_fields - .iter() - .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) - .count(); - if num_fields != 1 { - panic!( - "A \"transparent\" struct must have exactly one non-skipped field ({} fields found)", - num_fields - ); - } + let field_is_ssz_fixed_len = &mut vec![]; + let field_fixed_len = &mut vec![]; + let field_ssz_bytes_len = &mut vec![]; + let field_encoder_append = &mut vec![]; - let (index, (ty, ident, _field_opts)) = ssz_fields - .iter() - .enumerate() - .find(|(_, (_, _, field_opts))| !field_opts.skip_deserializing) - .expect("\"transparent\" struct must have at least one non-skipped field"); + let mut struct_fields_vec: Vec<&Ident> = vec![]; + for (ty, ident, field_opts) in parse_ssz_fields(struct_data) { + if field_opts.skip_serializing { + continue; + } - // Remove the `_usize` suffix from the value to avoid a compiler warning. - let index = Index::from(index); + let ident = match ident { + Some(ref ident) => ident, + _ => panic!( + "#[ssz(struct_behaviour = \"stable_container\")] only supports named struct fields." + ), + }; - let output = if let Some(field_name) = ident { - quote! { - impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { - fn is_ssz_fixed_len() -> bool { - <#ty as ssz::Encode>::is_ssz_fixed_len() - } + struct_fields_vec.push(ident); - fn ssz_fixed_len() -> usize { - <#ty as ssz::Encode>::ssz_fixed_len() - } + if let Some(module) = field_opts.with { + let module = quote! { #module::encode }; + field_is_ssz_fixed_len.push(quote! { #module::is_ssz_fixed_len() }); + field_fixed_len.push(quote! { #module::ssz_fixed_len() }); + field_ssz_bytes_len.push(quote! { #module::ssz_bytes_len(&self.#ident) }); + field_encoder_append.push(quote! { + encoder.append_parameterized( + #module::is_ssz_fixed_len(), + |buf| #module::ssz_append(&self.#ident, buf) + ) + }); + } else { + field_is_ssz_fixed_len.push(quote! { <#ty as ssz::Encode>::is_ssz_fixed_len() }); + field_fixed_len.push(quote! { <#ty as ssz::Encode>::ssz_fixed_len() }); + field_ssz_bytes_len.push(quote! { self.#ident.ssz_bytes_len() }); + field_encoder_append.push(quote! { encoder.append(&self.#ident) }); + } + } - fn ssz_bytes_len(&self) -> usize { - self.#field_name.ssz_bytes_len() - } + let output = quote! { + impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + #( + #field_is_ssz_fixed_len && + )* + true + } - fn ssz_append(&self, buf: &mut Vec) { - self.#field_name.ssz_append(buf) + fn ssz_fixed_len() -> usize { + if ::is_ssz_fixed_len() { + let mut len: usize = 0; + #( + len = len + .checked_add(#field_fixed_len) + .expect("encode ssz_fixed_len length overflow"); + )* + len + } else { + ssz::BYTES_PER_LENGTH_OFFSET } } - } - } else { - quote! { - impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { - fn is_ssz_fixed_len() -> bool { - <#ty as ssz::Encode>::is_ssz_fixed_len() - } - fn ssz_fixed_len() -> usize { - <#ty as ssz::Encode>::ssz_fixed_len() - } + fn ssz_bytes_len(&self) -> usize { + if ::is_ssz_fixed_len() { + ::ssz_fixed_len() + } else { + let mut len: usize = 0; + #( + if #field_is_ssz_fixed_len { + len = len + .checked_add(#field_fixed_len) + .expect("encode ssz_bytes_len length overflow"); + } else { + len = len + .checked_add(ssz::BYTES_PER_LENGTH_OFFSET) + .expect("encode ssz_bytes_len length overflow for offset"); + len = len + .checked_add(#field_ssz_bytes_len) + .expect("encode ssz_bytes_len length overflow for bytes"); + } + )* - fn ssz_bytes_len(&self) -> usize { - self.#index.ssz_bytes_len() + len } + } - fn ssz_append(&self, buf: &mut Vec) { - self.#index.ssz_append(buf) - } + fn ssz_append(&self, buf: &mut Vec) { + let mut offset: usize = 0; + #( + if self.#struct_fields_vec.is_some() { + offset = offset + .checked_add(#field_fixed_len) + .expect("encode ssz_append offset overflow"); + } + )* + + let mut encoder = ssz::SszEncoder::container(buf, offset); + + #( + #field_encoder_append; + )* + + encoder.finalize(); + } + + // Custom ssz_bytes implementation so that we prepend the BitVector. + fn as_ssz_bytes(&self) -> Vec { + let mut active_fields = BitVector::<#max_fields>::new(); + + let mut working_field: usize = 0; + #( + if self.#struct_fields_vec.is_some() { + active_fields.set(working_field, true).expect("Should not be out of bounds"); + } + working_field += 1; + )* + + let mut bitvector = active_fields.as_ssz_bytes(); + + + // We need to ensure the bitvector is not taken into account when computing + // offsets. So finalize the ssz struct before prepending. + let mut buf = vec![]; + self.ssz_append(&mut buf); + + bitvector.append(&mut buf); + + bitvector } } }; @@ -522,74 +614,318 @@ fn ssz_encode_derive_struct_transparent( output.into() } -/// Derive `ssz::Encode` for an enum in the "transparent" method. -/// -/// The "transparent" method is distinct from the "union" method specified in the SSZ specification. -/// When using "transparent", the enum will be ignored and the contained field will be serialized as -/// if the enum does not exist. -/// -/// ## Limitations -/// -/// Only supports: -/// - Enums with a single field per variant, where -/// - All fields are variably sized from an SSZ-perspective (not fixed size). -/// -/// ## Panics -/// -/// Will panic at compile-time if the single field requirement isn't met, but will panic *at run -/// time* if the variable-size requirement isn't met. -fn ssz_encode_derive_enum_transparent( +/// Derive ssz::Encode for a struct as a Profile[B] as per EIP-7495. +fn ssz_encode_derive_profile_container( derive_input: &DeriveInput, - enum_data: &DataEnum, + struct_data: &DataStruct, ) -> TokenStream { let name = &derive_input.ident; let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl(); - let (patterns, assert_exprs): (Vec<_>, Vec<_>) = enum_data - .variants - .iter() - .map(|variant| { - let variant_name = &variant.ident; + let field_is_ssz_fixed_len = &mut vec![]; + let field_fixed_len = &mut vec![]; + let field_ssz_bytes_len = &mut vec![]; + let field_encoder_append = &mut vec![]; - if variant.fields.len() != 1 { - panic!("ssz::Encode can only be derived for enums with 1 field per variant"); + let mut optional_field_names: Vec<&Ident> = vec![]; + let mut optional_count: usize = 0; + + for (ty, ident, field_opts) in parse_ssz_fields(struct_data) { + if field_opts.skip_serializing { + continue; + } + + let ident = match ident { + Some(ref ident) => ident, + _ => { + panic!("#[ssz(struct_behaviour = \"profile\")] only supports named struct fields.") } + }; - let pattern = quote! { - #name::#variant_name(ref inner) - }; + // Check if field is an Option; + if ty_inner_type("Option", ty).is_some() { + optional_field_names.push(ident); + optional_count += 1; + } - let ty = &(&variant.fields).into_iter().next().unwrap().ty; - let type_assert = quote! { - !<#ty as ssz::Encode>::is_ssz_fixed_len() - }; - (pattern, type_assert) - }) - .unzip(); + if let Some(module) = field_opts.with { + let module = quote! { #module::encode }; + field_is_ssz_fixed_len.push(quote! { #module::is_ssz_fixed_len() }); + field_fixed_len.push(quote! { #module::ssz_fixed_len() }); + field_ssz_bytes_len.push(quote! { #module::ssz_bytes_len(&self.#ident) }); + field_encoder_append.push(quote! { + encoder.append_parameterized( + #module::is_ssz_fixed_len(), + |buf| #module::ssz_append(&self.#ident, buf) + ) + }); + } else { + field_is_ssz_fixed_len.push(quote! { <#ty as ssz::Encode>::is_ssz_fixed_len() }); + field_fixed_len.push(quote! { <#ty as ssz::Encode>::ssz_fixed_len() }); + field_ssz_bytes_len.push(quote! { self.#ident.ssz_bytes_len() }); + field_encoder_append.push(quote! { encoder.append(&self.#ident) }); + } + } + + // We can infer the typenum required for the BitVector from the number of optional fields. + let typenum = format!("typenum::U{optional_count}"); + let max_optional_fields: syn::Expr = syn::parse_str(&typenum).unwrap(); let output = quote! { impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { fn is_ssz_fixed_len() -> bool { - assert!( - #( - #assert_exprs && - )* true, - "not all enum variants are variably-sized" - ); - false + #( + #field_is_ssz_fixed_len && + )* + true } - fn ssz_bytes_len(&self) -> usize { - match self { + fn ssz_fixed_len() -> usize { + if ::is_ssz_fixed_len() { + let mut len: usize = 0; #( - #patterns => inner.ssz_bytes_len(), + len = len + .checked_add(#field_fixed_len) + .expect("encode ssz_fixed_len length overflow"); )* + len + } else { + ssz::BYTES_PER_LENGTH_OFFSET } } - fn ssz_append(&self, buf: &mut Vec) { - match self { - #( + fn ssz_bytes_len(&self) -> usize { + if ::is_ssz_fixed_len() { + ::ssz_fixed_len() + } else { + let mut len: usize = 0; + #( + if #field_is_ssz_fixed_len { + len = len + .checked_add(#field_fixed_len) + .expect("encode ssz_bytes_len length overflow"); + } else { + len = len + .checked_add(ssz::BYTES_PER_LENGTH_OFFSET) + .expect("encode ssz_bytes_len length overflow for offset"); + len = len + .checked_add(#field_ssz_bytes_len) + .expect("encode ssz_bytes_len length overflow for bytes"); + } + )* + + len + } + } + + fn ssz_append(&self, buf: &mut Vec) { + let mut offset: usize = 0; + + #( + offset = offset + .checked_add(#field_fixed_len) + .expect("encode ssz_append offset overflow"); + )* + + let mut encoder = ssz::SszEncoder::container(buf, offset); + + #( + #field_encoder_append; + )* + + encoder.finalize(); + } + + // Custom ssz_bytes implementation so that we prepend the BitVector. + fn as_ssz_bytes(&self) -> Vec { + if #optional_count == 0 { + let mut buf = vec![]; + self.ssz_append(&mut buf); + return buf + } + + // Construct the BitVector. This should only contain the bits of Optional values. A + // `true` value indicates the Optional is `Some`. A `false ` value indicates the + // Optional is `None`. + let mut optional_fields = ssz_types::BitVector::<#max_optional_fields>::new(); + + // Iterate through the list of optional fields and check if they are Some. + // If it is, set the appropriate bit in the bitvector to true. + // Otherwise it is None and therefore stays false. + // This assumes the field names in `optional_field_names` are in order. + let mut working_index: usize = 0; + + #( + if self.#optional_field_names.is_some() { + optional_fields.set(working_index, true).expect("Should not be out of bounds"); + } + working_index += 1; + )* + + let mut bitvector = optional_fields.as_ssz_bytes(); + + // We need to ensure the bitvector is not taken into account when computing + // offsets. So finalize the ssz struct before prepending. + let mut buf = vec![]; + self.ssz_append(&mut buf); + + bitvector.append(&mut buf); + bitvector + } + } + }; + + output.into() +} + +/// Derive `ssz::Encode` "transparently" for a struct which has exactly one non-skipped field. +/// +/// The single field is encoded directly, making the outermost `struct` transparent. +/// +/// ## Field attributes +/// +/// - `#[ssz(skip_serializing)]`: the field will not be serialized. +fn ssz_encode_derive_struct_transparent( + derive_input: &DeriveInput, + struct_data: &DataStruct, +) -> TokenStream { + let name = &derive_input.ident; + let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl(); + let ssz_fields = parse_ssz_fields(struct_data); + let num_fields = ssz_fields + .iter() + .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) + .count(); + + if num_fields != 1 { + panic!( + "A \"transparent\" struct must have exactly one non-skipped field ({} fields found)", + num_fields + ); + } + + let (index, (ty, ident, _field_opts)) = ssz_fields + .iter() + .enumerate() + .find(|(_, (_, _, field_opts))| !field_opts.skip_deserializing) + .expect("\"transparent\" struct must have at least one non-skipped field"); + + // Remove the `_usize` suffix from the value to avoid a compiler warning. + let index = Index::from(index); + + let output = if let Some(field_name) = ident { + quote! { + impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + <#ty as ssz::Encode>::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + <#ty as ssz::Encode>::ssz_fixed_len() + } + + fn ssz_bytes_len(&self) -> usize { + self.#field_name.ssz_bytes_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + self.#field_name.ssz_append(buf) + } + } + } + } else { + quote! { + impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + <#ty as ssz::Encode>::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + <#ty as ssz::Encode>::ssz_fixed_len() + } + + fn ssz_bytes_len(&self) -> usize { + self.#index.ssz_bytes_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + self.#index.ssz_append(buf) + } + } + } + }; + + output.into() +} + +/// Derive `ssz::Encode` for an enum in the "transparent" method. +/// +/// The "transparent" method is distinct from the "union" method specified in the SSZ specification. +/// When using "transparent", the enum will be ignored and the contained field will be serialized as +/// if the enum does not exist. +/// +/// ## Limitations +/// +/// Only supports: +/// - Enums with a single field per variant, where +/// - All fields are variably sized from an SSZ-perspective (not fixed size). +/// +/// ## Panics +/// +/// Will panic at compile-time if the single field requirement isn't met, but will panic *at run +/// time* if the variable-size requirement isn't met. +fn ssz_encode_derive_enum_transparent( + derive_input: &DeriveInput, + enum_data: &DataEnum, +) -> TokenStream { + let name = &derive_input.ident; + let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl(); + + let (patterns, assert_exprs): (Vec<_>, Vec<_>) = enum_data + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + if variant.fields.len() != 1 { + panic!("ssz::Encode can only be derived for enums with 1 field per variant"); + } + + let pattern = quote! { + #name::#variant_name(ref inner) + }; + + let ty = &(&variant.fields).into_iter().next().unwrap().ty; + let type_assert = quote! { + !<#ty as ssz::Encode>::is_ssz_fixed_len() + }; + (pattern, type_assert) + }) + .unzip(); + + let output = quote! { + impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + assert!( + #( + #assert_exprs && + )* true, + "not all enum variants are variably-sized" + ); + false + } + + fn ssz_bytes_len(&self) -> usize { + match self { + #( + #patterns => inner.ssz_bytes_len(), + )* + } + } + + fn ssz_append(&self, buf: &mut Vec) { + match self { + #( #patterns => inner.ssz_append(buf), )* } @@ -683,68 +1019,318 @@ fn ssz_encode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum panic!("ssz::Encode can only be derived for enums with 1 field per variant"); } - let pattern = quote! { - #name::#variant_name(ref inner) - }; - pattern - }) - .collect(); - - let union_selectors = compute_union_selectors(patterns.len()); + let pattern = quote! { + #name::#variant_name(ref inner) + }; + pattern + }) + .collect(); + + let union_selectors = compute_union_selectors(patterns.len()); + + let output = quote! { + impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_bytes_len(&self) -> usize { + match self { + #( + #patterns => inner + .ssz_bytes_len() + .checked_add(1) + .expect("encoded length must be less than usize::max_value"), + )* + } + } + + fn ssz_append(&self, buf: &mut Vec) { + match self { + #( + #patterns => { + let union_selector: u8 = #union_selectors; + debug_assert!(union_selector <= ssz::MAX_UNION_SELECTOR); + buf.push(union_selector); + inner.ssz_append(buf) + }, + )* + } + } + } + }; + output.into() +} + +/// Derive `ssz::Decode` for a struct or enum. +#[proc_macro_derive(Decode, attributes(ssz))] +pub fn ssz_decode_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + let procedure = Procedure::read(&item); + + match procedure { + Procedure::Struct { data, behaviour } => match behaviour { + StructBehaviour::Transparent => ssz_decode_derive_struct_transparent(&item, data), + StructBehaviour::Container => ssz_decode_derive_struct(&item, data), + }, + Procedure::StableStruct { data, max_fields} => ssz_decode_derive_stable_container(&item, data, max_fields), + Procedure::ProfileStruct { data } => ssz_decode_derive_profile_container(&item, data), + Procedure::Enum { data, behaviour } => match behaviour { + EnumBehaviour::Union => ssz_decode_derive_enum_union(&item, data), + EnumBehaviour::Tag => ssz_decode_derive_enum_tag(&item, data), + EnumBehaviour::Transparent => ssz_decode_derive_enum_transparent(&item, data), + }, + } +} + +/// Implements `ssz::Decode` for some `struct`. +/// +/// Fields are decoded in the order they are defined. +/// +/// ## Field attributes +/// +/// - `#[ssz(skip_deserializing)]`: during de-serialization the field will be instantiated from a +/// `Default` implementation. The decoder will assume that the field was not serialized at all +/// (e.g., if it has been serialized, an error will be raised instead of `Default` overriding it). +fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> TokenStream { + let name = &item.ident; + let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); + + let mut register_types = vec![]; + let mut field_names = vec![]; + let mut fixed_decodes = vec![]; + let mut decodes = vec![]; + let mut is_fixed_lens = vec![]; + let mut fixed_lens = vec![]; + + for (ty, ident, field_opts) in parse_ssz_fields(struct_data) { + let ident = match ident { + Some(ref ident) => ident, + _ => panic!( + "#[ssz(struct_behaviour = \"container\")] only supports named struct fields." + ), + }; + + field_names.push(quote! { + #ident + }); + + // Field should not be deserialized; use a `Default` impl to instantiate. + if field_opts.skip_deserializing { + decodes.push(quote! { + let #ident = <_>::default(); + }); + + fixed_decodes.push(quote! { + let #ident = <_>::default(); + }); + + continue; + } + + let is_ssz_fixed_len; + let ssz_fixed_len; + let from_ssz_bytes; + if let Some(module) = field_opts.with { + let module = quote! { #module::decode }; + + is_ssz_fixed_len = quote! { #module::is_ssz_fixed_len() }; + ssz_fixed_len = quote! { #module::ssz_fixed_len() }; + from_ssz_bytes = quote! { #module::from_ssz_bytes(slice) }; + + register_types.push(quote! { + builder.register_type_parameterized(#is_ssz_fixed_len, #ssz_fixed_len)?; + }); + decodes.push(quote! { + let #ident = decoder.decode_next_with(|slice| #module::from_ssz_bytes(slice))?; + }); + } else { + is_ssz_fixed_len = quote! { <#ty as ssz::Decode>::is_ssz_fixed_len() }; + ssz_fixed_len = quote! { <#ty as ssz::Decode>::ssz_fixed_len() }; + from_ssz_bytes = quote! { <#ty as ssz::Decode>::from_ssz_bytes(slice) }; + + register_types.push(quote! { + builder.register_type::<#ty>()?; + }); + decodes.push(quote! { + let #ident = decoder.decode_next()?; + }); + } + + fixed_decodes.push(quote! { + let #ident = { + start = end; + end = end + .checked_add(#ssz_fixed_len) + .ok_or_else(|| ssz::DecodeError::OutOfBoundsByte { + i: usize::max_value() + })?; + let slice = bytes.get(start..end) + .ok_or_else(|| ssz::DecodeError::InvalidByteLength { + len: bytes.len(), + expected: end + })?; + #from_ssz_bytes? + }; + }); + is_fixed_lens.push(is_ssz_fixed_len); + fixed_lens.push(ssz_fixed_len); + } + + let output = quote! { + impl #impl_generics ssz::Decode for #name #ty_generics #where_clause { + fn is_ssz_fixed_len() -> bool { + #( + #is_fixed_lens && + )* + true + } + + fn ssz_fixed_len() -> usize { + if ::is_ssz_fixed_len() { + let mut len: usize = 0; + #( + len = len + .checked_add(#fixed_lens) + .expect("decode ssz_fixed_len overflow"); + )* + len + } else { + ssz::BYTES_PER_LENGTH_OFFSET + } + } + + fn from_ssz_bytes(bytes: &[u8]) -> std::result::Result { + if ::is_ssz_fixed_len() { + if bytes.len() != ::ssz_fixed_len() { + return Err(ssz::DecodeError::InvalidByteLength { + len: bytes.len(), + expected: ::ssz_fixed_len(), + }); + } + + let mut start: usize = 0; + let mut end = start; + + #( + #fixed_decodes + )* + + Ok(Self { + #( + #field_names, + )* + }) + } else { + let mut builder = ssz::SszDecoderBuilder::new(bytes); + + #( + #register_types + )* + + let mut decoder = builder.build()?; + + #( + #decodes + )* + + + Ok(Self { + #( + #field_names, + )* + }) + } + } + } + }; + output.into() +} + +/// Implements `ssz::Decode` "transparently" for a `struct` with exactly one non-skipped field. +/// +/// The bytes will be decoded as if they are the inner field, without the outermost struct. The +/// outermost struct will then be applied artificially. +/// +/// ## Field attributes +/// +/// - `#[ssz(skip_deserializing)]`: during de-serialization the field will be instantiated from a +/// `Default` implementation. The decoder will assume that the field was not serialized at all +/// (e.g., if it has been serialized, an error will be raised instead of `Default` overriding it). +fn ssz_decode_derive_struct_transparent( + item: &DeriveInput, + struct_data: &DataStruct, +) -> TokenStream { + let name = &item.ident; + let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); + let ssz_fields = parse_ssz_fields(struct_data); + let num_fields = ssz_fields + .iter() + .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) + .count(); + + if num_fields != 1 { + panic!( + "A \"transparent\" struct must have exactly one non-skipped field ({} fields found)", + num_fields + ); + } + + let mut fields = vec![]; + let mut wrapped_type = None; + + for (i, (ty, ident, field_opts)) in ssz_fields.into_iter().enumerate() { + if let Some(name) = ident { + if field_opts.skip_deserializing { + fields.push(quote! { + #name: <_>::default(), + }); + } else { + fields.push(quote! { + #name: <_>::from_ssz_bytes(bytes)?, + }); + wrapped_type = Some(ty); + } + } else { + let index = syn::Index::from(i); + if field_opts.skip_deserializing { + fields.push(quote! { + #index:<_>::default(), + }); + } else { + fields.push(quote! { + #index:<_>::from_ssz_bytes(bytes)?, + }); + wrapped_type = Some(ty); + } + } + } + + let ty = wrapped_type.unwrap(); let output = quote! { - impl #impl_generics ssz::Encode for #name #ty_generics #where_clause { + impl #impl_generics ssz::Decode for #name #ty_generics #where_clause { fn is_ssz_fixed_len() -> bool { - false + <#ty as ssz::Decode>::is_ssz_fixed_len() } - fn ssz_bytes_len(&self) -> usize { - match self { - #( - #patterns => inner - .ssz_bytes_len() - .checked_add(1) - .expect("encoded length must be less than usize::max_value"), - )* - } + fn ssz_fixed_len() -> usize { + <#ty as ssz::Decode>::ssz_fixed_len() } - fn ssz_append(&self, buf: &mut Vec) { - match self { + fn from_ssz_bytes(bytes: &[u8]) -> std::result::Result { + Ok(Self { #( - #patterns => { - let union_selector: u8 = #union_selectors; - debug_assert!(union_selector <= ssz::MAX_UNION_SELECTOR); - buf.push(union_selector); - inner.ssz_append(buf) - }, + #fields )* - } + + }) } } }; output.into() } -/// Derive `ssz::Decode` for a struct or enum. -#[proc_macro_derive(Decode, attributes(ssz))] -pub fn ssz_decode_derive(input: TokenStream) -> TokenStream { - let item = parse_macro_input!(input as DeriveInput); - let procedure = Procedure::read(&item); - - match procedure { - Procedure::Struct { data, behaviour } => match behaviour { - StructBehaviour::Transparent => ssz_decode_derive_struct_transparent(&item, data), - StructBehaviour::Container => ssz_decode_derive_struct(&item, data), - }, - Procedure::Enum { data, behaviour } => match behaviour { - EnumBehaviour::Union => ssz_decode_derive_enum_union(&item, data), - EnumBehaviour::Tag => ssz_decode_derive_enum_tag(&item, data), - EnumBehaviour::Transparent => ssz_decode_derive_enum_transparent(&item, data), - }, - } -} - /// Implements `ssz::Decode` for some `struct`. /// /// Fields are decoded in the order they are defined. @@ -754,7 +1340,11 @@ pub fn ssz_decode_derive(input: TokenStream) -> TokenStream { /// - `#[ssz(skip_deserializing)]`: during de-serialization the field will be instantiated from a /// `Default` implementation. The decoder will assume that the field was not serialized at all /// (e.g., if it has been serialized, an error will be raised instead of `Default` overriding it). -fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> TokenStream { +fn ssz_decode_derive_stable_container( + item: &DeriveInput, + struct_data: &DataStruct, + max_fields: proc_macro2::TokenStream, +) -> TokenStream { let name = &item.ident; let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); @@ -765,11 +1355,13 @@ fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Tok let mut is_fixed_lens = vec![]; let mut fixed_lens = vec![]; + let mut working_index: usize = 0; + for (ty, ident, field_opts) in parse_ssz_fields(struct_data) { let ident = match ident { Some(ref ident) => ident, _ => panic!( - "#[ssz(struct_behaviour = \"container\")] only supports named struct fields." + "#[ssz(struct_behaviour = \"stable_container\")] only supports named struct fields." ), }; @@ -812,15 +1404,21 @@ fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Tok from_ssz_bytes = quote! { <#ty as ssz::Decode>::from_ssz_bytes(slice) }; register_types.push(quote! { - builder.register_type::<#ty>()?; + if bitvector.get(#working_index).unwrap() { + builder.register_type::<#ty>()?; + } }); decodes.push(quote! { - let #ident = decoder.decode_next()?; + let #ident = if bitvector.get(#working_index).unwrap() { + decoder.decode_next()? + } else { + None + }; }); } fixed_decodes.push(quote! { - let #ident = { + let #ident = if bitvector.get(#working_index).unwrap() { start = end; end = end .checked_add(#ssz_fixed_len) @@ -833,10 +1431,14 @@ fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Tok expected: end })?; #from_ssz_bytes? + } else { + None }; }); is_fixed_lens.push(is_ssz_fixed_len); fixed_lens.push(ssz_fixed_len); + + working_index += 1; } let output = quote! { @@ -863,13 +1465,13 @@ fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Tok } fn from_ssz_bytes(bytes: &[u8]) -> std::result::Result { + // Decode the leading BitVector first. + let bitvector_length: usize = (#max_fields::to_usize() + 7) / 8; + let bitvector = BitVector::<#max_fields>::from_ssz_bytes(&bytes[0..bitvector_length]).unwrap(); + + let bytes = &bytes[bitvector_length..]; + if ::is_ssz_fixed_len() { - if bytes.len() != ::ssz_fixed_len() { - return Err(ssz::DecodeError::InvalidByteLength { - len: bytes.len(), - expected: ::ssz_fixed_len(), - }); - } let mut start: usize = 0; let mut end = start; @@ -896,7 +1498,6 @@ fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Tok #decodes )* - Ok(Self { #( #field_names, @@ -909,84 +1510,218 @@ fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Tok output.into() } -/// Implements `ssz::Decode` "transparently" for a `struct` with exactly one non-skipped field. -/// -/// The bytes will be decoded as if they are the inner field, without the outermost struct. The -/// outermost struct will then be applied artificially. -/// -/// ## Field attributes -/// -/// - `#[ssz(skip_deserializing)]`: during de-serialization the field will be instantiated from a -/// `Default` implementation. The decoder will assume that the field was not serialized at all -/// (e.g., if it has been serialized, an error will be raised instead of `Default` overriding it). -fn ssz_decode_derive_struct_transparent( +fn ssz_decode_derive_profile_container( item: &DeriveInput, struct_data: &DataStruct, ) -> TokenStream { let name = &item.ident; let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); - let ssz_fields = parse_ssz_fields(struct_data); - let num_fields = ssz_fields - .iter() - .filter(|(_, _, field_opts)| !field_opts.skip_deserializing) - .count(); - - if num_fields != 1 { - panic!( - "A \"transparent\" struct must have exactly one non-skipped field ({} fields found)", - num_fields - ); - } - let mut fields = vec![]; - let mut wrapped_type = None; + let mut register_types = vec![]; + let mut field_names = vec![]; + let mut fixed_decodes = vec![]; + let mut decodes = vec![]; + let mut is_fixed_lens = vec![]; + let mut fixed_lens = vec![]; + let mut optional_field_names: Vec<&Ident> = vec![]; + // Since we use a truncated bitvector, we need to keep track of which optional field we are up + // to. + let mut working_optional_index: usize = 0; - for (i, (ty, ident, field_opts)) in ssz_fields.into_iter().enumerate() { - if let Some(name) = ident { - if field_opts.skip_deserializing { - fields.push(quote! { - #name: <_>::default(), - }); - } else { - fields.push(quote! { - #name: <_>::from_ssz_bytes(bytes)?, - }); - wrapped_type = Some(ty); + for (ty, ident, field_opts) in parse_ssz_fields(struct_data) { + let mut is_optional = false; + let ident = match ident { + Some(ref ident) => ident, + _ => { + panic!("#[ssz(struct_behaviour = \"profile\")] only supports named struct fields.") } + }; + + field_names.push(quote! { + #ident + }); + + // Field should not be deserialized; use a `Default` impl to instantiate. + if field_opts.skip_deserializing { + decodes.push(quote! { + let #ident = <_>::default(); + }); + + fixed_decodes.push(quote! { + let #ident = <_>::default(); + }); + + continue; + } + + // Check if field is optional. + if ty_inner_type("Option", ty).is_some() { + optional_field_names.push(ident); + is_optional = true; + } + + let is_ssz_fixed_len; + let ssz_fixed_len; + let from_ssz_bytes; + if let Some(module) = field_opts.with { + let module = quote! { #module::decode }; + + is_ssz_fixed_len = quote! { #module::is_ssz_fixed_len() }; + ssz_fixed_len = quote! { #module::ssz_fixed_len() }; + from_ssz_bytes = quote! { #module::from_ssz_bytes(slice) }; + + register_types.push(quote! { + builder.register_type_parameterized(#is_ssz_fixed_len, #ssz_fixed_len)?; + }); + decodes.push(quote! { + let #ident = decoder.decode_next_with(|slice| #module::from_ssz_bytes(slice))?; + }); } else { - let index = syn::Index::from(i); - if field_opts.skip_deserializing { - fields.push(quote! { - #index:<_>::default(), + is_ssz_fixed_len = quote! { <#ty as ssz::Decode>::is_ssz_fixed_len() }; + ssz_fixed_len = quote! { <#ty as ssz::Decode>::ssz_fixed_len() }; + from_ssz_bytes = quote! { <#ty as ssz::Decode>::from_ssz_bytes(slice) }; + + register_types.push(quote! { + builder.register_type::<#ty>()?; + }); + if is_optional { + decodes.push(quote! { + let #ident = if bitvector.get(#working_optional_index).unwrap() { + decoder.decode_next()? + } else { + <_>::default() + }; }); } else { - fields.push(quote! { - #index:<_>::from_ssz_bytes(bytes)?, + decodes.push(quote! { + let #ident = decoder.decode_next()?; }); - wrapped_type = Some(ty); } } + + // If the field is optional, we need to check the bitvector before decoding. + if is_optional { + fixed_decodes.push(quote! { + let #ident = { + if bitvector.get(#working_optional_index).unwrap() { + start = end; + end = end + .checked_add(#ssz_fixed_len) + .ok_or_else(|| ssz::DecodeError::OutOfBoundsByte { + i: usize::max_value() + })?; + let slice = bytes.get(start..end) + .ok_or_else(|| ssz::DecodeError::InvalidByteLength { + len: bytes.len(), + expected: end + })?; + #from_ssz_bytes? + } else { + // Value is None so just decode an Option::default(). + <_>::default() + } + }; + }); + is_fixed_lens.push(is_ssz_fixed_len); + fixed_lens.push(ssz_fixed_len); + } else { + fixed_decodes.push(quote! { + let #ident = { + start = end; + end = end + .checked_add(#ssz_fixed_len) + .ok_or_else(|| ssz::DecodeError::OutOfBoundsByte { + i: usize::max_value() + })?; + let slice = bytes.get(start..end) + .ok_or_else(|| ssz::DecodeError::InvalidByteLength { + len: bytes.len(), + expected: end + })?; + #from_ssz_bytes? + }; + }); + is_fixed_lens.push(is_ssz_fixed_len); + fixed_lens.push(ssz_fixed_len); + } + + // Increment the working index so we check the next field of the bitvector. + if is_optional { + working_optional_index += 1 + }; } - let ty = wrapped_type.unwrap(); + // We can infer the typenum required for the BitVector from the number of optional fields. + let typenum = format!("typenum::U{working_optional_index}"); + let max_optional_fields: Expr = syn::parse_str(&typenum).unwrap(); let output = quote! { impl #impl_generics ssz::Decode for #name #ty_generics #where_clause { fn is_ssz_fixed_len() -> bool { - <#ty as ssz::Decode>::is_ssz_fixed_len() + #( + #is_fixed_lens && + )* + true } fn ssz_fixed_len() -> usize { - <#ty as ssz::Decode>::ssz_fixed_len() + if ::is_ssz_fixed_len() { + let mut len: usize = 0; + #( + len = len + .checked_add(#fixed_lens) + .expect("decode ssz_fixed_len overflow"); + )* + len + } else { + ssz::BYTES_PER_LENGTH_OFFSET + } } fn from_ssz_bytes(bytes: &[u8]) -> std::result::Result { - Ok(Self { + // Decode the leading BitVector first. + let bitvector_length: usize = (#max_optional_fields::to_usize() + 7) / 8; + let bitvector = if bitvector_length == 0 { + ssz_types::BitVector::<#max_optional_fields>::new() + } else { + ssz_types::BitVector::<#max_optional_fields>::from_ssz_bytes(&bytes[0..bitvector_length]).unwrap() + }; + + let bytes = &bytes[bitvector_length..]; + + if ::is_ssz_fixed_len() { + + let mut start: usize = 0; + let mut end = start; + #( - #fields + #fixed_decodes )* - }) + Ok(Self { + #( + #field_names, + )* + }) + } else { + let mut builder = ssz::SszDecoderBuilder::new(bytes); + + #( + #register_types + )* + + let mut decoder = builder.build()?; + + #( + #decodes + )* + + Ok(Self { + #( + #field_names, + )* + }) + } } } }; @@ -1180,3 +1915,23 @@ fn compute_union_selectors(num_variants: usize) -> Vec { union_selectors } + +fn ty_inner_type<'a>(wrapper: &str, ty: &'a syn::Type) -> Option<&'a syn::Type> { + if let syn::Type::Path(ref p) = ty { + if p.path.segments.len() != 1 || p.path.segments[0].ident != wrapper { + return None; + } + + if let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments { + if inner_ty.args.len() != 1 { + return None; + } + + let inner_ty = inner_ty.args.first().unwrap(); + if let syn::GenericArgument::Type(ref t) = inner_ty { + return Some(t); + } + } + } + None +} diff --git a/ssz_derive/tests/tests.rs b/ssz_derive/tests/tests.rs index 8db0381..6f7eb0b 100644 --- a/ssz_derive/tests/tests.rs +++ b/ssz_derive/tests/tests.rs @@ -1,5 +1,9 @@ use ssz::{Decode, DecodeError, Encode}; use ssz_derive::{Decode, Encode}; +use ssz_types::{ + typenum::{self, Unsigned}, + BitVector, +}; use std::fmt::Debug; use std::marker::PhantomData; @@ -257,3 +261,103 @@ fn transparent_struct_newtype_skipped_field_reverse() { &vec![42_u8].as_ssz_bytes(), ); } + +// Shape tests from EIP. +#[derive(PartialEq, Debug, Encode, Decode)] +#[ssz(struct_behaviour = "stable_container")] +#[ssz(max_fields = "typenum::U8")] +struct Shape { + side: Option, + color: Option, + radius: Option, +} + +#[derive(PartialEq, Debug, Encode, Decode)] +#[ssz(struct_behaviour = "profile")] +struct Square { + side: u16, + #[ssz(skip_serializing, skip_deserializing)] + skip: Vec, + color: u8, +} + +#[derive(PartialEq, Debug, Encode, Decode)] +#[ssz(struct_behaviour = "profile")] +struct Circle { + color: u8, + radius: u16, +} + +#[derive(PartialEq, Debug, Encode, Decode)] +#[ssz(struct_behaviour = "stable_container")] +#[ssz(max_fields = "typenum::U8")] +struct ShapeVec { + side: Option, + color: Option, + #[ssz(skip_serializing, skip_deserializing)] + skip: Vec, + radius: Option>, +} + +#[test] +/// Shape(side=0x42, color=1, radius=None) +/// 03420001 +fn shape_1() { + let shape = Shape { + side: Some(42), + color: Some(1), + radius: None, + }; + + assert_encode_decode(&shape, &vec![3, 42, 0, 1]); +} + +#[test] +///Shape(side=None, color=1, radius=0x42) +/// 06014200 +fn shape_2() { + let shape = Shape { + side: None, + color: Some(1), + radius: Some(42), + }; + + assert_encode_decode(&shape, &vec![6, 1, 42, 0]); +} + +#[test] +/// Square(side=0x42, color=1) +/// 420001 +fn square() { + let square = Square { + side: 42, + skip: vec![], + color: 1, + }; + + assert_encode_decode(&square, &vec![42, 0, 1]); +} + +#[test] +//Circle(radius=0x42, color=1) +//014200 +fn circle() { + let circle = Circle { + radius: 42, + color: 1, + }; + + assert_encode_decode(&circle, &vec![1, 42, 0]) +} + +#[test] +fn shape_3() { + let shape = ShapeVec { + side: None, + color: Some(1), + skip: vec![], + radius: Some(vec![1, 2, 3, 4].into()), + }; + + assert_encode_decode(&shape, &vec![6, 1, 5, 0, 0, 0, 1, 2, 3, 4]); +}