diff --git a/httpsig_test.go b/httpsig_test.go index dca48e7..9800cfa 100644 --- a/httpsig_test.go +++ b/httpsig_test.go @@ -42,7 +42,6 @@ type httpsigTest struct { expectedSignatureAlgorithm string expectedAlgorithm Algorithm expectErrorSigningResponse bool - expectRequestPath bool expectedDigest string } @@ -247,7 +246,6 @@ func init() { expectedAlgorithm: RSA_SHA512, expectedSignatureAlgorithm: "hs2019", expectErrorSigningResponse: true, - expectRequestPath: true, }, } @@ -390,6 +388,84 @@ func TestSignerResponse(t *testing.T) { } } +func TestNewRequestTargetHeader(t *testing.T) { + tests := []struct { + name string + prefs []Algorithm + digestAlg DigestAlgorithm + headers []string + scheme SignatureScheme + privKey crypto.PrivateKey + pubKeyId string + pubKey crypto.PublicKey + expectedAlgorithm Algorithm + expectedSignatureAlgorithm string + url string + }{ + { + name: "root path without /", + prefs: []Algorithm{RSA_SHA512}, + digestAlg: DigestSha256, + headers: []string{"Date", RequestTarget}, + scheme: Signature, + privKey: privKey, + pubKeyId: "pubKeyId", + pubKey: privKey.Public(), + expectedSignatureAlgorithm: "hs2019", + expectedAlgorithm: RSA_SHA512, + url: "", // empty url / path + }, + { + name: "root path with /", + prefs: []Algorithm{RSA_SHA512}, + digestAlg: DigestSha256, + headers: []string{"Date", RequestTarget}, + scheme: Signature, + privKey: privKey, + pubKeyId: "pubKeyId", + pubKey: privKey.Public(), + expectedSignatureAlgorithm: "hs2019", + expectedAlgorithm: RSA_SHA512, + url: "/", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test := test + s, a, err := NewSigner(test.prefs, test.digestAlg, test.headers, test.scheme, 0) + if err != nil { + t.Fatalf("%s", err) + } + if a != test.expectedAlgorithm { + t.Fatalf("got %s, want %s", a, test.expectedAlgorithm) + } + req, err := http.NewRequest(testMethod, test.url, nil) + if err != nil { + t.Fatalf("%s", err) + } + req.Header.Set("Date", testDate) + err = s.SignRequest(test.privKey, test.pubKeyId, req, nil) + if err != nil { + t.Fatalf("signing failed: %s", err) + } + + requestTarget := req.Header.Get(RequestTarget) + parts := strings.Split(requestTarget, requestTargetSeparator) + if len(parts) != 2 { + t.Fatalf("request target should contain two parts divided by '%s'", requestTargetSeparator) + } + + if parts[0] != strings.ToLower(testMethod) { + t.Fatalf("request target method is '%s' expected '%s'", parts[0], strings.ToLower(testMethod)) + } + + if parts[1] != "/" { + t.Fatalf("request target path is '%s' expected '%s'", parts[1], "/") + } + }) + } +} + func TestNewSignerRequestMissingHeaders(t *testing.T) { failingTests := []struct { name string diff --git a/signing.go b/signing.go index e18db41..ea566c3 100644 --- a/signing.go +++ b/signing.go @@ -59,6 +59,7 @@ func (m *macSigner) SignRequest(pKey crypto.PrivateKey, pubKeyId string, r *http return err } } + s, err := m.signatureString(r) if err != nil { return err @@ -198,6 +199,7 @@ func setSignatureHeader(h http.Header, targetHeader, prefix, pubKeyId, algo, enc headers = defaultHeaders } var b bytes.Buffer + // KeyId b.WriteString(prefix) if len(prefix) > 0 { @@ -209,6 +211,7 @@ func setSignatureHeader(h http.Header, targetHeader, prefix, pubKeyId, algo, enc b.WriteString(pubKeyId) b.WriteString(parameterValueDelimiter) b.WriteString(parameterSeparater) + // Algorithm b.WriteString(algorithmParameter) b.WriteString(parameterKVSeparater) @@ -256,6 +259,7 @@ func setSignatureHeader(h http.Header, targetHeader, prefix, pubKeyId, algo, enc } b.WriteString(parameterValueDelimiter) b.WriteString(parameterSeparater) + // Signature b.WriteString(signatureParameter) b.WriteString(parameterKVSeparater) @@ -271,17 +275,24 @@ func requestTargetNotPermitted(b *bytes.Buffer) error { func addRequestTarget(r *http.Request) func(b *bytes.Buffer) error { return func(b *bytes.Buffer) error { - b.WriteString(RequestTarget) - b.WriteString(headerFieldDelimiter) - b.WriteString(strings.ToLower(r.Method)) - b.WriteString(requestTargetSeparator) - b.WriteString(r.URL.Path) - + var value strings.Builder + value.WriteString(strings.ToLower(r.Method)) + value.WriteString(requestTargetSeparator) + value.WriteString(r.URL.Path) + if !strings.HasSuffix(r.URL.Path, "/") { + value.WriteString("/") + } if r.URL.RawQuery != "" { - b.WriteString("?") - b.WriteString(r.URL.RawQuery) + value.WriteString("?") + value.WriteString(r.URL.RawQuery) } + r.Header.Set(RequestTarget, value.String()) + + b.WriteString(RequestTarget) + b.WriteString(headerFieldDelimiter) + b.WriteString(value.String()) + return nil } }