Skip to content

Commit 99c21b7

Browse files
committed
proxy: add support for mTLS
1 parent bad4a71 commit 99c21b7

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

cmd/proxy/main.go

+32-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"crypto/tls"
7+
"crypto/x509"
78
"encoding/binary"
89
"errors"
910
"flag"
@@ -64,16 +65,18 @@ func main() {
6465

6566
func loop(l log.Logger, ctx context.Context) error {
6667
var (
67-
addr string
68-
tlsCert string
69-
tlsKey string
70-
backend string
68+
addr string
69+
tlsCert string
70+
tlsKey string
71+
backend string
72+
mtlsCACerts string
7173
)
7274

7375
flag.StringVar(&addr, "listen", "127.0.0.1:853", "UDP address to listen on.")
74-
flag.StringVar(&tlsCert, "cert", "cert.pem", "TLS certificate path.")
75-
flag.StringVar(&tlsKey, "key", "key.pem", "TLS key path.")
76+
flag.StringVar(&tlsCert, "cert", "server.crt", "Path to server TLS certificate.")
77+
flag.StringVar(&tlsKey, "key", "server.key", "Path to server TLS key.")
7678
flag.StringVar(&backend, "backend", "8.8.4.4:53", "IP of backend server.")
79+
flag.StringVar(&mtlsCACerts, "mtls_ca_certs", "", "Path to CA bundle for mTLS.")
7780

7881
flag.Parse()
7982

@@ -82,12 +85,27 @@ func loop(l log.Logger, ctx context.Context) error {
8285
return fmt.Errorf("load certificate: %w", err)
8386
}
8487

85-
tls := tls.Config{
88+
tlsConfig := tls.Config{
8689
Certificates: []tls.Certificate{cert},
8790
NextProtos: []string{"doq"},
8891
}
8992

90-
listener, err := quic.ListenAddr(addr, &tls, nil)
93+
if mtlsCACerts != "" {
94+
pems, err := os.ReadFile(mtlsCACerts)
95+
if err != nil {
96+
return fmt.Errorf("load mTLS CA certificates: %w", err)
97+
}
98+
pool := x509.NewCertPool()
99+
if ok := pool.AppendCertsFromPEM(pems); !ok {
100+
return fmt.Errorf("load mTLS CA certificates: found no certificate")
101+
}
102+
103+
tlsConfig.ClientCAs = pool
104+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
105+
106+
}
107+
108+
listener, err := quic.ListenAddr(addr, &tlsConfig, nil)
91109
if err != nil {
92110
return fmt.Errorf("listen: %w", err)
93111
}
@@ -105,6 +123,12 @@ func loop(l log.Logger, ctx context.Context) error {
105123
}
106124

107125
l := log.With(l, "client", session.RemoteAddr())
126+
127+
certs := session.ConnectionState().TLS.PeerCertificates
128+
if len(certs) > 0 {
129+
l = log.With(l, "client_cert_subject", certs[0].Subject)
130+
}
131+
108132
wg.Add(1)
109133
go func() {
110134
handleClient(l, ctx, session, backend)

0 commit comments

Comments
 (0)