Skip to content

Commit 3286114

Browse files
committed
fix: rule.SetOnResponse not work
should handle response headers after next plugin and before write response statusCode
1 parent 7451e58 commit 3286114

File tree

2 files changed

+139
-10
lines changed

2 files changed

+139
-10
lines changed

htransformation.go

+71-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package htransformation
22

33
import (
4+
"bufio"
45
"context"
56
"fmt"
7+
"net"
68
"net/http"
79

810
"github.com/tomMoulard/htransformation/pkg/handler/deleter"
@@ -15,9 +17,10 @@ import (
1517

1618
// HeadersTransformation holds the necessary components of a Traefik plugin.
1719
type HeadersTransformation struct {
18-
name string
19-
next http.Handler
20-
handlers []types.Handler
20+
name string
21+
next http.Handler
22+
reqHandlers []types.Handler
23+
respHandlers []types.Handler
2124
}
2225

2326
// Config holds configuration to be passed to the plugin.
@@ -42,7 +45,8 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
4245
types.Set: set.New,
4346
}
4447

45-
handlers := make([]types.Handler, 0, len(config.Rules))
48+
reqHandlers := make([]types.Handler, 0, len(config.Rules))
49+
respHandlers := make([]types.Handler, 0, len(config.Rules))
4650

4751
for _, rule := range config.Rules {
4852
newHandler, ok := handlerBuilder[rule.Type]
@@ -59,22 +63,79 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
5963
return nil, fmt.Errorf("%w: %s", err, rule.Name)
6064
}
6165

62-
handlers = append(handlers, h)
66+
if rule.SetOnResponse {
67+
respHandlers = append(respHandlers, h)
68+
} else {
69+
reqHandlers = append(reqHandlers, h)
70+
}
6371
}
6472

6573
return &HeadersTransformation{
66-
name: name,
67-
next: next,
68-
handlers: handlers,
74+
name: name,
75+
next: next,
76+
reqHandlers: reqHandlers,
77+
respHandlers: respHandlers,
6978
}, nil
7079
}
7180

7281
// Iterate over every header to match the ones specified in the config and
7382
// return nothing if regexp failed.
7483
func (u *HeadersTransformation) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
75-
for _, handler := range u.handlers {
84+
for _, handler := range u.reqHandlers {
7685
handler.Handle(responseWriter, request)
7786
}
7887

79-
u.next.ServeHTTP(responseWriter, request)
88+
wrappedResponseWriter := newWrappedResponseWriter(responseWriter, func(rw http.ResponseWriter) {
89+
for _, handler := range u.respHandlers {
90+
handler.Handle(rw, request)
91+
}
92+
})
93+
94+
u.next.ServeHTTP(wrappedResponseWriter, request)
95+
}
96+
97+
type wrappedResponseWriter struct {
98+
rw http.ResponseWriter
99+
handler func(http.ResponseWriter)
100+
headerSent bool
101+
}
102+
103+
func newWrappedResponseWriter(rw http.ResponseWriter, handler func(http.ResponseWriter)) http.ResponseWriter {
104+
return &wrappedResponseWriter{
105+
rw: rw,
106+
handler: handler,
107+
headerSent: false,
108+
}
109+
}
110+
111+
func (wrw *wrappedResponseWriter) handleResponseHeader() {
112+
if wrw.headerSent {
113+
return
114+
}
115+
116+
wrw.headerSent = true
117+
wrw.handler(wrw.rw)
118+
}
119+
120+
func (wrw *wrappedResponseWriter) Header() http.Header {
121+
return wrw.rw.Header()
122+
}
123+
124+
func (wrw *wrappedResponseWriter) Write(p []byte) (int, error) {
125+
wrw.handleResponseHeader()
126+
return wrw.rw.Write(p)
127+
}
128+
129+
func (wrw *wrappedResponseWriter) WriteHeader(statusCode int) {
130+
wrw.handleResponseHeader()
131+
wrw.rw.WriteHeader(statusCode)
132+
}
133+
134+
func (wrw *wrappedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
135+
hijacker, ok := wrw.rw.(http.Hijacker)
136+
if !ok {
137+
return nil, nil, fmt.Errorf("%T is not an http.Hijacker", wrw.rw)
138+
}
139+
140+
return hijacker.Hijack()
80141
}

htransformation_test.go

+68
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,71 @@ func TestHeaderRules(t *testing.T) {
192192
})
193193
}
194194
}
195+
196+
func TestSetOnResponse(t *testing.T) {
197+
testCases := []struct {
198+
name string
199+
headerName string
200+
headerValue string
201+
rule types.Rule
202+
expectedNewValue string
203+
}{
204+
{
205+
name: "set rule",
206+
headerName: "Header-A",
207+
headerValue: "valueA",
208+
rule: types.Rule{
209+
Name: "set rule",
210+
Header: "Header-A",
211+
Value: "newValue",
212+
Type: types.Set,
213+
SetOnResponse: true,
214+
},
215+
expectedNewValue: "newValue",
216+
},
217+
{
218+
name: "rewrite rule",
219+
headerName: "Header-A",
220+
headerValue: "valueAA",
221+
rule: types.Rule{
222+
Name: "rewrite rule",
223+
Header: "Header-A",
224+
Value: `value([\w\W]+)`,
225+
ValueReplace: "newValue-$1",
226+
Type: types.RewriteValueRule,
227+
SetOnResponse: true,
228+
},
229+
expectedNewValue: "newValue-AA",
230+
},
231+
}
232+
233+
for _, test := range testCases {
234+
t.Run(test.name, func(t *testing.T) {
235+
cfg := plug.CreateConfig()
236+
cfg.Rules = []types.Rule{test.rule}
237+
238+
ctx := context.Background()
239+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
240+
rw.Header().Add(test.headerName, test.headerValue)
241+
rw.WriteHeader(200)
242+
})
243+
244+
handler, err := plug.New(ctx, next, cfg, "demo-plugin")
245+
require.NoError(t, err)
246+
247+
recorder := httptest.NewRecorder()
248+
249+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
250+
require.NoError(t, err)
251+
252+
handler.ServeHTTP(recorder, req)
253+
resp := recorder.Result()
254+
statusCode := resp.StatusCode
255+
require.NoError(t, resp.Body.Close())
256+
257+
assert.Equal(t, http.StatusOK, statusCode)
258+
259+
assert.Equal(t, test.expectedNewValue, resp.Header.Get(test.rule.Header))
260+
})
261+
}
262+
}

0 commit comments

Comments
 (0)