diff --git a/client/v3/concurrency/session.go b/client/v3/concurrency/session.go index 2275e96c972c..d1367e6dd12f 100644 --- a/client/v3/concurrency/session.go +++ b/client/v3/concurrency/session.go @@ -40,14 +40,23 @@ type Session struct { // NewSession gets the leased session for a client. func NewSession(client *v3.Client, opts ...SessionOption) (*Session, error) { lg := client.GetLogger() - ops := &sessionOptions{ttl: defaultSessionTTL, ctx: client.Ctx()} + ops := &sessionOptions{ + ttl: defaultSessionTTL, + ctx: client.Ctx(), + } for _, opt := range opts { opt(ops, lg) } + var cancel context.CancelFunc + sessionCreationCtx := ops.ctx + if ops.creationTimeout > 0 { + sessionCreationCtx, cancel = context.WithTimeout(ops.ctx, ops.creationTimeout) + } + id := ops.leaseID if id == v3.NoLease { - resp, err := client.Grant(ops.ctx, int64(ops.ttl)) + resp, err := client.Grant(sessionCreationCtx, int64(ops.ttl)) if err != nil { return nil, err } @@ -115,9 +124,10 @@ func (s *Session) Close() error { } type sessionOptions struct { - ttl int - leaseID v3.LeaseID - ctx context.Context + ttl int + leaseID v3.LeaseID + ctx context.Context + creationTimeout time.Duration } // SessionOption configures Session. @@ -135,6 +145,19 @@ func WithTTL(ttl int) SessionOption { } } +// WithCreationTimeout configures the timeout for creating a new session. +// If timeout is <= 0, no timeout will be used, and the creating new session +// will be blocked forever until the etcd server is reachable. +func WithCreationTimeout(timeout time.Duration) SessionOption { + return func(so *sessionOptions, lg *zap.Logger) { + if timeout > 0 { + so.creationTimeout = timeout + } else { + lg.Warn("WithCreationTimeout(): timeout should be > 0, preserving current timeout", zap.Int64("current-session-timeout", int64(so.creationTimeout))) + } + } +} + // WithLease specifies the existing leaseID to be used for the session. // This is useful in process restart scenario, for example, to reclaim // leadership from an election prior to restart. diff --git a/tests/framework/integration/testing.go b/tests/framework/integration/testing.go index b0580dcd670e..4a839668dd55 100644 --- a/tests/framework/integration/testing.go +++ b/tests/framework/integration/testing.go @@ -20,6 +20,7 @@ import ( grpc_logsettable "github.com/grpc-ecosystem/go-grpc-middleware/logging/settable" "github.com/stretchr/testify/require" + "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zapgrpc" "go.uber.org/zap/zaptest" @@ -30,6 +31,7 @@ import ( "go.etcd.io/etcd/client/pkg/v3/verify" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/server/v3/embed" + "go.etcd.io/etcd/tests/v3/framework/testutils" ) var grpc_logger grpc_logsettable.SettableLoggerV2 @@ -131,6 +133,26 @@ func BeforeTest(t testutil.TB, opts ...TestOption) { os.Chdir(t.TempDir()) } +func ClientGRPCLoggerObserver(t testutil.TB) *testutils.LogObserver { + level := zapcore.InfoLevel + + obCore, logOb := testutils.NewLogObserver(level) + + options := zaptest.WrapOptions( + zap.WrapCore(func(oldCore zapcore.Core) zapcore.Core { + return zapcore.NewTee(oldCore, obCore) + }), + ) + + grpc_logger.Set( + zapgrpc.NewLogger( + zaptest.NewLogger(t, zaptest.Level(level), options). + Named("grpc-observer"), + ), + ) + return logOb +} + func assertInTestContext(t testutil.TB) { if !insideTestContext { t.Errorf("the function can be called only in the test context. Was integration.BeforeTest() called ?") diff --git a/tests/integration/clientv3/concurrency/session_test.go b/tests/integration/clientv3/concurrency/session_test.go index b17991179751..562d4db6bdb5 100644 --- a/tests/integration/clientv3/concurrency/session_test.go +++ b/tests/integration/clientv3/concurrency/session_test.go @@ -23,6 +23,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/client/v3/concurrency" integration2 "go.etcd.io/etcd/tests/v3/framework/integration" + clientv3test "go.etcd.io/etcd/tests/v3/integration/clientv3" ) func TestSessionOptions(t *testing.T) { @@ -111,3 +112,74 @@ func TestSessionCtx(t *testing.T) { } assert.Equal(t, childCtx.Err(), context.Canceled) } + +// TestCreationTimeout checks that the option WithCreationTimeout +// sets a timeout for the creation of new sessions +func TestCreationTimeout(t *testing.T) { + integration2.BeforeTest(t) + + // create new cluster + clus := integration2.NewCluster(t, &integration2.ClusterConfig{Size: 1}) + + // create new client + cli, err := integration2.NewClient(t, clientv3.Config{Endpoints: []string{clus.Members[0].GRPCURL()}}) + if err != nil { + clus.Terminate(t) + t.Fatal(err) + } + defer cli.Close() + + // ensure the connection is established. + clientv3test.MustWaitPinReady(t, cli) + + // terminating the cluster + clus.Terminate(t) + + // override the grpc logger + logOb := integration2.ClientGRPCLoggerObserver(t) + + _, err = concurrency.NewSession(cli, concurrency.WithCreationTimeout(3000*time.Millisecond)) + assert.Equal(t, err, context.DeadlineExceeded) + + _, err = logOb.Expect(context.Background(), "Subchannel Connectivity change to TRANSIENT_FAILURE", 3) + assert.Nil(t, err) +} + +// TestCreationTimeout checks that the option WithCreationTimeout +// sets a timeout for the creation of new sessions +func TestTimeoutDoesntAffectSubsequentConnections(t *testing.T) { + integration2.BeforeTest(t) + + // create new cluster + clus := integration2.NewCluster(t, &integration2.ClusterConfig{Size: 1}) + + // create new client + cli, err := integration2.NewClient(t, clientv3.Config{Endpoints: []string{clus.Members[0].GRPCURL()}}) + if err != nil { + clus.Terminate(t) + t.Fatal(err) + } + defer cli.Close() + + // ensure the connection is established. + clientv3test.MustWaitPinReady(t, cli) + + s, err := concurrency.NewSession(cli, concurrency.WithCreationTimeout(1*time.Second)) + + // terminating the cluster + clus.Terminate(t) + + donec := make(chan struct{}) + go func() { + defer close(donec) + _, _ = cli.Put(s.Ctx(), "sample_key", "sample_value", clientv3.WithLease(s.Lease())) + }() + + select { + case <-donec: + t.Fatal("operation timed out using WithCreationTimeout") + // if Put operation is blocked beyond the timeout specified using WithCreationTimeout, + // that timeout is not used by the Put operation + case <-time.After(2 * time.Second): + } +}