Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ rand = { version = "0.8.5", optional = true, features = ["std"], default-feature
rsa = { version = "0.9.6", optional = true }
sha2 = { version = "0.10.7", optional = true, features = ["oid"] }

# proc-macros (e.g. Header, claims)
jsonwebtoken-proc-macros = { path = "jsonwebtoken-proc-macros" }

[target.'cfg(target_arch = "wasm32")'.dependencies]
js-sys = "0.3"
getrandom = "0.2"
Expand Down
25 changes: 17 additions & 8 deletions examples/custom_header.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use jsonwebtoken::macros::{claims, header};
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Validation, decode_with_custom_header, encode,
};

#[derive(Debug, Serialize, Deserialize, Clone)]
#[claims]
struct Claims {
sub: String,
company: String,
exp: u64,
}

#[header]
#[derive(PartialEq, Eq)] // only required for assertions in tests, not required by jsonwebtoken
struct CustomHeader {
alg: Algorithm,
custom: String,
another_custom_field: Option<usize>,
}

fn main() {
let my_claims =
Claims { sub: "[email protected]".to_owned(), company: "ACME".to_owned(), exp: 10000000000 };
Expand All @@ -19,11 +29,10 @@ fn main() {
let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());

let header = Header {
kid: Some("signing_key".to_owned()),
let header = CustomHeader {
alg: Algorithm::HS512,
extras,
..Default::default()
custom: "custom".into(),
another_custom_field: 42.into(),
};

let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) {
Expand All @@ -32,7 +41,7 @@ fn main() {
};
println!("{:?}", token);

let token_data = match decode::<Claims>(
let token_data = match decode_with_custom_header::<CustomHeader, Claims>(
&token,
&DecodingKey::from_secret(key),
&Validation::new(Algorithm::HS512),
Expand Down
44 changes: 44 additions & 0 deletions examples/extras_header.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::collections::HashMap;

use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode, macros::claims,
};

#[claims]
struct Claims {
sub: String,
company: String,
exp: u64,
}

fn main() {
let my_claims =
Claims { sub: "[email protected]".to_owned(), company: "ACME".to_owned(), exp: 10000000000 };
let key = b"secret";

let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());

let header = Header { alg: Algorithm::HS512, extras, ..Default::default() };

let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) {
Ok(t) => t,
Err(_) => panic!(), // in practice you would return the error
};
println!("{:?}", token);

let token_data = match decode::<Claims>(
&token,
&DecodingKey::from_secret(key),
&Validation::new(Algorithm::HS512),
) {
Ok(c) => c,
Err(err) => match *err.kind() {
ErrorKind::InvalidToken => panic!(), // Example on how to handle a specific error
_ => panic!(),
},
};
println!("{:?}", token_data.claims);
println!("{:?}", token_data.header);
}
11 changes: 11 additions & 0 deletions jsonwebtoken-proc-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "jsonwebtoken-proc-macros"
version = "0.1.0"
edition = "2024"

[lib]
proc-macro = true

[dependencies]
quote = "1.0.41"
syn = { version = "2.0.106", features = ["full"] }
100 changes: 100 additions & 0 deletions jsonwebtoken-proc-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//! This library provides convenient derive and attribute macros for custom Header and Claim
//! structs for `jsonwebtoken`.
//!
//! Example
//! ```rust
//! use jsonwebtoken::Algorithm;
//! use jsonwebtoken::macros::header;
//! #[header]
//! struct CustomJwtHeader {
//! // `alg` is the only required struct field
//! alg: Algorithm,
//! custom_header: Option<String>,
//! another_custom_header: String,
//! }
//! ````
extern crate proc_macro;

use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, Item, ItemStruct, parse_macro_input};

