Skip to content

Commit 378bc1f

Browse files
authored
Optimise builder (#47)
1 parent 1b483c8 commit 378bc1f

10 files changed

+117
-46
lines changed

algo.go

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
// Signer is used to sign tokens.
99
type Signer interface {
1010
Algorithm() Algorithm
11+
SignSize() int
1112
Sign(payload []byte) ([]byte, error)
1213
}
1314

@@ -20,6 +21,8 @@ type Verifier interface {
2021
// Algorithm for signing and verifying.
2122
type Algorithm string
2223

24+
func (a Algorithm) String() string { return string(a) }
25+
2326
// Algorithm names for signing and verifying.
2427
const (
2528
EdDSA Algorithm = "EdDSA"

algo_eddsa.go

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ func (h edDSAAlg) Algorithm() Algorithm {
3636
return h.alg
3737
}
3838

39+
func (h edDSAAlg) SignSize() int {
40+
return ed25519.SignatureSize
41+
}
42+
3943
func (h edDSAAlg) Sign(payload []byte) ([]byte, error) {
4044
return ed25519.Sign(h.privateKey, payload), nil
4145
}

algo_es.go

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ func (h esAlg) Algorithm() Algorithm {
6969
return h.alg
7070
}
7171

72+
func (h esAlg) SignSize() int {
73+
return 2 * h.curveBits
74+
}
75+
7276
func (h esAlg) Sign(payload []byte) ([]byte, error) {
7377
signed, err := h.sign(payload)
7478
if err != nil {

algo_hs.go

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ func (h hsAlg) Algorithm() Algorithm {
6969
return h.alg
7070
}
7171

72+
func (h hsAlg) SignSize() int {
73+
return h.hash.Size()
74+
}
75+
7276
func (h hsAlg) Sign(payload []byte) ([]byte, error) {
7377
return h.sign(payload)
7478
}

algo_ps.go

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ type psAlg struct {
7878
opts *rsa.PSSOptions
7979
}
8080

81+
func (h psAlg) SignSize() int {
82+
return h.privateKey.Size()
83+
}
84+
8185
func (h psAlg) Algorithm() Algorithm {
8286
return h.alg
8387
}

algo_rs.go

+4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ func (h rsAlg) Algorithm() Algorithm {
6262
return h.alg
6363
}
6464

65+
func (h rsAlg) SignSize() int {
66+
return h.privateKey.Size()
67+
}
68+
6569
func (h rsAlg) Sign(payload []byte) ([]byte, error) {
6670
signed, err := h.sign(payload)
6771
if err != nil {

build.go

+83-38
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ var (
1212

1313
// Builder is used to create a new token.
1414
type Builder struct {
15-
signer Signer
16-
header Header
15+
signer Signer
16+
header Header
17+
headerRaw []byte
1718
}
1819

1920
// BuildBytes is used to create and encode JWT with a provided claims.
@@ -30,12 +31,12 @@ func Build(signer Signer, claims interface{}) (*Token, error) {
3031
func NewBuilder(signer Signer) *Builder {
3132
b := &Builder{
3233
signer: signer,
33-
3434
header: Header{
3535
Algorithm: signer.Algorithm(),
3636
Type: "JWT",
3737
},
3838
}
39+
b.headerRaw = encodeHeader(&b.header)
3940
return b
4041
}
4142

@@ -49,72 +50,116 @@ func (b *Builder) BuildBytes(claims interface{}) ([]byte, error) {
4950
}
5051

5152
// Build used to create and encode JWT with a provided claims.
53+
// If claims param is of type []byte then we treat it as a marshaled JSON.
54+
// In other words you can pass already marshaled claims.
5255
func (b *Builder) Build(claims interface{}) (*Token, error) {
53-
rawClaims, encodedClaims, err := encodeClaims(claims)
56+
rawClaims, err := encodeClaims(claims)
5457
if err != nil {
5558
return nil, err
5659
}
5760

58-
encodedHeader := encodeHeader(&b.header)
59-
payload := concatParts(encodedHeader, encodedClaims)
61+
lenH := len(b.headerRaw)
62+
lenC := base64EncodedLen(len(rawClaims))
63+
lenS := base64EncodedLen(b.signer.SignSize())
64+
65+
raw := make([]byte, lenH+1+lenC+1+lenS)
66+
idx := 0
67+
idx += copy(raw[idx:], b.headerRaw)
68+
raw[idx] = '.'
69+
idx++
70+
base64Encode(raw[idx:], rawClaims)
71+
idx += lenC
6072

61-
raw, signature, err := signPayload(b.signer, payload)
73+
signature, err := b.signer.Sign(raw[:idx])
6274
if err != nil {
6375
return nil, err
6476
}
77+
raw[idx] = '.'
78+
idx++
79+
base64Encode(raw[idx:], signature)
80+
idx += lenS
6581

6682
token := &Token{
6783
raw: raw,
68-
payload: payload,
84+
payload: raw[:lenH+1+lenC],
6985
signature: signature,
7086
header: b.header,
7187
claims: rawClaims,
7288
}
7389
return token, nil
7490
}
7591

76-
func encodeClaims(claims interface{}) (raw, encoded []byte, err error) {
77-
raw, err = json.Marshal(claims)
78-
if err != nil {
79-
return nil, nil, err
92+
func encodeClaims(claims interface{}) ([]byte, error) {
93+
switch claims := claims.(type) {
94+
case []byte:
95+
return claims, nil
96+
default:
97+
return json.Marshal(claims)
8098
}
81-
82-
encoded = make([]byte, base64EncodedLen(len(raw)))
83-
base64Encode(encoded, raw)
84-
85-
return raw, encoded, nil
8699
}
87100

88101
func encodeHeader(header *Header) []byte {
102+
if header.Type == "JWT" && header.ContentType == "" {
103+
switch header.Algorithm {
104+
case EdDSA:
105+
return []byte(encHeaderEdDSA)
106+
107+
case HS256:
108+
return []byte(encHeaderHS256)
109+
case HS384:
110+
return []byte(encHeaderHS384)
111+
case HS512:
112+
return []byte(encHeaderHS512)
113+
114+
case RS256:
115+
return []byte(encHeaderRS256)
116+
case RS384:
117+
return []byte(encHeaderRS384)
118+
case RS512:
119+
return []byte(encHeaderRS512)
120+
121+
case ES256:
122+
return []byte(encHeaderES256)
123+
case ES384:
124+
return []byte(encHeaderES384)
125+
case ES512:
126+
return []byte(encHeaderES512)
127+
128+
case PS256:
129+
return []byte(encHeaderPS256)
130+
case PS384:
131+
return []byte(encHeaderPS384)
132+
case PS512:
133+
return []byte(encHeaderPS512)
134+
135+
default:
136+
// another algorithm? encode below
137+
}
138+
}
89139
// returned err is always nil, see *Header.MarshalJSON
90-
buf, _ := header.MarshalJSON()
140+
buf, _ := json.Marshal(header)
91141

92142
encoded := make([]byte, base64EncodedLen(len(buf)))
93143
base64Encode(encoded, buf)
94-
95144
return encoded
96145
}
97146

98-
func signPayload(signer Signer, payload []byte) (signed, signature []byte, err error) {
99-
signature, err = signer.Sign(payload)
100-
if err != nil {
101-
return nil, nil, err
102-
}
147+
const (
148+
encHeaderEdDSA = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9"
103149

104-
encodedSignature := make([]byte, base64EncodedLen(len(signature)))
105-
base64Encode(encodedSignature, signature)
150+
encHeaderHS256 = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
151+
encHeaderHS384 = "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9"
152+
encHeaderHS512 = "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9"
106153

107-
signed = concatParts(payload, encodedSignature)
108-
109-
return signed, signature, nil
110-
}
154+
encHeaderRS256 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
155+
encHeaderRS384 = "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9"
156+
encHeaderRS512 = "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9"
111157

112-
func concatParts(a, b []byte) []byte {
113-
buf := make([]byte, len(a)+1+len(b))
114-
buf[len(a)] = '.'
158+
encHeaderES256 = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9"
159+
encHeaderES384 = "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9"
160+
encHeaderES512 = "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9"
115161

116-
copy(buf[:len(a)], a)
117-
copy(buf[len(a)+1:], b)
118-
119-
return buf
120-
}
162+
encHeaderPS256 = "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9"
163+
encHeaderPS384 = "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9"
164+
encHeaderPS512 = "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9"
165+
)

build_test.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func TestBuild(t *testing.T) {
1717

1818
raw := string(token)
1919
if raw != want {
20-
t.Errorf("want %v, got %v", want, raw)
20+
t.Errorf("want %v,\n got %v", want, raw)
2121
}
2222
}
2323

@@ -43,7 +43,7 @@ func TestBuildHeader(t *testing.T) {
4343
want = toBase64(want)
4444
raw := string(token.RawHeader())
4545
if raw != want {
46-
t.Errorf("want %v, got %v", want, raw)
46+
t.Errorf("\nwant %v,\n got %v", want, raw)
4747
}
4848
}
4949

@@ -109,6 +109,9 @@ func toBase64(s string) string {
109109

110110
type badSigner struct{}
111111

112+
func (badSigner) SignSize() int {
113+
return 0
114+
}
112115
func (badSigner) Algorithm() Algorithm {
113116
return "bad"
114117
}

jwt.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ type Header struct {
6767
}
6868

6969
// MarshalJSON implements the json.Marshaler interface.
70-
func (h *Header) MarshalJSON() (data []byte, err error) {
70+
func (h *Header) MarshalJSON() ([]byte, error) {
7171
buf := bytes.Buffer{}
7272
buf.WriteString(`{"alg":"`)
7373
buf.WriteString(string(h.Algorithm))

jwt_bench_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"github.com/cristalhq/jwt/v2"
1414
)
1515

16-
func BenchmarkEDSA(b *testing.B) {
16+
func BenchmarkAlgEDSA(b *testing.B) {
1717
pubKey, privKey, keyErr := ed25519.GenerateKey(rand.Reader)
1818
if keyErr != nil {
1919
b.Fatal(keyErr)
@@ -35,7 +35,7 @@ func BenchmarkEDSA(b *testing.B) {
3535
})
3636
}
3737

38-
func BenchmarkES(b *testing.B) {
38+
func BenchmarkAlgES(b *testing.B) {
3939
esAlgos := map[jwt.Algorithm]elliptic.Curve{
4040
jwt.ES256: elliptic.P256(),
4141
jwt.ES384: elliptic.P384(),
@@ -64,7 +64,7 @@ func BenchmarkES(b *testing.B) {
6464
}
6565
}
6666

67-
func BenchmarkPS(b *testing.B) {
67+
func BenchmarkAlgPS(b *testing.B) {
6868
psAlgos := []jwt.Algorithm{jwt.PS256, jwt.PS384, jwt.PS512}
6969
for _, algo := range psAlgos {
7070
key, keyErr := rsa.GenerateKey(rand.Reader, 2048)
@@ -89,7 +89,7 @@ func BenchmarkPS(b *testing.B) {
8989
}
9090
}
9191

92-
func BenchmarkRS(b *testing.B) {
92+
func BenchmarkAlgRS(b *testing.B) {
9393
rsAlgos := []jwt.Algorithm{jwt.RS256, jwt.RS384, jwt.RS512}
9494
for _, algo := range rsAlgos {
9595
key, keyErr := rsa.GenerateKey(rand.Reader, 2048)
@@ -114,7 +114,7 @@ func BenchmarkRS(b *testing.B) {
114114
}
115115
}
116116

117-
func BenchmarkHS(b *testing.B) {
117+
func BenchmarkAlgHS(b *testing.B) {
118118
key := []byte("12345")
119119
hsAlgos := []jwt.Algorithm{jwt.HS256, jwt.HS384, jwt.HS512}
120120
for _, algo := range hsAlgos {

0 commit comments

Comments
 (0)