Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
1. add timeout to tests
2. Rename ServerConfig -> ReloadableServerConfig
3. Move comment to right place
4. Close TlsInfo in etcd server close
5. squash commits
  • Loading branch information
yishuT committed Nov 8, 2021
1 parent 5571b87 commit df8f6b6
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 66 deletions.
2 changes: 1 addition & 1 deletion client/pkg/transport/keepalive_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestNewKeepAliveListener(t *testing.T) {
defer del()
tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
tlscfg, err := tlsInfo.ServerConfig()
tlscfg, err := tlsInfo.ReloadableServerConfig()
if err != nil {
t.Fatalf("unexpected serverConfig error: %v", err)
}
Expand Down
75 changes: 35 additions & 40 deletions client/pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali
func (info *TLSInfo) startRefresh() {
info.refreshOnce.Do(
func() {
info.loadTLSConfig()
info.loadServerTlsConfig()
if info.RefreshDuration > 0 {
info.refreshDone = make(chan struct{})
go info.tlsConfigRefreshLoop()
Expand All @@ -367,7 +367,7 @@ func (info *TLSInfo) startRefresh() {
)
}

func (info *TLSInfo) loadTLSConfig() *tls.Config {
func (info *TLSInfo) loadServerTlsConfig() *tls.Config {
if info.Logger != nil {
info.Logger.Info("tls config reload from files")
}
Expand All @@ -389,13 +389,33 @@ func (info *TLSInfo) tlsConfigRefreshLoop() {
for {
select {
case <-ticker.C:
info.loadTLSConfig()
info.loadServerTlsConfig()
case <-info.refreshDone:
return
}
}
}

// baseConfig is called on initial TLS handshake start.
//
// Previously,
// 1. Server has non-empty (*tls.Config).Certificates on client hello
// 2. Server calls (*tls.Config).GetCertificate iff:
// - Server's (*tls.Config).Certificates is not empty, or
// - Client supplies SNI; non-empty (*tls.ClientHelloInfo).ServerName
//
// When (*tls.Config).Certificates is always populated on initial handshake,
// client is expected to provide a valid matching SNI to pass the TLS
// verification, thus trigger server (*tls.Config).GetCertificate to reload
// TLS assets. However, a cert whose SAN field does not include domain names
// but only IP addresses, has empty (*tls.ClientHelloInfo).ServerName, thus
// it was never able to trigger TLS reload on initial handshake; first
// ceritifcate object was being used, never being updated.
//
// Now, (*tls.Config).Certificates is created empty on initial TLS client
// handshake, in order to trigger (*tls.Config).GetCertificate and populate
// rest of the certificates on every new TLS connection, even when client
// SNI is empty (e.g. cert only includes IPs).
func (info *TLSInfo) baseConfig() (*tls.Config, error) {
if info.KeyFile == "" || info.CertFile == "" {
return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
Expand Down Expand Up @@ -516,33 +536,23 @@ func (info *TLSInfo) serverConfig() (*tls.Config, error) {
return cfg, nil
}

// baseConfig is called on initial TLS handshake start.
//
// Previously,
// 1. Server has non-empty (*tls.Config).Certificates on client hello
// 2. Server calls (*tls.Config).GetCertificate iff:
// - Server's (*tls.Config).Certificates is not empty, or
// - Client supplies SNI; non-empty (*tls.ClientHelloInfo).ServerName
//
// When (*tls.Config).Certificates is always populated on initial handshake,
// client is expected to provide a valid matching SNI to pass the TLS
// verification, thus trigger server (*tls.Config).GetCertificate to reload
// TLS assets. However, a cert whose SAN field does not include domain names
// but only IP addresses, has empty (*tls.ClientHelloInfo).ServerName, thus
// it was never able to trigger TLS reload on initial handshake; first
// ceritifcate object was being used, never being updated.
//
// Now, (*tls.Config).Certificates is created empty on initial TLS client
// handshake, in order to trigger (*tls.Config).GetCertificate and populate
// rest of the certificates on every new TLS connection, even when client
// SNI is empty (e.g. cert only includes IPs).
func (info *TLSInfo) getLatestTLSConfig() (*tls.Config, error) {
// cafiles returns a list of CA file paths.
func (info *TLSInfo) cafiles() []string {
cs := make([]string, 0)
if info.TrustedCAFile != "" {
cs = append(cs, info.TrustedCAFile)
}
return cs
}

// ReloadableServerConfig generates a tls.Config object for use by an HTTP server.
func (info *TLSInfo) ReloadableServerConfig() (*tls.Config, error) {
info.startRefresh()
return &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
cfg, ok := info.tlsConfig.Load().(*tls.Config)
if !ok {
cfg = info.loadTLSConfig()
cfg = info.loadServerTlsConfig()
}
return cfg, nil
},
Expand Down Expand Up @@ -578,21 +588,6 @@ func (info *TLSInfo) getLatestTLSConfig() (*tls.Config, error) {
}, nil
}

// cafiles returns a list of CA file paths.
func (info *TLSInfo) cafiles() []string {
cs := make([]string, 0)
if info.TrustedCAFile != "" {
cs = append(cs, info.TrustedCAFile)
}
return cs
}

// ServerConfig generates a tls.Config object for use by an HTTP server.
func (info *TLSInfo) ServerConfig() (*tls.Config, error) {
cfg, err := info.getLatestTLSConfig()
return cfg, err
}

// ClientConfig generates a tls.Config object for use by an HTTP client.
func (info *TLSInfo) ClientConfig() (*tls.Config, error) {
var cfg *tls.Config
Expand Down
58 changes: 42 additions & 16 deletions client/pkg/transport/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ func TestTLSInfoParseFuncError(t *testing.T) {
tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))

if _, err = tt.info.serverConfig(); err == nil {
t.Errorf("#%d: expected non-nil error from ServerConfig()", i)
t.Errorf("#%d: expected non-nil error from ReloadableServerConfig()", i)
}

if _, err = tt.info.ClientConfig(); err == nil {
Expand Down Expand Up @@ -516,7 +516,7 @@ func TestTLSInfoConfigFuncs(t *testing.T) {

sCfg, err := tt.info.serverConfig()
if err != nil {
t.Errorf("#%d: expected nil error from ServerConfig(), got non-nil: %v", i, err)
t.Errorf("#%d: expected nil error from ReloadableServerConfig(), got non-nil: %v", i, err)
}

if tt.wantCAs != (sCfg.ClientCAs != nil) {
Expand Down Expand Up @@ -723,14 +723,27 @@ func TestRootCAReload(t *testing.T) {
cli.Get("https://" + ln.Addr().String())
}()

conn, err := ln.Accept()
if err != nil {
t.Fatalf("unexpected Accept error: %v", err)
}
if _, ok := conn.(*tls.Conn); !ok {
t.Error("failed to accept *tls.Conn")
errChan := make(chan error)
go func() {
conn, err := ln.Accept()
if err != nil {
errChan <- err
}
if _, ok := conn.(*tls.Conn); !ok {
errChan <- errors.New("failed to accept *tls.Conn")
}
conn.Close()
errChan <- nil
}()

select {
case <-time.After(10 * time.Second):
t.Fatalf("timeout accept")
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
}
conn.Close()

// regenerate rootCA and sign new certs
rootCA, _, privKey = createRootCertificateAuthority(rootCAPath, caBytes, t)
Expand Down Expand Up @@ -758,12 +771,25 @@ func TestRootCAReload(t *testing.T) {
go func() {
cli.Get("https://" + ln.Addr().String())
}()
conn, err = ln.Accept()
if err != nil {
t.Fatalf("unexpected Accept error: %v", err)
}
if _, ok := conn.(*tls.Conn); !ok {
t.Error("failed to accept *tls.Conn")

go func() {
conn, err := ln.Accept()
if err != nil {
errChan <- err
}
if _, ok := conn.(*tls.Conn); !ok {
errChan <- errors.New("failed to accept *tls.Conn")
}
conn.Close()
errChan <- nil
}()

select {
case <-time.After(10 * time.Second):
t.Fatalf("timeout accept")
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
}
conn.Close()
}
2 changes: 1 addition & 1 deletion client/pkg/transport/listener_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.L
l.Close()
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
}
tlscfg, err := tlsinfo.ServerConfig()
tlscfg, err := tlsinfo.ReloadableServerConfig()
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ require (
github.com/cespare/xxhash/v2 v2.1.1 // indirect
github.com/coreos/go-semver v0.3.0 // indirect
github.com/coreos/go-systemd/v22 v22.3.2 // indirect
github.com/creack/pty v1.1.11 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfc
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw=
github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
Expand Down
1 change: 1 addition & 0 deletions pkg/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ func (s *server) Close() (err error) {
}
s.lg.Sync()
s.listenerMu.Unlock()
s.tlsInfo.Close()
})
s.closeWg.Wait()
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ func testServerHTTP(t *testing.T, secure, delayTx bool) {
tlsInfo := createTLSInfo(lg, secure)
var tlsConfig *tls.Config
if secure {
_, err := tlsInfo.ServerConfig()
_, err := tlsInfo.ReloadableServerConfig()
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion server/embed/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ func (e *Etcd) servePeers() (err error) {
ph := etcdhttp.NewPeerHandler(e.GetLogger(), e.Server)
var peerTLScfg *tls.Config
if !e.cfg.PeerTLSInfo.Empty() {
if peerTLScfg, err = e.cfg.PeerTLSInfo.ServerConfig(); err != nil {
if peerTLScfg, err = e.cfg.PeerTLSInfo.ReloadableServerConfig(); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion server/embed/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (sctx *serveCtx) serve(
}

if sctx.secure {
tlscfg, tlsErr := tlsinfo.ServerConfig()
tlscfg, tlsErr := tlsinfo.ReloadableServerConfig()
if tlsErr != nil {
return tlsErr
}
Expand Down
2 changes: 1 addition & 1 deletion server/etcdmain/grpc_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ func mustHTTPListener(lg *zap.Logger, m cmux.CMux, tlsinfo *transport.TLSInfo, c
return srvhttp, m.Match(cmux.HTTP1())
}

srvTLS, err := tlsinfo.ServerConfig()
srvTLS, err := tlsinfo.ReloadableServerConfig()
if err != nil {
lg.Fatal("failed to set up TLS", zap.Error(err))
}
Expand Down
8 changes: 4 additions & 4 deletions tests/framework/integration/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ func (m *Member) Clone(t testutil.TB) *Member {
return mm
}

// Launch starts a member based on ServerConfig, PeerListeners
// Launch starts a member based on ReloadableServerConfig, PeerListeners
// and ClientListeners.
func (m *Member) Launch() error {
m.Logger.Info(
Expand All @@ -937,7 +937,7 @@ func (m *Member) Launch() error {

var peerTLScfg *tls.Config
if m.PeerTLSInfo != nil && !m.PeerTLSInfo.Empty() {
if peerTLScfg, err = m.PeerTLSInfo.ServerConfig(); err != nil {
if peerTLScfg, err = m.PeerTLSInfo.ReloadableServerConfig(); err != nil {
return err
}
}
Expand All @@ -947,7 +947,7 @@ func (m *Member) Launch() error {
tlscfg *tls.Config
)
if m.ClientTLSInfo != nil && !m.ClientTLSInfo.Empty() {
tlscfg, err = m.ClientTLSInfo.ServerConfig()
tlscfg, err = m.ClientTLSInfo.ReloadableServerConfig()
if err != nil {
return err
}
Expand Down Expand Up @@ -1030,7 +1030,7 @@ func (m *Member) Launch() error {
hs.Start()
} else {
info := m.ClientTLSInfo
hs.TLS, err = info.ServerConfig()
hs.TLS, err = info.ReloadableServerConfig()
if err != nil {
return err
}
Expand Down

0 comments on commit df8f6b6

Please sign in to comment.