Skip to content
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ If you want to set the `kid` parameter or change the algorithm for example:
```rust
let mut header = Header::new(Algorithm::HS512);
header.kid = Some("blabla".to_owned());

let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());
header.extras = Some(extras);

let token = encode(&header, &my_claims, &EncodingKey::from_secret("secret".as_ref()))?;
```
Look at `examples/custom_header.rs` for a full working example.
Expand Down
15 changes: 14 additions & 1 deletion benches/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
struct Claims {
Expand All @@ -17,6 +18,18 @@ fn bench_encode(c: &mut Criterion) {
});
}

fn bench_encode_custom_extra_headers(c: &mut Criterion) {
let claim = Claims { sub: "[email protected]".to_owned(), company: "ACME".to_owned() };
let key = EncodingKey::from_secret("secret".as_ref());
let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());
let header = &Header { extras, ..Default::default() };

c.bench_function("bench_encode", |b| {
b.iter(|| encode(black_box(header), black_box(&claim), black_box(&key)))
});
}

fn bench_decode(c: &mut Criterion) {
let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ";
let key = DecodingKey::from_secret("secret".as_ref());
Expand All @@ -32,5 +45,5 @@ fn bench_decode(c: &mut Criterion) {
});
}

criterion_group!(benches, bench_encode, bench_decode);
criterion_group!(benches, bench_encode, bench_encode_custom_extra_headers, bench_decode);
criterion_main!(benches);
12 changes: 10 additions & 2 deletions examples/custom_header.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
Expand All @@ -15,8 +16,15 @@ fn main() {
Claims { sub: "[email protected]".to_owned(), company: "ACME".to_owned(), exp: 10000000000 };
let key = b"secret";

let header =
Header { kid: Some("signing_key".to_owned()), alg: Algorithm::HS512, ..Default::default() };
let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());

let header = Header {
kid: Some("signing_key".to_owned()),
alg: Algorithm::HS512,
extras,
..Default::default()
};

let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) {
Ok(t) => t,
Expand Down
10 changes: 9 additions & 1 deletion src/header.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::result;

use base64::{engine::general_purpose::STANDARD, Engine};
Expand All @@ -10,7 +11,7 @@ use crate::serialization::b64_decode;

/// A basic JWT header, the alg defaults to HS256 and typ is automatically
/// set to `JWT`. All the other fields are optional.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Header {
/// The type of JWS: it can only be "JWT" here
///
Expand Down Expand Up @@ -64,6 +65,12 @@ pub struct Header {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "x5t#S256")]
pub x5t_s256: Option<String>,

/// Any additional non-standard headers not defined in [RFC7515#4.1](https://datatracker.ietf.org/doc/html/rfc7515#section-4.1).
/// Once serialized, all keys will be converted to fields at the root level of the header payload
/// Ex: Dict("custom" -> "header") will be converted to "{"typ": "JWT", ..., "custom": "header"}"
#[serde(flatten)]
pub extras: HashMap<String, String>,
}