/// Convenience macro for JWT header structs
///
/// Adds the following derive macros:
/// ```rust
/// #[derive(
/// Debug,
/// Clone,
/// Default,
/// serde::Serialize,
/// serde::Deserialize
/// )]
/// ```
#[proc_macro_attribute]
pub fn claims(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut item = parse_macro_input!(input as Item);

match &mut item {
Item::Struct(ItemStruct { attrs, .. }) => {
attrs.push(
syn::parse_quote!(#[derive(Debug, Clone, Default, jsonwebtoken::serde::Serialize, jsonwebtoken::serde::Deserialize)]),
);
quote!(#item).into()
}
_ => syn::Error::new_spanned(&item, "#[header] can only be used on structs")
.to_compile_error()
.into(),
}
}

/// Convenience macro for JWT header structs
///
/// Adds the following derive macros:
/// ```rust
/// #[derive(
/// Debug,
/// Clone,
/// Default,
/// jsonwebtoken::macros::Header,
/// serde::Serialize,
/// serde::Deserialize
/// )]
/// ```
#[proc_macro_attribute]
pub fn header(_attr: TokenStream, input: TokenStream) -> TokenStream {
let mut item = parse_macro_input!(input as Item);

match &mut item {
Item::Struct(ItemStruct { attrs, .. }) => {
attrs.push(syn::parse_quote!(#[derive(Debug, Clone, Default, jsonwebtoken::serde::Serialize, jsonwebtoken::serde::Deserialize, jsonwebtoken::macros::Header)]));
quote!(#item).into()
}
_ => syn::Error::new_spanned(&item, "#[header] can only be used on structs")
.to_compile_error()
.into(),
}
}

/// Derive macro required for custom JWT headers used with `jsonwebtoken`
///
/// Requires an `alg: jsonwebtoken::Algorithm` field exists in the struct
#[proc_macro_derive(Header)]
pub fn derive_header(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);

let name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let expanded = quote! {
impl #impl_generics ::jsonwebtoken::header::FromEncoded for #name #ty_generics #where_clause {}
impl #impl_generics ::jsonwebtoken::header::Alg for #name #ty_generics #where_clause {
fn alg(&self) -> &::jsonwebtoken::Algorithm {
&self.alg
}
}
};

TokenStream::from(expanded)
}
75 changes: 45 additions & 30 deletions src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::Algorithm;
use crate::algorithms::AlgorithmFamily;
use crate::crypto::JwtVerifier;
use crate::errors::{ErrorKind, Result, new_error};
use crate::header::Header;
use crate::header::{Alg, FromEncoded, Header};
use crate::jwk::{AlgorithmParameters, Jwk};
#[cfg(feature = "use_pem")]
use crate::pem::decoder::PemEncodedKey;
Expand Down Expand Up @@ -37,15 +37,16 @@ use crate::crypto::rust_crypto::{

/// The return type of a successful call to [decode](fn.decode.html).
#[derive(Debug)]
pub struct TokenData<T> {
pub struct TokenData<H, T> {
/// The decoded JWT header
pub header: Header,
pub header: H,
/// The decoded JWT claims
pub claims: T,
}

impl<T> Clone for TokenData<T>
impl<H, T> Clone for TokenData<H, T>
where
H: Clone,
T: Clone,
{
fn clone(&self) -> Self {
Expand Down Expand Up @@ -281,21 +282,40 @@ pub fn decode<T: DeserializeOwned + Clone>(
token: impl AsRef<[u8]>,
key: &DecodingKey,
validation: &Validation,
) -> Result<TokenData<T>> {
) -> Result<TokenData<Header, T>> {
decode_with_custom_header(token, key, validation)
}

/// Decode and validate a JWT with a custom header
///
/// If the token or its signature is invalid, or the claims fail validation, this will return an
/// error.
pub fn decode_with_custom_header<H, T>(
token: impl AsRef<[u8]>,
key: &DecodingKey,
validation: &Validation,
) -> Result<TokenData<H, T>>
where
H: DeserializeOwned + Clone + FromEncoded + Alg,
T: DeserializeOwned + Clone,
{
let token = token.as_ref();
let header = decode_header(token)?;

if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
let header = H::from_encoded(header)?;

if validation.validate_signature && !validation.algorithms.contains(header.alg()) {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}

let verifying_provider = jwt_verifier_factory(&header.alg, key)?;
let verifying_provider = jwt_verifier_factory(header.alg(), key)?;
verify_signature_body(message, signature, &header, validation, verifying_provider)?;

let (header, claims) = verify_signature(token, validation, verifying_provider)?;
let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(payload)?;
validate(decoded_claims.deserialize()?, validation)?;

let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?;
let claims = decoded_claims.deserialize()?;
validate(decoded_claims.deserialize()?, validation)?;

Ok(TokenData { header, claims })
}
Expand All @@ -305,7 +325,7 @@ pub fn decode<T: DeserializeOwned + Clone>(
/// DANGER: This performs zero validation on the JWT
pub fn insecure_decode<T: DeserializeOwned + Clone>(
token: impl AsRef<[u8]>,
) -> Result<TokenData<T>> {
) -> Result<TokenData<Header, T>> {
let token = token.as_ref();

let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
Expand Down Expand Up @@ -357,10 +377,21 @@ pub fn decode_header(token: impl AsRef<[u8]>) -> Result<Header> {
Header::from_encoded(header)
}

/// Decode only the custom header of a JWT without decoding or validating the payload
pub fn decode_custom_header<H>(token: impl AsRef<[u8]>) -> Result<H>
where
H: DeserializeOwned + Clone + Alg + FromEncoded,
{
let token = token.as_ref();
let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
let (_, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
H::from_encoded(header)
}

pub(crate) fn verify_signature_body(
message: &[u8],
signature: &[u8],
header: &Header,
header: &impl Alg,
validation: &Validation,
verifying_provider: Box<dyn JwtVerifier>,
) -> Result<()> {
Expand All @@ -376,7 +407,7 @@ pub(crate) fn verify_signature_body(
}
}

if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
if validation.validate_signature && !validation.algorithms.contains(header.alg()) {
return Err(new_error(ErrorKind::InvalidAlgorithm));
}

Expand All @@ -388,19 +419,3 @@ pub(crate) fn verify_signature_body(

Ok(())
}

/// Verify the signature of a JWT, and return a header object and raw payload.
///
/// If the token or its signature is invalid, it will return an error.
fn verify_signature<'a>(
token: &'a [u8],
validation: &Validation,
verifying_provider: Box<dyn JwtVerifier>,
) -> Result<(Header, &'a [u8])> {
let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
let header = Header::from_encoded(header)?;
verify_signature_body(message, signature, &header, validation, verifying_provider)?;

Ok((header, payload))
}
Loading