diff --git a/client/pkg/transport/keepalive_listener_test.go b/client/pkg/transport/keepalive_listener_test.go index 425f53368b54..015392dfbc10 100644 --- a/client/pkg/transport/keepalive_listener_test.go +++ b/client/pkg/transport/keepalive_listener_test.go @@ -56,7 +56,7 @@ func TestNewKeepAliveListener(t *testing.T) { defer del() tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile} tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) - tlscfg, err := tlsInfo.ServerConfig() + tlscfg, err := tlsInfo.ReloadableServerConfig() if err != nil { t.Fatalf("unexpected serverConfig error: %v", err) } diff --git a/client/pkg/transport/listener.go b/client/pkg/transport/listener.go index 574ff35b07c2..af219750eaaa 100644 --- a/client/pkg/transport/listener.go +++ b/client/pkg/transport/listener.go @@ -358,7 +358,7 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali func (info *TLSInfo) startRefresh() { info.refreshOnce.Do( func() { - info.loadTLSConfig() + info.loadServerTlsConfig() if info.RefreshDuration > 0 { info.refreshDone = make(chan struct{}) go info.tlsConfigRefreshLoop() @@ -367,7 +367,7 @@ func (info *TLSInfo) startRefresh() { ) } -func (info *TLSInfo) loadTLSConfig() *tls.Config { +func (info *TLSInfo) loadServerTlsConfig() *tls.Config { if info.Logger != nil { info.Logger.Info("tls config reload from files") } @@ -389,13 +389,33 @@ func (info *TLSInfo) tlsConfigRefreshLoop() { for { select { case <-ticker.C: - info.loadTLSConfig() + info.loadServerTlsConfig() case <-info.refreshDone: return } } } +// baseConfig is called on initial TLS handshake start. +// +// Previously, +// 1. Server has non-empty (*tls.Config).Certificates on client hello +// 2. Server calls (*tls.Config).GetCertificate iff: +// - Server's (*tls.Config).Certificates is not empty, or +// - Client supplies SNI; non-empty (*tls.ClientHelloInfo).ServerName +// +// When (*tls.Config).Certificates is always populated on initial handshake, +// client is expected to provide a valid matching SNI to pass the TLS +// verification, thus trigger server (*tls.Config).GetCertificate to reload +// TLS assets. However, a cert whose SAN field does not include domain names +// but only IP addresses, has empty (*tls.ClientHelloInfo).ServerName, thus +// it was never able to trigger TLS reload on initial handshake; first +// ceritifcate object was being used, never being updated. +// +// Now, (*tls.Config).Certificates is created empty on initial TLS client +// handshake, in order to trigger (*tls.Config).GetCertificate and populate +// rest of the certificates on every new TLS connection, even when client +// SNI is empty (e.g. cert only includes IPs). func (info *TLSInfo) baseConfig() (*tls.Config, error) { if info.KeyFile == "" || info.CertFile == "" { return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile) @@ -516,33 +536,23 @@ func (info *TLSInfo) serverConfig() (*tls.Config, error) { return cfg, nil } -// baseConfig is called on initial TLS handshake start. -// -// Previously, -// 1. Server has non-empty (*tls.Config).Certificates on client hello -// 2. Server calls (*tls.Config).GetCertificate iff: -// - Server's (*tls.Config).Certificates is not empty, or -// - Client supplies SNI; non-empty (*tls.ClientHelloInfo).ServerName -// -// When (*tls.Config).Certificates is always populated on initial handshake, -// client is expected to provide a valid matching SNI to pass the TLS -// verification, thus trigger server (*tls.Config).GetCertificate to reload -// TLS assets. However, a cert whose SAN field does not include domain names -// but only IP addresses, has empty (*tls.ClientHelloInfo).ServerName, thus -// it was never able to trigger TLS reload on initial handshake; first -// ceritifcate object was being used, never being updated. -// -// Now, (*tls.Config).Certificates is created empty on initial TLS client -// handshake, in order to trigger (*tls.Config).GetCertificate and populate -// rest of the certificates on every new TLS connection, even when client -// SNI is empty (e.g. cert only includes IPs). -func (info *TLSInfo) getLatestTLSConfig() (*tls.Config, error) { +// cafiles returns a list of CA file paths. +func (info *TLSInfo) cafiles() []string { + cs := make([]string, 0) + if info.TrustedCAFile != "" { + cs = append(cs, info.TrustedCAFile) + } + return cs +} + +// ReloadableServerConfig generates a tls.Config object for use by an HTTP server. +func (info *TLSInfo) ReloadableServerConfig() (*tls.Config, error) { info.startRefresh() return &tls.Config{ GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { cfg, ok := info.tlsConfig.Load().(*tls.Config) if !ok { - cfg = info.loadTLSConfig() + cfg = info.loadServerTlsConfig() } return cfg, nil }, @@ -578,21 +588,6 @@ func (info *TLSInfo) getLatestTLSConfig() (*tls.Config, error) { }, nil } -// cafiles returns a list of CA file paths. -func (info *TLSInfo) cafiles() []string { - cs := make([]string, 0) - if info.TrustedCAFile != "" { - cs = append(cs, info.TrustedCAFile) - } - return cs -} - -// ServerConfig generates a tls.Config object for use by an HTTP server. -func (info *TLSInfo) ServerConfig() (*tls.Config, error) { - cfg, err := info.getLatestTLSConfig() - return cfg, err -} - // ClientConfig generates a tls.Config object for use by an HTTP client. func (info *TLSInfo) ClientConfig() (*tls.Config, error) { var cfg *tls.Config diff --git a/client/pkg/transport/listener_test.go b/client/pkg/transport/listener_test.go index 9a4c40cad080..9aee099ee7ba 100644 --- a/client/pkg/transport/listener_test.go +++ b/client/pkg/transport/listener_test.go @@ -476,7 +476,7 @@ func TestTLSInfoParseFuncError(t *testing.T) { tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake")) if _, err = tt.info.serverConfig(); err == nil { - t.Errorf("#%d: expected non-nil error from ServerConfig()", i) + t.Errorf("#%d: expected non-nil error from ReloadableServerConfig()", i) } if _, err = tt.info.ClientConfig(); err == nil { @@ -516,7 +516,7 @@ func TestTLSInfoConfigFuncs(t *testing.T) { sCfg, err := tt.info.serverConfig() if err != nil { - t.Errorf("#%d: expected nil error from ServerConfig(), got non-nil: %v", i, err) + t.Errorf("#%d: expected nil error from ReloadableServerConfig(), got non-nil: %v", i, err) } if tt.wantCAs != (sCfg.ClientCAs != nil) { @@ -723,14 +723,27 @@ func TestRootCAReload(t *testing.T) { cli.Get("https://" + ln.Addr().String()) }() - conn, err := ln.Accept() - if err != nil { - t.Fatalf("unexpected Accept error: %v", err) - } - if _, ok := conn.(*tls.Conn); !ok { - t.Error("failed to accept *tls.Conn") + errChan := make(chan error) + go func() { + conn, err := ln.Accept() + if err != nil { + errChan <- err + } + if _, ok := conn.(*tls.Conn); !ok { + errChan <- errors.New("failed to accept *tls.Conn") + } + conn.Close() + errChan <- nil + }() + + select { + case <-time.After(10 * time.Second): + t.Fatalf("timeout accept") + case err := <-errChan: + if err != nil { + t.Fatal(err) + } } - conn.Close() // regenerate rootCA and sign new certs rootCA, _, privKey = createRootCertificateAuthority(rootCAPath, caBytes, t) @@ -758,12 +771,25 @@ func TestRootCAReload(t *testing.T) { go func() { cli.Get("https://" + ln.Addr().String()) }() - conn, err = ln.Accept() - if err != nil { - t.Fatalf("unexpected Accept error: %v", err) - } - if _, ok := conn.(*tls.Conn); !ok { - t.Error("failed to accept *tls.Conn") + + go func() { + conn, err := ln.Accept() + if err != nil { + errChan <- err + } + if _, ok := conn.(*tls.Conn); !ok { + errChan <- errors.New("failed to accept *tls.Conn") + } + conn.Close() + errChan <- nil + }() + + select { + case <-time.After(10 * time.Second): + t.Fatalf("timeout accept") + case err := <-errChan: + if err != nil { + t.Fatal(err) + } } - conn.Close() } diff --git a/client/pkg/transport/listener_tls.go b/client/pkg/transport/listener_tls.go index 37b17ec275ea..c5acb850f3cb 100644 --- a/client/pkg/transport/listener_tls.go +++ b/client/pkg/transport/listener_tls.go @@ -50,7 +50,7 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.L l.Close() return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String()) } - tlscfg, err := tlsinfo.ServerConfig() + tlscfg, err := tlsinfo.ReloadableServerConfig() if err != nil { return nil, err } diff --git a/go.mod b/go.mod index dbdef2f5a058..6d3234b46dc0 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.1 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/creack/pty v1.1.11 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect diff --git a/go.sum b/go.sum index 78811a2aba06..067e5e051d32 100644 --- a/go.sum +++ b/go.sum @@ -65,6 +65,7 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfc github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 9381225f79fd..27ec2a57cc1c 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -651,6 +651,7 @@ func (s *server) Close() (err error) { } s.lg.Sync() s.listenerMu.Unlock() + s.tlsInfo.Close() }) s.closeWg.Wait() return err diff --git a/pkg/proxy/server_test.go b/pkg/proxy/server_test.go index f1f186ba9746..06187528fe6f 100644 --- a/pkg/proxy/server_test.go +++ b/pkg/proxy/server_test.go @@ -496,7 +496,7 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) { tlsInfo := createTLSInfo(lg, secure) var tlsConfig *tls.Config if secure { - _, err := tlsInfo.ServerConfig() + _, err := tlsInfo.ReloadableServerConfig() if err != nil { t.Fatal(err) } diff --git a/server/embed/etcd.go b/server/embed/etcd.go index a4ee13e682d0..a5738b3ff433 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -532,7 +532,7 @@ func (e *Etcd) servePeers() (err error) { ph := etcdhttp.NewPeerHandler(e.GetLogger(), e.Server) var peerTLScfg *tls.Config if !e.cfg.PeerTLSInfo.Empty() { - if peerTLScfg, err = e.cfg.PeerTLSInfo.ServerConfig(); err != nil { + if peerTLScfg, err = e.cfg.PeerTLSInfo.ReloadableServerConfig(); err != nil { return err } } diff --git a/server/embed/serve.go b/server/embed/serve.go index 455abd43d36f..40c4d219ae24 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -144,7 +144,7 @@ func (sctx *serveCtx) serve( } if sctx.secure { - tlscfg, tlsErr := tlsinfo.ServerConfig() + tlscfg, tlsErr := tlsinfo.ReloadableServerConfig() if tlsErr != nil { return tlsErr } diff --git a/server/etcdmain/grpc_proxy.go b/server/etcdmain/grpc_proxy.go index bf69f1b7ff22..7bfe68c1b83b 100644 --- a/server/etcdmain/grpc_proxy.go +++ b/server/etcdmain/grpc_proxy.go @@ -476,7 +476,7 @@ func mustHTTPListener(lg *zap.Logger, m cmux.CMux, tlsinfo *transport.TLSInfo, c return srvhttp, m.Match(cmux.HTTP1()) } - srvTLS, err := tlsinfo.ServerConfig() + srvTLS, err := tlsinfo.ReloadableServerConfig() if err != nil { lg.Fatal("failed to set up TLS", zap.Error(err)) } diff --git a/tests/framework/integration/cluster.go b/tests/framework/integration/cluster.go index f5724b2dfa44..8bfe6703c9e9 100644 --- a/tests/framework/integration/cluster.go +++ b/tests/framework/integration/cluster.go @@ -918,7 +918,7 @@ func (m *Member) Clone(t testutil.TB) *Member { return mm } -// Launch starts a member based on ServerConfig, PeerListeners +// Launch starts a member based on ReloadableServerConfig, PeerListeners // and ClientListeners. func (m *Member) Launch() error { m.Logger.Info( @@ -937,7 +937,7 @@ func (m *Member) Launch() error { var peerTLScfg *tls.Config if m.PeerTLSInfo != nil && !m.PeerTLSInfo.Empty() { - if peerTLScfg, err = m.PeerTLSInfo.ServerConfig(); err != nil { + if peerTLScfg, err = m.PeerTLSInfo.ReloadableServerConfig(); err != nil { return err } } @@ -947,7 +947,7 @@ func (m *Member) Launch() error { tlscfg *tls.Config ) if m.ClientTLSInfo != nil && !m.ClientTLSInfo.Empty() { - tlscfg, err = m.ClientTLSInfo.ServerConfig() + tlscfg, err = m.ClientTLSInfo.ReloadableServerConfig() if err != nil { return err } @@ -1030,7 +1030,7 @@ func (m *Member) Launch() error { hs.Start() } else { info := m.ClientTLSInfo - hs.TLS, err = info.ServerConfig() + hs.TLS, err = info.ReloadableServerConfig() if err != nil { return err }