@@ -18,28 +18,36 @@ package main
18
18
19
19
import (
20
20
"crypto/tls"
21
- "crypto/x509"
22
21
"encoding/json"
23
22
"encoding/pem"
24
23
"fmt"
25
24
"io"
26
- "net"
27
25
"net/http"
28
26
"os"
29
27
"regexp"
30
28
"strconv"
31
29
"strings"
32
30
"time"
33
31
34
- "github.com/paultag/sniff/parser"
35
32
"golang.org/x/net/http2"
36
33
"golang.org/x/net/http2/h2c"
37
34
"golang.org/x/net/websocket"
38
35
39
36
g "sigs.k8s.io/gateway-api/conformance/echo-basic/grpc"
40
37
)
41
38
42
- // RequestAssertions contains information about the request and the Ingress
39
+ type preserveSlashes struct {
40
+ mux http.Handler
41
+ }
42
+
43
+ func (s * preserveSlashes ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
44
+ r .URL .Path = strings .ReplaceAll (r .URL .Path , "//" , "/" )
45
+ s .mux .ServeHTTP (w , r )
46
+ }
47
+
48
+ var context Context
49
+
50
+ // RequestAssertions contains information about the request and the Ingress.
43
51
type RequestAssertions struct {
44
52
Path string `json:"path"`
45
53
Host string `json:"host"`
@@ -63,25 +71,14 @@ type TLSAssertions struct {
63
71
CipherSuite string `json:"cipherSuite"`
64
72
}
65
73
66
- type preserveSlashes struct {
67
- mux http.Handler
68
- }
69
-
70
- func (s * preserveSlashes ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
71
- r .URL .Path = strings .ReplaceAll (r .URL .Path , "//" , "/" )
72
- s .mux .ServeHTTP (w , r )
73
- }
74
-
75
- // Context contains information about the context where the echoserver is running
74
+ // Context contains information about the context where the echoserver is running.
76
75
type Context struct {
77
76
Namespace string `json:"namespace"`
78
77
Ingress string `json:"ingress"`
79
78
Service string `json:"service"`
80
79
Pod string `json:"pod"`
81
80
}
82
81
83
- var context Context
84
-
85
82
func main () {
86
83
if os .Getenv ("GRPC_ECHO_SERVER" ) != "" {
87
84
g .Main ()
@@ -92,6 +89,7 @@ func main() {
92
89
if httpPort == "" {
93
90
httpPort = "3000"
94
91
}
92
+
95
93
h2cPort := os .Getenv ("H2C_PORT" )
96
94
if h2cPort == "" {
97
95
h2cPort = "3001"
@@ -113,7 +111,6 @@ func main() {
113
111
httpMux .HandleFunc ("/health" , healthHandler )
114
112
httpMux .HandleFunc ("/status/" , statusHandler )
115
113
httpMux .HandleFunc ("/" , echoHandler )
116
- httpMux .HandleFunc ("/backendTLS" , echoHandler )
117
114
httpMux .Handle ("/ws" , websocket .Handler (wsHandler ))
118
115
httpHandler := & preserveSlashes {httpMux }
119
116
@@ -130,18 +127,22 @@ func main() {
130
127
go runH2CServer (h2cPort , errchan )
131
128
132
129
// Enable HTTPS if server certificate and private key are given. (TLS_SERVER_CERT, TLS_SERVER_PRIVKEY)
133
- // Enable secure backend if CA certificate and key are given. (CA_CERT, CA_CERT_KEY)
134
- if os .Getenv ("TLS_SERVER_CERT" ) != "" && os .Getenv ("TLS_SERVER_PRIVKEY" ) != "" ||
135
- os .Getenv ("CA_CERT" ) != "" && os .Getenv ("CA_CERT_KEY" ) != "" {
130
+ if os .Getenv ("TLS_SERVER_CERT" ) != "" && os .Getenv ("TLS_SERVER_PRIVKEY" ) != "" {
136
131
go func () {
137
132
fmt .Printf ("Starting server, listening on port %s (https)\n " , httpsPort )
138
- err := listenAndServeTLS (fmt .Sprintf (":%s" , httpsPort ), os .Getenv ("TLS_SERVER_CERT" ), os .Getenv ("TLS_SERVER_PRIVKEY" ), os . Getenv ( "CA_CERT" ), httpHandler )
133
+ err := listenAndServeTLS (fmt .Sprintf (":%s" , httpsPort ), os .Getenv ("TLS_SERVER_CERT" ), os .Getenv ("TLS_SERVER_PRIVKEY" ), httpHandler )
139
134
if err != nil {
140
135
errchan <- err
141
136
}
142
137
}()
143
138
}
144
139
140
+ // Enable secure backend if CA certificate is given. (CA_CERT)
141
+ if os .Getenv ("CA_CERT" ) != "" {
142
+ // Start the backend server and listen on port 9443.
143
+ go runBackendTLSServer ("9443" , errchan )
144
+ }
145
+
145
146
if err := <- errchan ; err != nil {
146
147
panic (fmt .Sprintf ("Failed to start listening: %s\n " , err .Error ()))
147
148
}
@@ -207,30 +208,100 @@ func runH2CServer(h2cPort string, errchan chan<- error) {
207
208
}
208
209
}
209
210
210
- func echoHandler ( w http. ResponseWriter , r * http. Request ) {
211
- var sni string
211
+ // Global variable to store the SNI retrieved at a different level.
212
+ var globalsni string
212
213
214
+ func runBackendTLSServer (port string , errchan chan <- error ) {
215
+ // This handler function runs within the backend server to find the SNI
216
+ // and return it in the RequestAssertions.
217
+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
218
+ if strings .Contains (r .RequestURI , "backendTLS" ) {
219
+ // Find the sni in the global (or if needed, a channel/mutex?).
220
+ if globalsni == "" {
221
+ err := fmt .Errorf ("error finding SNI: SNI is empty" )
222
+ // If there are some test cases without SNI, then they must handle this error properly.
223
+ processError (w , err , http .StatusBadRequest )
224
+ }
225
+ requestAssertions := RequestAssertions {
226
+ r .RequestURI ,
227
+ r .Host ,
228
+ r .Method ,
229
+ r .Proto ,
230
+ r .Header ,
231
+
232
+ context ,
233
+
234
+ tlsStateToAssertions (r .TLS ),
235
+ globalsni ,
236
+ }
237
+ processRequestAssertions (requestAssertions , w , r )
238
+ } else {
239
+ // This should never happen, but just in case.
240
+ processError (w , fmt .Errorf ("backend server called without correct uri" ), http .StatusBadRequest )
241
+ }
242
+ })
243
+
244
+ config , err := makeTLSConfig (os .Getenv ("CA_CERT" ))
245
+ if err != nil {
246
+ errchan <- err
247
+ }
248
+ btlsServer := & http.Server {
249
+ Addr : fmt .Sprintf (":%s" , port ),
250
+ Handler : handler ,
251
+ ReadHeaderTimeout : time .Second ,
252
+ TLSConfig : config ,
253
+ }
254
+ fmt .Printf ("Starting server, listening on port %s (btls)\n " , port )
255
+ err = btlsServer .ListenAndServeTLS (os .Getenv ("CA_CERT" ), os .Getenv ("CA_CERT_KEY" ))
256
+ if err != nil {
257
+ fmt .Printf ("Failed to start server: %v\n " , err )
258
+ errchan <- err
259
+ }
260
+ }
261
+
262
+ func makeTLSConfig (cacert string ) (* tls.Config , error ) {
263
+ var config tls.Config
264
+
265
+ if cacert == "" {
266
+ return & config , fmt .Errorf ("empty CA cert specified" )
267
+ }
268
+ cert , err := tls .LoadX509KeyPair (cacert , os .Getenv ("CA_CERT_KEY" ))
269
+ if err != nil {
270
+ return & config , fmt .Errorf ("failed to load key pair: %v" , err )
271
+ }
272
+ certs := []tls.Certificate {cert }
273
+
274
+ // Verify certificate against given CA but also allow unauthenticated connections.
275
+ config .ClientAuth = tls .VerifyClientCertIfGiven
276
+ config .Certificates = certs
277
+ config .GetConfigForClient = func (info * tls.ClientHelloInfo ) (* tls.Config , error ) {
278
+ if info != nil {
279
+ // Store the SNI from the ClientHello into a global variable.
280
+ globalsni = info .ServerName
281
+ if globalsni == "" {
282
+ return nil , fmt .Errorf ("no SNI specified" )
283
+ }
284
+ return nil , nil
285
+ } else {
286
+ return nil , fmt .Errorf ("no client hello available" )
287
+ }
288
+ }
289
+
290
+ return & config , nil
291
+ }
292
+
293
+ func echoHandler (w http.ResponseWriter , r * http.Request ) {
213
294
fmt .Printf ("Echoing back request made to %s to client (%s)\n " , r .RequestURI , r .RemoteAddr )
214
295
215
296
// If the request has form ?delay=[:duration] wait for duration
216
297
// For example, ?delay=10s will cause the response to wait 10s before responding
217
298
err := delayResponse (r )
218
299
if err != nil {
300
+ fmt .Printf ("error : %v\n " , err )
219
301
processError (w , err , http .StatusInternalServerError )
220
302
return
221
303
}
222
304
223
- // If the request was made to URI backendTLS, then get the server name indication and
224
- // add it to the RequestAssertions. It will be echoed back later.
225
- if strings .Contains (r .RequestURI , "backendTLS" ) {
226
- sni , err = sniffForSNI (r .RemoteAddr )
227
- if err != nil {
228
- // TODO: research if for some test cases there won't be SNI available.
229
- processError (w , err , http .StatusBadGateway )
230
- return
231
- }
232
- }
233
-
234
305
requestAssertions := RequestAssertions {
235
306
r .RequestURI ,
236
307
r .Host ,
@@ -241,9 +312,12 @@ func echoHandler(w http.ResponseWriter, r *http.Request) {
241
312
context ,
242
313
243
314
tlsStateToAssertions (r .TLS ),
244
- sni ,
315
+ "" ,
245
316
}
317
+ processRequestAssertions (requestAssertions , w , r )
318
+ }
246
319
320
+ func processRequestAssertions (requestAssertions RequestAssertions , w http.ResponseWriter , r * http.Request ) {
247
321
js , err := json .MarshalIndent (requestAssertions , "" , " " )
248
322
if err != nil {
249
323
processError (w , err , http .StatusInternalServerError )
@@ -289,70 +363,15 @@ func processError(w http.ResponseWriter, err error, code int) { //nolint:unparam
289
363
_ , _ = w .Write (body )
290
364
}
291
365
292
- func listenAndServeTLS (addr string , serverCert string , serverPrivKey string , clientCA string , handler http.Handler ) error {
293
- var config tls.Config
294
-
295
- // Optionally enable client certificate validation when client CA certificates are given.
296
- if clientCA != "" {
297
- ca , err := os .ReadFile (clientCA )
298
- if err != nil {
299
- return err
300
- }
301
-
302
- certPool := x509 .NewCertPool ()
303
- if ok := certPool .AppendCertsFromPEM (ca ); ! ok {
304
- return fmt .Errorf ("unable to append certificate in %q to CA pool" , clientCA )
305
- }
306
-
307
- // Verify certificate against given CA but also allow unauthenticated connections.
308
- config .ClientAuth = tls .VerifyClientCertIfGiven
309
- config .ClientCAs = certPool
310
- }
311
-
366
+ func listenAndServeTLS (addr string , serverCert , serverPrivKey string , handler http.Handler ) error {
312
367
srv := & http.Server { //nolint:gosec
313
- Addr : addr ,
314
- Handler : handler ,
315
- TLSConfig : & config ,
368
+ Addr : addr ,
369
+ Handler : handler ,
316
370
}
317
371
318
372
return srv .ListenAndServeTLS (serverCert , serverPrivKey )
319
373
}
320
374
321
- // sniffForSNI uses the request address to listen for the incoming TLS connection,
322
- // and tries to find the server name indication from that connection.
323
- func sniffForSNI (addr string ) (string , error ) {
324
- var sni string
325
-
326
- // Listen to get the SNI, and store in config.
327
- listener , err := net .Listen ("tcp" , addr )
328
- if err != nil {
329
- return "" , err
330
- }
331
- defer listener .Close ()
332
-
333
- for {
334
- conn , err := listener .Accept ()
335
- if err != nil {
336
- return "" , err
337
- }
338
- data := make ([]byte , 4096 )
339
- _ , err = conn .Read (data )
340
- if err != nil {
341
- return "" , fmt .Errorf ("could not read socket: %v" , err )
342
- }
343
- // Take an incoming TLS Client Hello and return the SNI name.
344
- sni , err = parser .GetHostname (data )
345
- if err != nil {
346
- return "" , fmt .Errorf ("error getting SNI: %v" , err )
347
- }
348
- if sni == "" {
349
- return "" , fmt .Errorf ("no server name indication found" )
350
- } else { //nolint:revive
351
- return sni , nil
352
- }
353
- }
354
- }
355
-
356
375
func tlsStateToAssertions (connectionState * tls.ConnectionState ) * TLSAssertions {
357
376
if connectionState != nil {
358
377
var state TLSAssertions
0 commit comments