1
1
package htransformation
2
2
3
3
import (
4
+ "bufio"
4
5
"context"
5
6
"fmt"
7
+ "net"
6
8
"net/http"
7
9
8
10
"github.com/tomMoulard/htransformation/pkg/handler/deleter"
@@ -15,9 +17,10 @@ import (
15
17
16
18
// HeadersTransformation holds the necessary components of a Traefik plugin.
17
19
type HeadersTransformation struct {
18
- name string
19
- next http.Handler
20
- handlers []types.Handler
20
+ name string
21
+ next http.Handler
22
+ reqHandlers []types.Handler
23
+ respHandlers []types.Handler
21
24
}
22
25
23
26
// Config holds configuration to be passed to the plugin.
@@ -42,7 +45,8 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
42
45
types .Set : set .New ,
43
46
}
44
47
45
- handlers := make ([]types.Handler , 0 , len (config .Rules ))
48
+ reqHandlers := make ([]types.Handler , 0 , len (config .Rules ))
49
+ respHandlers := make ([]types.Handler , 0 , len (config .Rules ))
46
50
47
51
for _ , rule := range config .Rules {
48
52
newHandler , ok := handlerBuilder [rule .Type ]
@@ -59,22 +63,79 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
59
63
return nil , fmt .Errorf ("%w: %s" , err , rule .Name )
60
64
}
61
65
62
- handlers = append (handlers , h )
66
+ if rule .SetOnResponse {
67
+ respHandlers = append (respHandlers , h )
68
+ } else {
69
+ reqHandlers = append (reqHandlers , h )
70
+ }
63
71
}
64
72
65
73
return & HeadersTransformation {
66
- name : name ,
67
- next : next ,
68
- handlers : handlers ,
74
+ name : name ,
75
+ next : next ,
76
+ reqHandlers : reqHandlers ,
77
+ respHandlers : respHandlers ,
69
78
}, nil
70
79
}
71
80
72
81
// Iterate over every header to match the ones specified in the config and
73
82
// return nothing if regexp failed.
74
83
func (u * HeadersTransformation ) ServeHTTP (responseWriter http.ResponseWriter , request * http.Request ) {
75
- for _ , handler := range u .handlers {
84
+ for _ , handler := range u .reqHandlers {
76
85
handler .Handle (responseWriter , request )
77
86
}
78
87
79
- u .next .ServeHTTP (responseWriter , request )
88
+ wrappedResponseWriter := newWrappedResponseWriter (responseWriter , func (rw http.ResponseWriter ) {
89
+ for _ , handler := range u .respHandlers {
90
+ handler .Handle (rw , request )
91
+ }
92
+ })
93
+
94
+ u .next .ServeHTTP (wrappedResponseWriter , request )
95
+ }
96
+
97
+ type wrappedResponseWriter struct {
98
+ rw http.ResponseWriter
99
+ handler func (http.ResponseWriter )
100
+ headerSent bool
101
+ }
102
+
103
+ func newWrappedResponseWriter (rw http.ResponseWriter , handler func (http.ResponseWriter )) http.ResponseWriter {
104
+ return & wrappedResponseWriter {
105
+ rw : rw ,
106
+ handler : handler ,
107
+ headerSent : false ,
108
+ }
109
+ }
110
+
111
+ func (wrw * wrappedResponseWriter ) handleResponseHeader () {
112
+ if wrw .headerSent {
113
+ return
114
+ }
115
+
116
+ wrw .headerSent = true
117
+ wrw .handler (wrw .rw )
118
+ }
119
+
120
+ func (wrw * wrappedResponseWriter ) Header () http.Header {
121
+ return wrw .rw .Header ()
122
+ }
123
+
124
+ func (wrw * wrappedResponseWriter ) Write (p []byte ) (int , error ) {
125
+ wrw .handleResponseHeader ()
126
+ return wrw .rw .Write (p )
127
+ }
128
+
129
+ func (wrw * wrappedResponseWriter ) WriteHeader (statusCode int ) {
130
+ wrw .handleResponseHeader ()
131
+ wrw .rw .WriteHeader (statusCode )
132
+ }
133
+
134
+ func (wrw * wrappedResponseWriter ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
135
+ hijacker , ok := wrw .rw .(http.Hijacker )
136
+ if ! ok {
137
+ return nil , nil , fmt .Errorf ("%T is not an http.Hijacker" , wrw .rw )
138
+ }
139
+
140
+ return hijacker .Hijack ()
80
141
}
0 commit comments