impl Header {
Expand All @@ -80,6 +87,7 @@ impl Header {
x5c: None,
x5t: None,
x5t_s256: None,
extras: Default::default(),
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<'de> Deserialize<'de> for PublicKeyUse {
D: Deserializer<'de>,
{
struct PublicKeyUseVisitor;
impl<'de> de::Visitor<'de> for PublicKeyUseVisitor {
impl de::Visitor<'_> for PublicKeyUseVisitor {
type Value = PublicKeyUse;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -116,7 +116,7 @@ impl<'de> Deserialize<'de> for KeyOperations {
D: Deserializer<'de>,
{
struct KeyOperationsVisitor;
impl<'de> de::Visitor<'de> for KeyOperationsVisitor {
impl de::Visitor<'_> for KeyOperationsVisitor {
type Value = KeyOperations;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
2 changes: 2 additions & 0 deletions src/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use serde::{Deserialize, Serialize};

use crate::errors::Result;

#[inline]
pub(crate) fn b64_encode<T: AsRef<[u8]>>(input: T) -> String {
URL_SAFE_NO_PAD.encode(input)
}

#[inline]
pub(crate) fn b64_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>> {
URL_SAFE_NO_PAD.decode(input).map_err(|e| e.into())
}
Expand Down
16 changes: 8 additions & 8 deletions src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,20 @@ where
{
struct NumericType(PhantomData<fn() -> TryParse<u64>>);

impl<'de> Visitor<'de> for NumericType {
impl Visitor<'_> for NumericType {
type Value = TryParse<u64>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("A NumericType that can be reasonably coerced into a u64")
}

fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(TryParse::Parsed(value))
}

fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
Expand All @@ -354,13 +361,6 @@ where
Err(serde::de::Error::custom("NumericType must be representable as a u64"))
}
}

fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(TryParse::Parsed(value))
}
}

match deserializer.deserialize_any(NumericType(PhantomData)) {
Expand Down
70 changes: 70 additions & 0 deletions tests/hmac.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use jsonwebtoken::{
decode, decode_header, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use time::OffsetDateTime;
use wasm_bindgen_test::wasm_bindgen_test;

Expand Down Expand Up @@ -51,6 +52,56 @@ fn encode_with_custom_header() {
.unwrap();
assert_eq!(my_claims, token_data.claims);
assert_eq!("kid", token_data.header.kid.unwrap());
assert!(token_data.header.extras.is_empty());
}

#[test]
#[wasm_bindgen_test]
fn encode_with_extra_custom_header() {
let my_claims = Claims {
sub: "[email protected]".to_string(),
company: "ACME".to_string(),
exp: OffsetDateTime::now_utc().unix_timestamp() + 10000,
};
let mut extras = HashMap::with_capacity(1);
extras.insert("custom".to_string(), "header".to_string());
let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() };
let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap();
let token_data = decode::<Claims>(
&token,
&DecodingKey::from_secret(b"secret"),
&Validation::new(Algorithm::HS256),
)
.unwrap();
assert_eq!(my_claims, token_data.claims);
assert_eq!("kid", token_data.header.kid.unwrap());
assert_eq!("header", token_data.header.extras.get("custom").unwrap().as_str());
}

#[test]
#[wasm_bindgen_test]
fn encode_with_multiple_extra_custom_headers() {
let my_claims = Claims {
sub: "[email protected]".to_string(),
company: "ACME".to_string(),
exp: OffsetDateTime::now_utc().unix_timestamp() + 10000,
};
let mut extras = HashMap::with_capacity(2);
extras.insert("custom1".to_string(), "header1".to_string());
extras.insert("custom2".to_string(), "header2".to_string());
let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() };
let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap();
let token_data = decode::<Claims>(
&token,
&DecodingKey::from_secret(b"secret"),
&Validation::new(Algorithm::HS256),
)
.unwrap();
assert_eq!(my_claims, token_data.claims);
assert_eq!("kid", token_data.header.kid.unwrap());
let extras = token_data.header.extras;
assert_eq!("header1", extras.get("custom1").unwrap().as_str());
assert_eq!("header2", extras.get("custom2").unwrap().as_str());
}

#[test]
Expand Down Expand Up @@ -86,6 +137,25 @@ fn decode_token() {
claims.unwrap();
}

#[test]
#[wasm_bindgen_test]
fn decode_token_custom_headers() {
let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsImN1c3RvbTEiOiJoZWFkZXIxIiwiY3VzdG9tMiI6ImhlYWRlcjIifQ.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.FtOHsoKcNH3SriK3tnR-uWJg4UV4FkOzvq_JCfLngfU";
let claims = decode::<Claims>(
token,
&DecodingKey::from_secret(b"secret"),
&Validation::new(Algorithm::HS256),
)
.unwrap();
let my_claims =
Claims { sub: "[email protected]".to_string(), company: "ACME".to_string(), exp: 2532524891 };
assert_eq!(my_claims, claims.claims);
assert_eq!("kid", claims.header.kid.unwrap());
let extras = claims.header.extras;
assert_eq!("header1", extras.get("custom1").unwrap().as_str());
assert_eq!("header2", extras.get("custom2").unwrap().as_str());
}

#[test]
#[wasm_bindgen_test]
#[should_panic(expected = "InvalidToken")]
Expand Down
Loading