diff --git a/client/v3/concurrency/session.go b/client/v3/concurrency/session.go index 2275e96c972..d1367e6dd12 100644 --- a/client/v3/concurrency/session.go +++ b/client/v3/concurrency/session.go @@ -40,14 +40,23 @@ type Session struct { // NewSession gets the leased session for a client. func NewSession(client *v3.Client, opts ...SessionOption) (*Session, error) { lg := client.GetLogger() - ops := &sessionOptions{ttl: defaultSessionTTL, ctx: client.Ctx()} + ops := &sessionOptions{ + ttl: defaultSessionTTL, + ctx: client.Ctx(), + } for _, opt := range opts { opt(ops, lg) } + var cancel context.CancelFunc + sessionCreationCtx := ops.ctx + if ops.creationTimeout > 0 { + sessionCreationCtx, cancel = context.WithTimeout(ops.ctx, ops.creationTimeout) + } + id := ops.leaseID if id == v3.NoLease { - resp, err := client.Grant(ops.ctx, int64(ops.ttl)) + resp, err := client.Grant(sessionCreationCtx, int64(ops.ttl)) if err != nil { return nil, err } @@ -115,9 +124,10 @@ func (s *Session) Close() error { } type sessionOptions struct { - ttl int - leaseID v3.LeaseID - ctx context.Context + ttl int + leaseID v3.LeaseID + ctx context.Context + creationTimeout time.Duration } // SessionOption configures Session. @@ -135,6 +145,19 @@ func WithTTL(ttl int) SessionOption { } } +// WithCreationTimeout configures the timeout for creating a new session. +// If timeout is <= 0, no timeout will be used, and the creating new session +// will be blocked forever until the etcd server is reachable. +func WithCreationTimeout(timeout time.Duration) SessionOption { + return func(so *sessionOptions, lg *zap.Logger) { + if timeout > 0 { + so.creationTimeout = timeout + } else { + lg.Warn("WithCreationTimeout(): timeout should be > 0, preserving current timeout", zap.Int64("current-session-timeout", int64(so.creationTimeout))) + } + } +} + // WithLease specifies the existing leaseID to be used for the session. // This is useful in process restart scenario, for example, to reclaim // leadership from an election prior to restart. diff --git a/tests/framework/integration/testing.go b/tests/framework/integration/testing.go index a211973ddc7..8c1ec1c956b 100644 --- a/tests/framework/integration/testing.go +++ b/tests/framework/integration/testing.go @@ -20,6 +20,7 @@ import ( grpclogsettable "github.com/grpc-ecosystem/go-grpc-middleware/logging/settable" "github.com/stretchr/testify/require" + "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zapgrpc" "go.uber.org/zap/zaptest" @@ -28,6 +29,7 @@ import ( "go.etcd.io/etcd/client/pkg/v3/verify" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/server/v3/embed" + "go.etcd.io/etcd/tests/v3/framework/testutils" gofail "go.etcd.io/gofail/runtime" ) @@ -130,6 +132,26 @@ func BeforeTest(t testutil.TB, opts ...TestOption) { os.Chdir(t.TempDir()) } +func ClientGRPCLoggerObserver(t testutil.TB) *testutils.LogObserver { + level := zapcore.InfoLevel + + obCore, logOb := testutils.NewLogObserver(level) + + options := zaptest.WrapOptions( + zap.WrapCore(func(oldCore zapcore.Core) zapcore.Core { + return zapcore.NewTee(oldCore, obCore) + }), + ) + + grpcLogger.Set( + zapgrpc.NewLogger( + zaptest.NewLogger(t, zaptest.Level(level), options). + Named("grpc-observer"), + ), + ) + return logOb +} + func assertInTestContext(t testutil.TB) { if !insideTestContext { t.Errorf("the function can be called only in the test context. Was integration.BeforeTest() called ?") diff --git a/tests/integration/cluster_test.go b/tests/integration/cluster_test.go index 29f8ae8dd5d..ca89ddbce04 100644 --- a/tests/integration/cluster_test.go +++ b/tests/integration/cluster_test.go @@ -25,10 +25,14 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" "go.etcd.io/etcd/server/v3/etcdserver" "go.etcd.io/etcd/tests/v3/framework/config" "go.etcd.io/etcd/tests/v3/framework/integration" + clientv3test "go.etcd.io/etcd/tests/v3/integration/clientv3" ) func init() { @@ -518,3 +522,91 @@ func TestSpeedyTerminate(t *testing.T) { case <-donec: } } + +// TestCreationTimeout checks that the option WithCreationTimeout +// sets a timeout for the creation of new sessions in case the cluster +// shuts down +func TestCreationTimeout(t *testing.T) { + integration.BeforeTest(t) + + // create new cluster + clus := integration.NewCluster(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + + // create new client + cli, err := integration.NewClient(t, clientv3.Config{Endpoints: clus.Client(0).Endpoints(), Username: "user1", Password: "123"}) + if err != nil { + clus.Terminate(t) + t.Fatal(err) + } + defer cli.Close() + + // ensure the connection is established. + clientv3test.MustWaitPinReady(t, cli) + + // terminating the cluster + clus.Terminate(t) + + // override the grpc logger + logOb := integration.ClientGRPCLoggerObserver(t) + + _, err = concurrency.NewSession(cli, concurrency.WithCreationTimeout(3000*time.Millisecond)) + assert.Equal(t, err, context.DeadlineExceeded) + + _, err = logOb.Expect(context.Background(), "Subchannel Connectivity change to TRANSIENT_FAILURE", 3) + assert.Nil(t, err) +} + +// TestTimeoutDoesntAffectSubsequentConnections checks that the option WithCreationTimeout +// is only used when Session is created +func TestTimeoutDoesntAffectSubsequentConnections(t *testing.T) { + integration.BeforeTest(t) + + // create new cluster + clus := integration.NewCluster(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + clus.Members[0].KeepDataDirTerminate = true + + // create new client + cli, err := integration.NewClient(t, clientv3.Config{Endpoints: clus.Client(0).Endpoints(), Username: "user1", Password: "123"}) + if err != nil { + clus.Terminate(t) + t.Fatal(err) + } + defer cli.Close() + + // ensure the connection is established. + clientv3test.MustWaitPinReady(t, cli) + + s, _ := concurrency.NewSession(cli, concurrency.WithCreationTimeout(1*time.Second)) + + // terminating the cluster + clus.Members[0].Terminate(t) + + errorc := make(chan error) + defer close(errorc) + go func() { + _, err := cli.Put(s.Ctx(), "sample_key", "sample_value", clientv3.WithLease(s.Lease())) + errorc <- err + }() + + select { + case err := <-errorc: + t.Fatalf("Operation put should be blocked forever when the server is unreachable: %v", err) + // if Put operation is blocked beyond the timeout specified using WithCreationTimeout, + // that timeout is not used by the Put operation + case <-time.After(2 * time.Second): + } + + // restarting and ensuring that the Put operation will eventually succeed + clus.Members[0].Restart(t) + clus.Members[0].WaitOK(t) + select { + case err := <-errorc: + if err != nil { + t.Errorf("Put failed: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Put function hung even after restarting cluster") + } +}