From 7b0a8fc353b34329d4afab8fe1e493620de866ae Mon Sep 17 00:00:00 2001 From: Dwayne Sykes Date: Wed, 8 Oct 2025 20:09:09 -0500 Subject: [PATCH 1/3] feat: custom header support Add true support for custom headers with full backwards compatibility. --- examples/custom_header.rs | 26 ++++++++++---- src/decoding.rs | 73 +++++++++++++++++++++++---------------- src/encoding.rs | 14 +++++--- src/header.rs | 43 +++++++++++++++++++---- src/jws.rs | 7 ++-- src/lib.rs | 6 ++-- tests/header/mod.rs | 33 +++++++++++++++++- 7 files changed, 151 insertions(+), 51 deletions(-) diff --git a/examples/custom_header.rs b/examples/custom_header.rs index 7c343d8f..35359235 100644 --- a/examples/custom_header.rs +++ b/examples/custom_header.rs @@ -2,7 +2,9 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use jsonwebtoken::errors::ErrorKind; -use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; +use jsonwebtoken::{ + Algorithm, DecodingKey, EncodingKey, Validation, decode_with_custom_header, encode, header, +}; #[derive(Debug, Serialize, Deserialize, Clone)] struct Claims { @@ -11,6 +13,19 @@ struct Claims { exp: u64, } +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +struct CustomHeader { + alg: Algorithm, + custom: String, + another_custom_field: Option, +} +impl header::FromEncoded for CustomHeader {} +impl header::Alg for CustomHeader { + fn alg(&self) -> &Algorithm { + &self.alg + } +} + fn main() { let my_claims = Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), exp: 10000000000 }; @@ -19,11 +34,10 @@ fn main() { let mut extras = HashMap::with_capacity(1); extras.insert("custom".to_string(), "header".to_string()); - let header = Header { - kid: Some("signing_key".to_owned()), + let header = CustomHeader { alg: Algorithm::HS512, - extras, - ..Default::default() + custom: "custom".into(), + another_custom_field: 42.into(), }; let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) { @@ -32,7 +46,7 @@ fn main() { }; println!("{:?}", token); - let token_data = match decode::( + let token_data = match decode_with_custom_header::( &token, &DecodingKey::from_secret(key), &Validation::new(Algorithm::HS512), diff --git a/src/decoding.rs b/src/decoding.rs index 6d1fc42f..c566a10e 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -7,7 +7,7 @@ use crate::Algorithm; use crate::algorithms::AlgorithmFamily; use crate::crypto::JwtVerifier; use crate::errors::{ErrorKind, Result, new_error}; -use crate::header::Header; +use crate::header::{Alg, FromEncoded, Header}; use crate::jwk::{AlgorithmParameters, Jwk}; #[cfg(feature = "use_pem")] use crate::pem::decoder::PemEncodedKey; @@ -37,15 +37,16 @@ use crate::crypto::rust_crypto::{ /// The return type of a successful call to [decode](fn.decode.html). #[derive(Debug)] -pub struct TokenData { +pub struct TokenData { /// The decoded JWT header - pub header: Header, + pub header: H, /// The decoded JWT claims pub claims: T, } -impl Clone for TokenData +impl Clone for TokenData where + H: Clone, T: Clone, { fn clone(&self) -> Self { @@ -281,21 +282,40 @@ pub fn decode( token: impl AsRef<[u8]>, key: &DecodingKey, validation: &Validation, -) -> Result> { +) -> Result> { + decode_with_custom_header(token, key, validation) +} + +/// Decode and validate a JWT with a custom header +/// +/// If the token or its signature is invalid, or the claims fail validation, this will return an +/// error. +pub fn decode_with_custom_header( + token: impl AsRef<[u8]>, + key: &DecodingKey, + validation: &Validation, +) -> Result> +where + H: DeserializeOwned + Clone + FromEncoded + Alg, + T: DeserializeOwned + Clone, +{ let token = token.as_ref(); - let header = decode_header(token)?; - if validation.validate_signature && !validation.algorithms.contains(&header.alg) { + let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.')); + let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.')); + let header = H::from_encoded(header)?; + + if validation.validate_signature && !validation.algorithms.contains(header.alg()) { return Err(new_error(ErrorKind::InvalidAlgorithm)); } - let verifying_provider = jwt_verifier_factory(&header.alg, key)?; + let verifying_provider = jwt_verifier_factory(header.alg(), key)?; + verify_signature_body(message, signature, &header, validation, verifying_provider)?; - let (header, claims) = verify_signature(token, validation, verifying_provider)?; + let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(payload)?; + validate(decoded_claims.deserialize()?, validation)?; - let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?; let claims = decoded_claims.deserialize()?; - validate(decoded_claims.deserialize()?, validation)?; Ok(TokenData { header, claims }) } @@ -357,10 +377,21 @@ pub fn decode_header(token: impl AsRef<[u8]>) -> Result
{ Header::from_encoded(header) } +/// Decode only the custom header of a JWT without decoding or validating the payload +pub fn decode_custom_header(token: impl AsRef<[u8]>) -> Result +where + H: DeserializeOwned + Clone + Alg + FromEncoded, +{ + let token = token.as_ref(); + let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.')); + let (_, header) = expect_two!(message.rsplitn(2, |b| *b == b'.')); + H::from_encoded(header) +} + pub(crate) fn verify_signature_body( message: &[u8], signature: &[u8], - header: &Header, + header: &impl Alg, validation: &Validation, verifying_provider: Box, ) -> Result<()> { @@ -376,7 +407,7 @@ pub(crate) fn verify_signature_body( } } - if validation.validate_signature && !validation.algorithms.contains(&header.alg) { + if validation.validate_signature && !validation.algorithms.contains(header.alg()) { return Err(new_error(ErrorKind::InvalidAlgorithm)); } @@ -388,19 +419,3 @@ pub(crate) fn verify_signature_body( Ok(()) } - -/// Verify the signature of a JWT, and return a header object and raw payload. -/// -/// If the token or its signature is invalid, it will return an error. -fn verify_signature<'a>( - token: &'a [u8], - validation: &Validation, - verifying_provider: Box, -) -> Result<(Header, &'a [u8])> { - let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.')); - let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.')); - let header = Header::from_encoded(header)?; - verify_signature_body(message, signature, &header, validation, verifying_provider)?; - - Ok((header, payload)) -} diff --git a/src/encoding.rs b/src/encoding.rs index 30a31953..0fc98b56 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -10,7 +10,7 @@ use crate::Algorithm; use crate::algorithms::AlgorithmFamily; use crate::crypto::JwtSigner; use crate::errors::{ErrorKind, Result, new_error}; -use crate::header::Header; +use crate::header::Alg; #[cfg(feature = "use_pem")] use crate::pem::decoder::PemEncodedKey; use crate::serialization::{b64_encode, b64_encode_part}; @@ -171,14 +171,18 @@ impl Debug for EncodingKey { /// // This will create a JWT using HS256 as algorithm /// let token = encode(&Header::default(), &my_claims, &EncodingKey::from_secret("secret".as_ref())).unwrap(); /// ``` -pub fn encode(header: &Header, claims: &T, key: &EncodingKey) -> Result { - if key.family != header.alg.family() { +pub fn encode( + header: &H, + claims: &T, + key: &EncodingKey, +) -> Result { + if key.family != header.alg().family() { return Err(new_error(ErrorKind::InvalidAlgorithm)); } - let signing_provider = jwt_signer_factory(&header.alg, key)?; + let signing_provider = jwt_signer_factory(header.alg(), key)?; - if signing_provider.algorithm() != header.alg { + if signing_provider.algorithm() != *header.alg() { return Err(new_error(ErrorKind::InvalidAlgorithm)); } diff --git a/src/header.rs b/src/header.rs index 76b744a2..dfb288c8 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,3 +1,4 @@ +//! Traits and datastructures for JWT Headers use std::collections::HashMap; use std::result; @@ -25,12 +26,19 @@ const ENC_A256GCM: &str = "A256GCM"; #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[allow(clippy::upper_case_acronyms, non_camel_case_types)] pub enum Enc { + /// HMAC-256 A128CBC_HS256, + /// HMAC-384 A192CBC_HS384, + /// HMAC-512 A256CBC_HS512, + /// AES-GCM 128 A128GCM, + /// AES-GCM 192 A192GCM, + /// AES-GCM 256 A256GCM, + /// Other encryption type Other(String), } @@ -76,7 +84,9 @@ impl<'de> Deserialize<'de> for Enc { /// Defined in [RFC7516#4.1.3](https://datatracker.ietf.org/doc/html/rfc7516#section-4.1.3). #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Zip { + /// Basic Deflate Compression Deflate, + /// Other Compression Other(String), } @@ -106,6 +116,25 @@ impl<'de> Deserialize<'de> for Zip { } } +/// Getter for `alg` attribute of a JWT Header +/// This must be implemented by custom header structs +pub trait Alg { + /// Getter for `alg` + fn alg(&self) -> &Algorithm; +} + +/// Decodes a JWT part from b64 +pub trait FromEncoded { + /// Converts an encoded JWT part into the Header struct if possible + fn from_encoded>(encoded_part: T) -> Result + where + Self: Sized + serde::de::DeserializeOwned, + { + let decoded = b64_decode(encoded_part)?; + Ok(serde_json::from_slice(&decoded)?) + } +} + /// A basic JWT header, the alg defaults to HS256 and typ is automatically /// set to `JWT`. All the other fields are optional. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -213,12 +242,6 @@ impl Header { } } - /// Converts an encoded part into the Header struct if possible - pub(crate) fn from_encoded>(encoded_part: T) -> Result { - let decoded = b64_decode(encoded_part)?; - Ok(serde_json::from_slice(&decoded)?) - } - /// Decodes the X.509 certificate chain into ASN.1 DER format. pub fn x5c_der(&self) -> Result>>> { Ok(self @@ -237,3 +260,11 @@ impl Default for Header { Header::new(Algorithm::default()) } } + +impl Alg for Header { + fn alg(&self) -> &Algorithm { + &self.alg + } +} + +impl FromEncoded for Header {} diff --git a/src/jws.rs b/src/jws.rs index 57dc02a2..fb1da764 100644 --- a/src/jws.rs +++ b/src/jws.rs @@ -5,7 +5,10 @@ use crate::crypto::sign; use crate::errors::{ErrorKind, Result, new_error}; use crate::serialization::{DecodedJwtPartClaims, b64_encode_part}; use crate::validation::validate; -use crate::{DecodingKey, EncodingKey, Header, TokenData, Validation}; +use crate::{ + DecodingKey, EncodingKey, TokenData, Validation, + header::{FromEncoded, Header}, +}; use crate::decoding::{jwt_verifier_factory, verify_signature_body}; use serde::de::DeserializeOwned; @@ -63,7 +66,7 @@ pub fn decode( jws: &Jws, key: &DecodingKey, validation: &Validation, -) -> Result> { +) -> Result> { let header = Header::from_encoded(&jws.protected)?; let message = [jws.protected.as_str(), jws.payload.as_str()].join("."); diff --git a/src/lib.rs b/src/lib.rs index 920b996b..5a247ca3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,9 @@ compile_error!( compile_error!("at least one of the features \"rust_crypto\" or \"aws_lc_rs\" must be enabled"); pub use algorithms::Algorithm; -pub use decoding::{DecodingKey, TokenData, decode, decode_header}; +pub use decoding::{ + DecodingKey, TokenData, decode, decode_custom_header, decode_header, decode_with_custom_header, +}; pub use encoding::{EncodingKey, encode}; pub use header::Header; pub use validation::{Validation, get_current_timestamp}; @@ -31,7 +33,7 @@ mod decoding; mod encoding; /// All the errors that can be encountered while encoding/decoding JWTs pub mod errors; -mod header; +pub mod header; pub mod jwk; pub mod jws; #[cfg(feature = "use_pem")] diff --git a/tests/header/mod.rs b/tests/header/mod.rs index c3eb2576..15c97452 100644 --- a/tests/header/mod.rs +++ b/tests/header/mod.rs @@ -1,7 +1,10 @@ use base64::{Engine, engine::general_purpose::STANDARD}; use wasm_bindgen_test::wasm_bindgen_test; -use jsonwebtoken::Header; +use jsonwebtoken::{ + Algorithm, + header::{Alg, FromEncoded, Header}, +}; static CERT_CHAIN: [&str; 3] = include!("cert_chain.json"); @@ -38,3 +41,31 @@ fn x5c_der_invalid_chain() { assert!(header.x5c_der().is_err()); } + +#[test] +#[wasm_bindgen_test] +fn decode_custom_header() { + #[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize)] + struct CustomHeader { + alg: Algorithm, + typ: String, + nonstandard_header: String, + } + impl Alg for CustomHeader { + fn alg(&self) -> &Algorithm { + &self.alg + } + } + impl FromEncoded for CustomHeader {} + + let expected = CustomHeader { + alg: Algorithm::HS256, + typ: "JWT".into(), + nonstandard_header: "traits are awesome".into(), + }; + + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsIm5vbnN0YW5kYXJkX2hlYWRlciI6InRyYWl0cyBhcmUgYXdlc29tZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNzU5OTY4MTQ1fQ.c2VjcmV0"; + + let header = jsonwebtoken::decode_custom_header::(token).unwrap(); + assert_eq!(header, expected); +} From a68540f69b94368be517e6da0a18f4f03df91250 Mon Sep 17 00:00:00 2001 From: Dwayne Sykes Date: Wed, 8 Oct 2025 20:15:33 -0500 Subject: [PATCH 2/3] feat: add `BasicHeader` with `Hash` impl Add `BasicHeader` as a backwards-compatible fix for the regression introduced in #420 which dropped the `Hash` derive trait from `Header`. Resolves #439 --- src/header.rs | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/src/header.rs b/src/header.rs index dfb288c8..2191f71a 100644 --- a/src/header.rs +++ b/src/header.rs @@ -268,3 +268,120 @@ impl Alg for Header { } impl FromEncoded for Header {} + +/// A truly basic JWT header, the alg defaults to HS256 and typ is automatically +/// set to `JWT`. All the other fields are optional. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct BasicHeader { + /// The type of JWS: it can only be "JWT" here + /// + /// Defined in [RFC7515#4.1.9](https://tools.ietf.org/html/rfc7515#section-4.1.9). + #[serde(skip_serializing_if = "Option::is_none")] + pub typ: Option, + /// The algorithm used + /// + /// Defined in [RFC7515#4.1.1](https://tools.ietf.org/html/rfc7515#section-4.1.1). + pub alg: Algorithm, + /// Content type + /// + /// Defined in [RFC7519#5.2](https://tools.ietf.org/html/rfc7519#section-5.2). + #[serde(skip_serializing_if = "Option::is_none")] + pub cty: Option, + /// JSON Key URL + /// + /// Defined in [RFC7515#4.1.2](https://tools.ietf.org/html/rfc7515#section-4.1.2). + #[serde(skip_serializing_if = "Option::is_none")] + pub jku: Option, + /// JSON Web Key + /// + /// Defined in [RFC7515#4.1.3](https://tools.ietf.org/html/rfc7515#section-4.1.3). + #[serde(skip_serializing_if = "Option::is_none")] + pub jwk: Option, + /// Key ID + /// + /// Defined in [RFC7515#4.1.4](https://tools.ietf.org/html/rfc7515#section-4.1.4). + #[serde(skip_serializing_if = "Option::is_none")] + pub kid: Option, + /// X.509 URL + /// + /// Defined in [RFC7515#4.1.5](https://tools.ietf.org/html/rfc7515#section-4.1.5). + #[serde(skip_serializing_if = "Option::is_none")] + pub x5u: Option, + /// X.509 certificate chain. A Vec of base64 encoded ASN.1 DER certificates. + /// + /// Defined in [RFC7515#4.1.6](https://tools.ietf.org/html/rfc7515#section-4.1.6). + #[serde(skip_serializing_if = "Option::is_none")] + pub x5c: Option>, + /// X.509 SHA1 certificate thumbprint + /// + /// Defined in [RFC7515#4.1.7](https://tools.ietf.org/html/rfc7515#section-4.1.7). + #[serde(skip_serializing_if = "Option::is_none")] + pub x5t: Option, + /// X.509 SHA256 certificate thumbprint + /// + /// Defined in [RFC7515#4.1.8](https://tools.ietf.org/html/rfc7515#section-4.1.8). + /// + /// This will be serialized/deserialized as "x5t#S256", as defined by the RFC. + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "x5t#S256")] + pub x5t_s256: Option, + /// Critical - indicates header fields that must be understood by the receiver. + /// + /// Defined in [RFC7515#4.1.6](https://tools.ietf.org/html/rfc7515#section-4.1.6). + #[serde(skip_serializing_if = "Option::is_none")] + pub crit: Option>, + /// See `Enc` for description. + #[serde(skip_serializing_if = "Option::is_none")] + pub enc: Option, + /// See `Zip` for description. + #[serde(skip_serializing_if = "Option::is_none")] + pub zip: Option, + /// ACME: The URL to which this JWS object is directed + /// + /// Defined in [RFC8555#6.4](https://datatracker.ietf.org/doc/html/rfc8555#section-6.4). + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, + /// ACME: Random data for preventing replay attacks. + /// + /// Defined in [RFC8555#6.5.2](https://datatracker.ietf.org/doc/html/rfc8555#section-6.5.2). + #[serde(skip_serializing_if = "Option::is_none")] + pub nonce: Option, +} + +impl BasicHeader { + /// Returns a JWT header with the algorithm given + pub fn new(algorithm: Algorithm) -> Self { + Self { + typ: Some("JWT".to_string()), + alg: algorithm, + cty: None, + jku: None, + jwk: None, + kid: None, + x5u: None, + x5c: None, + x5t: None, + x5t_s256: None, + crit: None, + enc: None, + zip: None, + url: None, + nonce: None, + } + } +} + +impl Default for BasicHeader { + /// Returns a JWT header using the default Algorithm, HS256 + fn default() -> Self { + Self::new(Algorithm::default()) + } +} + +impl Alg for BasicHeader { + fn alg(&self) -> &Algorithm { + &self.alg + } +} + +impl FromEncoded for BasicHeader {} From a606eac7fffe3f78c1e91c438d58ae9a3fa4ac4d Mon Sep 17 00:00:00 2001 From: Dwayne Sykes Date: Thu, 9 Oct 2025 23:01:51 -0500 Subject: [PATCH 3/3] feat: add macros for custom headers and claims Add `Header` derive macro, `header` attribute macro, and `claims` attribute macro. `Header` derive macro implements required traits for custom headers `header` and `claims` are convenience attribute macros that add the required derive macros. --- Cargo.toml | 3 + examples/custom_header.rs | 15 ++--- examples/extras_header.rs | 44 ++++++++++++ jsonwebtoken-proc-macros/Cargo.toml | 11 +++ jsonwebtoken-proc-macros/src/lib.rs | 100 ++++++++++++++++++++++++++++ src/decoding.rs | 2 +- src/encoding.rs | 5 +- src/header.rs | 37 ++-------- src/jwk.rs | 2 +- src/lib.rs | 17 ++++- 10 files changed, 188 insertions(+), 48 deletions(-) create mode 100644 examples/extras_header.rs create mode 100644 jsonwebtoken-proc-macros/Cargo.toml create mode 100644 jsonwebtoken-proc-macros/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 441cc736..eec28dfb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,9 @@ rand = { version = "0.8.5", optional = true, features = ["std"], default-feature rsa = { version = "0.9.6", optional = true } sha2 = { version = "0.10.7", optional = true, features = ["oid"] } +# proc-macros (e.g. Header, claims) +jsonwebtoken-proc-macros = { path = "jsonwebtoken-proc-macros" } + [target.'cfg(target_arch = "wasm32")'.dependencies] js-sys = "0.3" getrandom = "0.2" diff --git a/examples/custom_header.rs b/examples/custom_header.rs index 35359235..f2a14354 100644 --- a/examples/custom_header.rs +++ b/examples/custom_header.rs @@ -1,30 +1,25 @@ -use serde::{Deserialize, Serialize}; use std::collections::HashMap; use jsonwebtoken::errors::ErrorKind; +use jsonwebtoken::macros::{claims, header}; use jsonwebtoken::{ - Algorithm, DecodingKey, EncodingKey, Validation, decode_with_custom_header, encode, header, + Algorithm, DecodingKey, EncodingKey, Validation, decode_with_custom_header, encode, }; -#[derive(Debug, Serialize, Deserialize, Clone)] +#[claims] struct Claims { sub: String, company: String, exp: u64, } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +#[header] +#[derive(PartialEq, Eq)] // only required for assertions in tests, not required by jsonwebtoken struct CustomHeader { alg: Algorithm, custom: String, another_custom_field: Option, } -impl header::FromEncoded for CustomHeader {} -impl header::Alg for CustomHeader { - fn alg(&self) -> &Algorithm { - &self.alg - } -} fn main() { let my_claims = diff --git a/examples/extras_header.rs b/examples/extras_header.rs new file mode 100644 index 00000000..32a1aa9b --- /dev/null +++ b/examples/extras_header.rs @@ -0,0 +1,44 @@ +use std::collections::HashMap; + +use jsonwebtoken::errors::ErrorKind; +use jsonwebtoken::{ + Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, macros::claims, +}; + +#[claims] +struct Claims { + sub: String, + company: String, + exp: u64, +} + +fn main() { + let my_claims = + Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), exp: 10000000000 }; + let key = b"secret"; + + let mut extras = HashMap::with_capacity(1); + extras.insert("custom".to_string(), "header".to_string()); + + let header = Header { alg: Algorithm::HS512, extras, ..Default::default() }; + + let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) { + Ok(t) => t, + Err(_) => panic!(), // in practice you would return the error + }; + println!("{:?}", token); + + let token_data = match decode::( + &token, + &DecodingKey::from_secret(key), + &Validation::new(Algorithm::HS512), + ) { + Ok(c) => c, + Err(err) => match *err.kind() { + ErrorKind::InvalidToken => panic!(), // Example on how to handle a specific error + _ => panic!(), + }, + }; + println!("{:?}", token_data.claims); + println!("{:?}", token_data.header); +} diff --git a/jsonwebtoken-proc-macros/Cargo.toml b/jsonwebtoken-proc-macros/Cargo.toml new file mode 100644 index 00000000..bf643720 --- /dev/null +++ b/jsonwebtoken-proc-macros/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "jsonwebtoken-proc-macros" +version = "0.1.0" +edition = "2024" + +[lib] +proc-macro = true + +[dependencies] +quote = "1.0.41" +syn = { version = "2.0.106", features = ["full"] } diff --git a/jsonwebtoken-proc-macros/src/lib.rs b/jsonwebtoken-proc-macros/src/lib.rs new file mode 100644 index 00000000..2654d5c8 --- /dev/null +++ b/jsonwebtoken-proc-macros/src/lib.rs @@ -0,0 +1,100 @@ +//! This library provides convenient derive and attribute macros for custom Header and Claim +//! structs for `jsonwebtoken`. +//! +//! Example +//! ```rust +//! use jsonwebtoken::Algorithm; +//! use jsonwebtoken::macros::header; +//! #[header] +//! struct CustomJwtHeader { +//! // `alg` is the only required struct field +//! alg: Algorithm, +//! custom_header: Option, +//! another_custom_header: String, +//! } +//! ```` +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{DeriveInput, Item, ItemStruct, parse_macro_input}; + +/// Convenience macro for JWT header structs +/// +/// Adds the following derive macros: +/// ```rust +/// #[derive( +/// Debug, +/// Clone, +/// Default, +/// serde::Serialize, +/// serde::Deserialize +/// )] +/// ``` +#[proc_macro_attribute] +pub fn claims(_attr: TokenStream, input: TokenStream) -> TokenStream { + let mut item = parse_macro_input!(input as Item); + + match &mut item { + Item::Struct(ItemStruct { attrs, .. }) => { + attrs.push( + syn::parse_quote!(#[derive(Debug, Clone, Default, jsonwebtoken::serde::Serialize, jsonwebtoken::serde::Deserialize)]), + ); + quote!(#item).into() + } + _ => syn::Error::new_spanned(&item, "#[header] can only be used on structs") + .to_compile_error() + .into(), + } +} + +/// Convenience macro for JWT header structs +/// +/// Adds the following derive macros: +/// ```rust +/// #[derive( +/// Debug, +/// Clone, +/// Default, +/// jsonwebtoken::macros::Header, +/// serde::Serialize, +/// serde::Deserialize +/// )] +/// ``` +#[proc_macro_attribute] +pub fn header(_attr: TokenStream, input: TokenStream) -> TokenStream { + let mut item = parse_macro_input!(input as Item); + + match &mut item { + Item::Struct(ItemStruct { attrs, .. }) => { + attrs.push(syn::parse_quote!(#[derive(Debug, Clone, Default, jsonwebtoken::serde::Serialize, jsonwebtoken::serde::Deserialize, jsonwebtoken::macros::Header)])); + quote!(#item).into() + } + _ => syn::Error::new_spanned(&item, "#[header] can only be used on structs") + .to_compile_error() + .into(), + } +} + +/// Derive macro required for custom JWT headers used with `jsonwebtoken` +/// +/// Requires an `alg: jsonwebtoken::Algorithm` field exists in the struct +#[proc_macro_derive(Header)] +pub fn derive_header(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + let name = &input.ident; + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let expanded = quote! { + impl #impl_generics ::jsonwebtoken::header::FromEncoded for #name #ty_generics #where_clause {} + impl #impl_generics ::jsonwebtoken::header::Alg for #name #ty_generics #where_clause { + fn alg(&self) -> &::jsonwebtoken::Algorithm { + &self.alg + } + } + }; + + TokenStream::from(expanded) +} diff --git a/src/decoding.rs b/src/decoding.rs index c566a10e..4c7725f3 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -325,7 +325,7 @@ where /// DANGER: This performs zero validation on the JWT pub fn insecure_decode( token: impl AsRef<[u8]>, -) -> Result> { +) -> Result> { let token = token.as_ref(); let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.')); diff --git a/src/encoding.rs b/src/encoding.rs index 0fc98b56..66c8b9a7 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -153,10 +153,9 @@ impl Debug for EncodingKey { /// If the algorithm given is RSA or EC, the key needs to be in the PEM format. /// /// ```rust -/// use serde::{Deserialize, Serialize}; -/// use jsonwebtoken::{encode, Algorithm, Header, EncodingKey}; +/// use jsonwebtoken::{encode, macros::claims, Algorithm, Header, EncodingKey}; /// -/// #[derive(Debug, Serialize, Deserialize)] +/// #[claims] /// struct Claims { /// sub: String, /// company: String diff --git a/src/header.rs b/src/header.rs index 2191f71a..fb7e08cb 100644 --- a/src/header.rs +++ b/src/header.rs @@ -3,11 +3,13 @@ use std::collections::HashMap; use std::result; use base64::{Engine, engine::general_purpose::STANDARD}; +use jsonwebtoken_proc_macros::header; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::algorithms::Algorithm; use crate::errors::Result; use crate::jwk::Jwk; +use crate::macros::Header; use crate::serialization::b64_decode; const ZIP_SERIAL_DEFLATE: &str = "DEF"; @@ -137,7 +139,7 @@ pub trait FromEncoded { /// A basic JWT header, the alg defaults to HS256 and typ is automatically /// set to `JWT`. All the other fields are optional. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Header, Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Header { /// The type of JWS: it can only be "JWT" here /// @@ -254,24 +256,10 @@ impl Header { } } -impl Default for Header { - /// Returns a JWT header using the default Algorithm, HS256 - fn default() -> Self { - Header::new(Algorithm::default()) - } -} - -impl Alg for Header { - fn alg(&self) -> &Algorithm { - &self.alg - } -} - -impl FromEncoded for Header {} - /// A truly basic JWT header, the alg defaults to HS256 and typ is automatically /// set to `JWT`. All the other fields are optional. -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[header] +#[derive(Hash, PartialEq, Eq)] pub struct BasicHeader { /// The type of JWS: it can only be "JWT" here /// @@ -370,18 +358,3 @@ impl BasicHeader { } } } - -impl Default for BasicHeader { - /// Returns a JWT header using the default Algorithm, HS256 - fn default() -> Self { - Self::new(Algorithm::default()) - } -} - -impl Alg for BasicHeader { - fn alg(&self) -> &Algorithm { - &self.alg - } -} - -impl FromEncoded for BasicHeader {} diff --git a/src/jwk.rs b/src/jwk.rs index 31f944d2..83d48de2 100644 --- a/src/jwk.rs +++ b/src/jwk.rs @@ -596,7 +596,7 @@ impl Jwk { /// Compute the thumbprint of the JWK. /// - /// Per [RFC-7638](https://datatracker.ietf.org/doc/html/rfc7638) + /// Per (RFC-7638)[] pub fn thumbprint(&self, hash_function: ThumbprintHash) -> String { let pre = match &self.algorithm { AlgorithmParameters::EllipticCurve(a) => match a.curve { diff --git a/src/lib.rs b/src/lib.rs index 5a247ca3..72014dd9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,14 @@ compile_error!( #[cfg(not(any(feature = "rust_crypto", feature = "aws_lc_rs")))] compile_error!("at least one of the features \"rust_crypto\" or \"aws_lc_rs\" must be enabled"); +// hidden export of `self` as `jsonwebtoken` required by proc macros +#[doc(hidden)] +extern crate self as jsonwebtoken; + +// hidden re-export of serde for proc macros +#[doc(hidden)] +pub use serde; + pub use algorithms::Algorithm; pub use decoding::{ DecodingKey, TokenData, decode, decode_custom_header, decode_header, decode_with_custom_header, @@ -33,10 +41,17 @@ mod decoding; mod encoding; /// All the errors that can be encountered while encoding/decoding JWTs pub mod errors; -pub mod header; pub mod jwk; pub mod jws; #[cfg(feature = "use_pem")] mod pem; mod serialization; mod validation; + +#[doc(hidden)] +pub mod header; + +/// Derive macros for custom JWT Header and Claims +pub mod macros { + pub use jsonwebtoken_proc_macros::{Header, claims, header}; +}