-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware.go
More file actions
242 lines (215 loc) · 7.74 KB
/
middleware.go
File metadata and controls
242 lines (215 loc) · 7.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
// Copyright © 2021 Luther Systems, Ltd. All right reserved.
package midware
import (
"bytes"
"fmt"
"net/http"
"sort"
"strings"
"github.com/google/uuid"
"github.com/luthersystems/svc/static"
)
// DefaultTraceHeader is the default header when TraceHeaders is given an empty
// string instead of a valid header name.
var DefaultTraceHeader = "X-Request-Id"
// DefaultAzureHeader is the default Azure header that contains a unique guid
// generated by application gateway for each client request and presented in
// the forwarded request to the backend pool member.
var DefaultAzureHeader = "X-Appgw-Trace-Id"
// DefaultAWSHeader is the default AWS header that can be used for request tracing
// to track HTTP requests from clients to targets or other services
var DefaultAWSHeader = "X-Amzn-Trace-Id"
// PathOverrides is middleware which overrides handling for a specified set of
// http request paths. Each entry in a PathOverrides map is an http request
// path and the associated handler will be used to serve that path instead of
// allowing the middleware's "natural" inner handler to serve the request.
//
// PathOverrides does not support overriding subtrees (paths ending with '/')
// in the way that http.ServeMux supports path patterns. Keys in PathOverrides
// are expected to be complete, rooted paths.
type PathOverrides map[string]http.Handler
// Wrap implements the Middleware interface.
func (m PathOverrides) Wrap(next http.Handler) http.Handler {
var prefixes []string
// public file system may have nested directories we want to access but we
// want to ensure that the /public/ handler handles the request
for path := range m {
if path != static.PublicPathPrefix && strings.HasPrefix(path, static.PublicPathPrefix) {
panic(fmt.Sprintf("PathOverride conflict: disallowed registration of nested public route: %s", path))
}
if strings.HasSuffix(path, "/") {
prefixes = append(prefixes, path)
}
}
sort.Slice(prefixes, func(i, j int) bool {
return len(prefixes[i]) > len(prefixes[j])
})
return &pathOverridesHandler{
m: m,
prefixes: prefixes,
next: next,
}
}
type pathOverridesHandler struct {
m PathOverrides
prefixes []string
next http.Handler
}
func (h *pathOverridesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Exact match
if handler, ok := h.m[path]; ok {
handler.ServeHTTP(w, r)
return
}
// do longest match first
for _, prefix := range h.prefixes {
if strings.HasPrefix(path, prefix) {
h.m[prefix].ServeHTTP(w, r)
return
}
}
// Default to next handler
h.next.ServeHTTP(w, r)
}
// ServerResponseHeader returns a middleware that renders the given sequence of
// server components (presumably in "software[/version]" format) and includes
// them in the Server response header. Any secondary components which are
// supplied in addition to primary will be rendered in sequence and delimited
// by a single whitespace. Any component which renders an empty string or one
// consisting solely of whitespace is ignored and other values will have
// leading and trailing whitespace trimmed. ServerResponseHeader overwrites
// any Server header that was set earlier (by another middleware).
//
// ServerResponseHeader will panic immediately if the primary component does
// not contain a valid token (RFC2616). It is recommended that the primary
// component be the result of ServerFixed called with a const, non-empty name
// argument.
//
// BUG: Neither ServerResponseHeader nor its returned middleware check
// components for invalid control characters. Because of this it is important
// that application end users and unchecked code not be permitted to inject
// content into server response header components.
func ServerResponseHeader(primary string, secondary ...func() string) Middleware {
primary = strings.TrimSpace(primary)
if primary == "" {
panic("http server header primary component is invalid")
}
return Func(func(next http.Handler) http.Handler {
return &serverListHandler{p: primary, s: secondary, next: next}
})
}
type serverListHandler struct {
p string
s []func() string
next http.Handler
}
func (h *serverListHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s := h.header()
// NOTE: s cannot be empty in any allowed construction of h but we include
// this branch which cannot panic just to protect against subtle future
// bugs.
if s == "" {
// The RFC2616 grammar for Server dictates that it must contain a
// nonempty token. The application expects a Server header to be
// injected here and we don't want to crash an inflexible http client
// library by injecting an invalid header so we inject something
// generic that is still valid according to the RFC.
s = "server"
}
w.Header().Set("Server", s)
h.next.ServeHTTP(w, r)
}
func (h *serverListHandler) header() string {
if len(h.s) == 0 {
return h.p
}
var b bytes.Buffer
b.WriteString(h.p) // space has already been trimmed
for i := range h.s {
s := strings.TrimSpace(h.s[i]())
if s != "" {
b.WriteByte(' ')
b.WriteString(h.s[i]())
}
}
return b.String()
}
// ServerFixed returns a string indented to be used as the primary component in
// ServerResponseHeader. ServerFixed ignores any leading and trailing
// whitespace in its arguments. If version is non-empty the server header
// component will render the two strings joined by a slash, like the following:
//
// fmt.Sprintf("%s/%s", name, version)
//
// The name argument of ServerFixed should be non-empty but that is not
// enforced. If passed two empty strings ServerFixed will return an empty
// string.
func ServerFixed(name string, version string) string {
if version == "" {
return strings.TrimSpace(name)
}
return strings.TrimSpace(name) + "/" + strings.TrimSpace(version)
}
// ServerFixedFunc returns a function which can be used as a secondary
// component in ServerResponseHeader for cases where the software's name and
// version is known ahead of time. The returned component is equivalent to the
// following function closure:
//
// func() string {
// return ServerFixed(name, version)
// }
func ServerFixedFunc(name string, version string) func() string {
fixed := ServerFixed(name, version)
return func() string { return fixed }
}
// TraceHeaders ensures all incoming http requests have an identifying header
// for tracing and automatically includes a matching header in http responses.
// If allow is true then requests are allowed to specify their own ids which
// are assumed to be unique, otherwise any existing header will be overwritten
// before deferring to the inner http handler. If header is the empty string
// then DefaultTraceHeader will contain the tracing identifier.
func TraceHeaders(header string, allow bool) Middleware {
if header == "" {
header = DefaultTraceHeader
}
return Func(func(next http.Handler) http.Handler {
return &traceRequestHeader{
header: header,
allow: allow,
next: next,
}
})
}
type traceRequestHeader struct {
header string
allow bool
next http.Handler
}
func (h *traceRequestHeader) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var reqid string
precedenceHeaders := []string{DefaultTraceHeader, DefaultAzureHeader, DefaultAWSHeader}
if h.allow {
reqid = r.Header.Get(h.header)
for _, header := range precedenceHeaders {
headerValue := r.Header.Get(header)
if headerValue != "" {
h.header = header
reqid = headerValue
break
}
}
}
if reqid == "" {
reqid = uuid.New().String()
r.Header.Set(h.header, reqid)
}
// Always set DefaultTraceHeader on request and response since a lot of
// logging is hard-coded to use this header.
if h.header != DefaultTraceHeader {
r.Header.Set(DefaultTraceHeader, reqid)
w.Header().Set(DefaultTraceHeader, reqid)
}
w.Header().Set(h.header, reqid)
h.next.ServeHTTP(w, r)
}